%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=0
import torch as t
class MNIST(t.utils.data.Dataset):
"""
Implement in-memory MNIST dataset
"""
Z_DIM = ... # define laten variable dimensionality
BATCH_SIZE = ... # define batch size
def conv_block(in_features: int, out_features: int, kernel: int = 3, stride: int = 1, padding: int = 1):
"""
Implement classical trio convolution-normalization-nonlinearity
"""
# Decoder network should be able to upscale featuremaps along space dimensions.
class Upscale(t.nn.Module):
"""
Upscaling layer
"""
pass
def downscale_block(features: int, out_features: int):
"""
Small module consisting of several convolutions that can be used to downscale input tensor
"""
pass
def upscale_block(in_features: int, features: int):
"""
Small module consisting of several convolutions that can be used to upscale input tensor
"""
pass
# Convolutional layers works with [B, C, H, W] feature spaces
# and Linear layers works with [..., C] features spaces. We need a layers to convert between this two
# spaces when H = W = 1
class Squeeze(t.nn.Module):
"""
Convert x from [B, C, H, W] to [B, C]
"""
pass
class Unsqueeze(t.nn.Module):
"""
Convert x from [B, C] to [B, C, 1, 1]
"""
pass
def get_encoder():
"""
Encoder defined specifically for images 32x32. You can implement general case if you want.
"""
def get_decoder():
"""
Decoder defined specifically for images 32x32. You can implement general case if you want.
"""
def reparameterize(mean, log_std):
"""
Reparameterization trick for normal distribution
"""
pass
def kl(mean, log_std):
"""
KL-divergence between normal distribution and standard normal distribution
"""
class VAE(t.nn.Module):
"""
Encoder-decoder model
"""
def forward(self, imgs):
"""
forward calculation for VAE. Return:
1) ELBO loss to minimize
2) log p(x | z)
3) KL(p(z) || q(z | x))
"""
return {'loss': ..., 'log p(x|z)': ..., 'kl': ...}
model = ... # instantiate VAE model
print('model crated')
optimizer = ... # instantiate optimizer
print('optimzer created')
data = ... # load data
print('model created')
# Let's check that forward works
model.eval()
print(data[0].shape)
model(data[0][None].cuda())
model.train()
# with wandb.init(project='autoencoder'):
wandb.init(project='autoencoder') # Initialize wandb
# Training loop
for step, imgs in enumerate(tqdm(...)):
# pass images through VAE
# make optimization step
# report training metrics with wandb
if step % 50 == 0:
model.eval()
with t.no_grad():
# Generate images with VAE by randomly samling latent variable
model.train()