In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from matplotlib import pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cuda:0
In [2]:
# Convert vector to image
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.view(x.size(0), 28, 28)
    return x


# Display images, n=max number of images to show
def display_images(in_raw, out_raw, n=1):
    out_raw = out_raw[:n]
    if in_raw is not None:
        in_raw = in_raw[:n]
        in_pic = to_img(in_raw.cpu().data).view(-1, 28, 28)
        plt.figure(figsize=(18, 6))
        for i in range(n):
            plt.subplot(1,n,i+1)
            plt.imshow(in_pic[i])
            plt.axis('off')
    out_pic = to_img(out_raw.cpu().data).view(-1, 28, 28)
    plt.figure(figsize=(18, 6))
    for i in range(n):   
        plt.subplot(1,n,i+1)
        plt.imshow(out_pic[i])
        plt.axis('off')
    plt.show()

Data Loaders

In [3]:
# Define data loading step

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=img_transform),
    batch_size=256, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=img_transform),
    batch_size=32, shuffle=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw

/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /usr/local/src/pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

The VAE model

In [4]:
class VAE(nn.Module):
    def __init__(self, d=50):
        super().__init__()
        self.d = d #latent dimension
        
        self.encoder = nn.Sequential(
            nn.Linear(784, d ** 2),
            nn.ReLU(),
            nn.Linear(d ** 2, d * 2) # we have mean and variance, each is d-dim vector
        )

        self.decoder = nn.Sequential(
            nn.Linear(d, d ** 2),
            nn.ReLU(),
            nn.Linear(d ** 2, 784),
            nn.Tanh()
        )

    def sampler(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.new_empty(std.size()).normal_()
            return eps.mul_(std).add_(mu)
        else:
            return mu

    def forward(self, x):
        mu_logvar = self.encoder(x.view(-1, 784)).view(-1, 2, self.d)
        mu = mu_logvar[:, 0, :]
        logvar = mu_logvar[:, 1, :]
        z = self.sampler(mu, logvar)
        return self.decoder(z), mu, logvar
    
    def generate(self, N=10):
        z = torch.randn((N, self.d)).to(device)
        gen_img = self.decoder(z)
        return gen_img

The Loss function

In [5]:
def loss_function(x_hat, x, mu, logvar, beta=1):
    #recon_loss = nn.functional.binary_cross_entropy(
    recon_loss = nn.functional.mse_loss(
        x_hat, x.view(-1, 784), reduction='sum'
    )
    KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))

    return recon_loss + beta * KLD

Train the VAE

In [6]:
latent_dim=20
model = VAE(latent_dim).to(device)

# Setting the optimiser
learning_rate = 1e-3
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)

epochs = 50
codes = dict(mu=list(), logσ2=list(), y=list())
for epoch in range(0, epochs + 1):
    # Training
    if epoch > 0:  # test untrained net first
        model.train()
        train_loss = 0
        for x, _ in train_loader:
            x = x.to(device)
            # ===================forward=====================
            x_hat, mu, logvar = model(x)
            loss = loss_function(x_hat, x, mu, logvar)
            train_loss += loss.item()
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # ===================log========================
        print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')
    
    # Testing
    means, logvars, labels = list(), list(), list()
    with torch.no_grad():
        model.eval()
        test_loss = 0
        for x, y in test_loader:
            x = x.to(device)
            # ===================forward=====================
            x_hat, mu, logvar = model(x)
            test_loss += loss_function(x_hat, x, mu, logvar).item()
            # =====================log=======================
            means.append(mu.detach())
            logvars.append(logvar.detach())
            labels.append(y.detach())
    # ===================log========================
    codes['mu'].append(torch.cat(means))
    codes['logσ2'].append(torch.cat(logvars))
    codes['y'].append(torch.cat(labels))
    test_loss /= len(test_loader.dataset)
    print(f'====> Test set loss: {test_loss:.4f}')
    display_images(x, x_hat, 4)
