based on https://github.com/Atcold/pytorch-Deep-Learning/blob/master/11-VAE.ipynb
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from matplotlib import pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# Convert vector to image
def to_img(x):
x = 0.5 * (x + 1)
x = x.view(x.size(0), 28, 28)
return x
# Display images, n=max number of images to show
def display_images(in_raw, out_raw, n=1):
out_raw = out_raw[:n]
if in_raw is not None:
in_raw = in_raw[:n]
in_pic = to_img(in_raw.cpu().data).view(-1, 28, 28)
plt.figure(figsize=(18, 6))
for i in range(n):
plt.subplot(1,n,i+1)
plt.imshow(in_pic[i])
plt.axis('off')
out_pic = to_img(out_raw.cpu().data).view(-1, 28, 28)
plt.figure(figsize=(18, 6))
for i in range(n):
plt.subplot(1,n,i+1)
plt.imshow(out_pic[i])
plt.axis('off')
plt.show()
# Define data loading step
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True, transform=img_transform),
batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=img_transform),
batch_size=32, shuffle=True)
class VAE(nn.Module):
def __init__(self, d=50):
super().__init__()
self.d = d #latent dimension
self.encoder = nn.Sequential(
nn.Linear(784, d ** 2),
nn.ReLU(),
nn.Linear(d ** 2, d * 2) # we have mean and variance, each is d-dim vector
)
self.decoder = nn.Sequential(
nn.Linear(d, d ** 2),
nn.ReLU(),
nn.Linear(d ** 2, 784),
nn.Tanh()
)
def sampler(self, mu, logvar):
if self.training:
std = logvar.mul(0.5).exp_()
eps = std.new_empty(std.size()).normal_()
return eps.mul_(std).add_(mu)
else:
return mu
def forward(self, x):
mu_logvar = self.encoder(x.view(-1, 784)).view(-1, 2, self.d)
mu = mu_logvar[:, 0, :]
logvar = mu_logvar[:, 1, :]
z = self.sampler(mu, logvar)
return self.decoder(z), mu, logvar
def generate(self, N=10):
z = torch.randn((N, self.d)).to(device)
gen_img = self.decoder(z)
return gen_img
def loss_function(x_hat, x, mu, logvar, beta=1):
#recon_loss = nn.functional.binary_cross_entropy(
recon_loss = nn.functional.mse_loss(
x_hat, x.view(-1, 784), reduction='sum'
)
KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
return recon_loss + beta * KLD
latent_dim=20
model = VAE(latent_dim).to(device)
# Setting the optimiser
learning_rate = 1e-3
optimizer = torch.optim.Adam(
model.parameters(),
lr=learning_rate,
)
epochs = 50
codes = dict(mu=list(), logσ2=list(), y=list())
for epoch in range(0, epochs + 1):
# Training
if epoch > 0: # test untrained net first
model.train()
train_loss = 0
for x, _ in train_loader:
x = x.to(device)
# ===================forward=====================
x_hat, mu, logvar = model(x)
loss = loss_function(x_hat, x, mu, logvar)
train_loss += loss.item()
# ===================backward====================
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ===================log========================
print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')
# Testing
means, logvars, labels = list(), list(), list()
with torch.no_grad():
model.eval()
test_loss = 0
for x, y in test_loader:
x = x.to(device)
# ===================forward=====================
x_hat, mu, logvar = model(x)
test_loss += loss_function(x_hat, x, mu, logvar).item()
# =====================log=======================
means.append(mu.detach())
logvars.append(logvar.detach())
labels.append(y.detach())
# ===================log========================
codes['mu'].append(torch.cat(means))
codes['logσ2'].append(torch.cat(logvars))
codes['y'].append(torch.cat(labels))
test_loss /= len(test_loader.dataset)
print(f'====> Test set loss: {test_loss:.4f}')
display_images(x, x_hat, 4)
N=6
display_images(None, model.generate(N), N)
N=6
display_images(None, model.generate(N), N)
x, _ = next(iter(test_loader))
x = x.to(device)
x_hat, mu, logvar = model(x)
A, B = 1, 14
sample = model.decoder(torch.stack((mu[A].data, mu[B].data), 0))
#sample = torch.stack((x_hat[A].data, x_hat[B].data), 0)
display_images(None, torch.stack(((
x[A].data.view(-1),
x[B].data.view(-1)
)), 0), 2)
display_images(None, torch.stack(((
sample.data[0],
sample.data[1]
)), 0), 2)
# Perform an interpolation between input A and B, in N steps
N = 20
code = torch.Tensor(N, latent_dim).to(device)
sample = torch.Tensor(N, 28, 28).to(device)
for i in range(N):
code[i] = i / (N - 1) * mu[B].data + (1 - i / (N - 1) ) * mu[A].data
# sample[i] = i / (N - 1) * x[B].data + (1 - i / (N - 1) ) * x[A].data
img_set = model.decoder(code)
for i in range(N//4):
display_images(None, img_set[i*4:4*(i+1)], 4)