====> Test set loss: 727.2871
====> Epoch: 1 Average loss: 145.4837
====> Test set loss: 86.3820
====> Epoch: 2 Average loss: 89.2477
====> Test set loss: 73.7199
====> Epoch: 3 Average loss: 81.6282
====> Test set loss: 69.8812
====> Epoch: 4 Average loss: 78.1212
====> Test set loss: 66.7671
====> Epoch: 5 Average loss: 75.8147
====> Test set loss: 65.5077
====> Epoch: 6 Average loss: 74.2920
====> Test set loss: 64.0387
====> Epoch: 7 Average loss: 73.1301
====> Test set loss: 62.7531
====> Epoch: 8 Average loss: 72.0988
====> Test set loss: 62.3729
====> Epoch: 9 Average loss: 71.3943
====> Test set loss: 62.8740
====> Epoch: 10 Average loss: 70.8130
====> Test set loss: 61.7129
====> Epoch: 11 Average loss: 70.2867
====> Test set loss: 59.5818
====> Epoch: 12 Average loss: 69.8290
====> Test set loss: 61.0419
====> Epoch: 13 Average loss: 69.4426
====> Test set loss: 60.6360
====> Epoch: 14 Average loss: 69.0222
====> Test set loss: 59.4542
====> Epoch: 15 Average loss: 68.7174
====> Test set loss: 59.3274
====> Epoch: 16 Average loss: 68.4701
====> Test set loss: 59.2792
====> Epoch: 17 Average loss: 68.1310
====> Test set loss: 58.7359
====> Epoch: 18 Average loss: 67.9398
====> Test set loss: 58.7250
====> Epoch: 19 Average loss: 67.6892
====> Test set loss: 58.3589
====> Epoch: 20 Average loss: 67.4666
====> Test set loss: 58.7632
====> Epoch: 21 Average loss: 67.2576
====> Test set loss: 58.1412
====> Epoch: 22 Average loss: 67.0259
====> Test set loss: 58.1059
====> Epoch: 23 Average loss: 66.9060
====> Test set loss: 57.5996
====> Epoch: 24 Average loss: 66.7129
====> Test set loss: 57.8446
====> Epoch: 25 Average loss: 66.5621
====> Test set loss: 57.6108
====> Epoch: 26 Average loss: 66.4718
====> Test set loss: 57.5028
====> Epoch: 27 Average loss: 66.1706
====> Test set loss: 57.2792
====> Epoch: 28 Average loss: 66.1107
====> Test set loss: 57.6444
====> Epoch: 29 Average loss: 65.9602
====> Test set loss: 56.7419
====> Epoch: 30 Average loss: 65.8052
====> Test set loss: 56.9834
====> Epoch: 31 Average loss: 65.6210
====> Test set loss: 57.2466
====> Epoch: 32 Average loss: 65.6273
====> Test set loss: 57.2458
====> Epoch: 33 Average loss: 65.4603
====> Test set loss: 56.0537
====> Epoch: 34 Average loss: 65.4017
====> Test set loss: 56.6860
====> Epoch: 35 Average loss: 65.2697
====> Test set loss: 56.6800
====> Epoch: 36 Average loss: 65.1989
====> Test set loss: 56.6645
====> Epoch: 37 Average loss: 65.0536
====> Test set loss: 56.6439
====> Epoch: 38 Average loss: 64.9384
====> Test set loss: 56.0820
====> Epoch: 39 Average loss: 64.8608
====> Test set loss: 56.6131
====> Epoch: 40 Average loss: 64.7732
====> Test set loss: 55.6812
====> Epoch: 41 Average loss: 64.6911
====> Test set loss: 56.7094
====> Epoch: 42 Average loss: 64.5550
====> Test set loss: 55.8501
====> Epoch: 43 Average loss: 64.4979
====> Test set loss: 56.9949
====> Epoch: 44 Average loss: 64.4607
====> Test set loss: 55.7420
====> Epoch: 45 Average loss: 64.3894
====> Test set loss: 55.1894
====> Epoch: 46 Average loss: 64.3180
====> Test set loss: 55.6339
====> Epoch: 47 Average loss: 64.2206
====> Test set loss: 55.3490
====> Epoch: 48 Average loss: 64.1760
====> Test set loss: 55.7189