AML21: 05 Larger CNN for CIFAR-10¶

Data and Libraries¶

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
import numpy

# this 'device' will be used for training our model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cuda:0

Load the CIFAR10 dataset¶

Observe that we set shuffle=True, which means that data is randomized

In [8]:
input_size  = 32*32*3   # images are 32x32 pixels with 3 channels
output_size = 10      # there are 10 classes

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=64, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=1000, shuffle=True)

classNames= ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Files already downloaded and verified
In [9]:
# show some training images
def imshow(img, plot):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()   # convert from tensor
    plot.imshow(numpy.transpose(npimg, (1, 2, 0))) 
    

plt.figure(figsize=(16,4))

# fetch a batch of train images; RANDOM
image_batch, label_batch = next(iter(train_loader))
#imshow(torchvision.utils.make_grid(image_batch))
for i in range(20):
    image = image_batch[i]
    label = classNames[label_batch[i].item()]
    plt.subplot(2, 10, i + 1)
    #image, label = train_loader.dataset.__getitem__(i)
    #plt.imshow(image.squeeze().numpy())
    imshow(image, plt)
    plt.axis('off')
    plt.title(label)
plt.show()

A 2-hidden layer Fully Connected Neural Network¶

Helper functions for training and testing¶

In [10]:
# function to count number of parameters
def get_n_params(model):
    np=0
    for p in list(model.parameters()):
        np += p.nelement()
    return np

accuracy_list = []
# we pass a model object to this trainer, and it trains this model for one epoch
def train(epoch, model):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # send to device
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        # send to device
        data, target = data.to(device), target.to(device)
        
        output = model(data)
        test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability                                                                 
        correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    accuracy_list.append(accuracy)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        accuracy))

Defining the Convolutional Neural Network¶

In [11]:
# FROM: https://github.com/boazbk/mltheoryseminar/blob/main/code/hw0/simple_train.ipynb
## 5-Layer CNN for CIFAR
## This is the Myrtle5 network by David Page (https://myrtle.ai/learn/how-to-train-your-resnet-4-architecture/)

class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), x.size(1))

class PrintShape(nn.Module):
    def forward(self, x): 
        print(x.shape)
        return x

def make_myrtle5(c=64, num_classes=10):
    ''' Returns a 5-layer CNN with width parameter c. '''
    return nn.Sequential(
        # Layer 0
        nn.Conv2d(3, c, kernel_size=3, stride=1,
                  padding=1, bias=True),
        nn.BatchNorm2d(c),
        nn.ReLU(),

        # Layer 1
        nn.Conv2d(c, c*2, kernel_size=3,
                  stride=1, padding=1, bias=True),
        nn.BatchNorm2d(c*2),
        nn.ReLU(),
        nn.MaxPool2d(2),

        # Layer 2
        nn.Conv2d(c*2, c*4, kernel_size=3,
                  stride=1, padding=1, bias=True),
        nn.BatchNorm2d(c*4),
        nn.ReLU(),
        nn.MaxPool2d(2),

        # Layer 3
        nn.Conv2d(c*4, c*8, kernel_size=3,
                  stride=1, padding=1, bias=True),
        nn.BatchNorm2d(c*8),
        nn.ReLU(),
        nn.MaxPool2d(2),

        # Layer 4
        nn.MaxPool2d(4),
        Flatten(),
        nn.Linear(c*8, num_classes, bias=True),
        #PrintShape(),
        nn.LogSoftmax(dim=1)
    )
In [12]:
print("Training on ", device)
model_cnn = make_myrtle5()
model_cnn.to(device)
optimizer = optim.SGD(model_cnn.parameters(), lr=0.01, momentum=0.5)
print('Number of parameters: {}'.format(get_n_params(model_cnn)))

for epoch in range(0, 20):
    train(epoch, model_cnn)
    test(model_cnn)
Training on  cuda:0
Number of parameters: 1558026
Train Epoch: 0 [0/50000 (0%)]	Loss: 2.450725
Train Epoch: 0 [6400/50000 (13%)]	Loss: 1.485865
Train Epoch: 0 [12800/50000 (26%)]	Loss: 1.422552
Train Epoch: 0 [19200/50000 (38%)]	Loss: 1.455640
Train Epoch: 0 [25600/50000 (51%)]	Loss: 1.353905
Train Epoch: 0 [32000/50000 (64%)]	Loss: 0.911743
Train Epoch: 0 [38400/50000 (77%)]	Loss: 1.004067
Train Epoch: 0 [44800/50000 (90%)]	Loss: 1.314174

Test set: Average loss: 1.1696, Accuracy: 5825/10000 (58%)

Train Epoch: 1 [0/50000 (0%)]	Loss: 0.959871
Train Epoch: 1 [6400/50000 (13%)]	Loss: 1.119282
Train Epoch: 1 [12800/50000 (26%)]	Loss: 1.360642
Train Epoch: 1 [19200/50000 (38%)]	Loss: 1.117539
Train Epoch: 1 [25600/50000 (51%)]	Loss: 1.070766
Train Epoch: 1 [32000/50000 (64%)]	Loss: 0.779036
Train Epoch: 1 [38400/50000 (77%)]	Loss: 0.792585
Train Epoch: 1 [44800/50000 (90%)]	Loss: 0.868176

Test set: Average loss: 1.2344, Accuracy: 6389/10000 (64%)

Train Epoch: 2 [0/50000 (0%)]	Loss: 1.178679
Train Epoch: 2 [6400/50000 (13%)]	Loss: 0.860170
Train Epoch: 2 [12800/50000 (26%)]	Loss: 0.664973
Train Epoch: 2 [19200/50000 (38%)]	Loss: 0.620556
Train Epoch: 2 [25600/50000 (51%)]	Loss: 0.603280
Train Epoch: 2 [32000/50000 (64%)]	Loss: 0.875738
Train Epoch: 2 [38400/50000 (77%)]	Loss: 0.831436
Train Epoch: 2 [44800/50000 (90%)]	Loss: 0.781533

Test set: Average loss: 0.8584, Accuracy: 7125/10000 (71%)

Train Epoch: 3 [0/50000 (0%)]	Loss: 0.705054
Train Epoch: 3 [6400/50000 (13%)]	Loss: 0.590898
Train Epoch: 3 [12800/50000 (26%)]	Loss: 0.586477
Train Epoch: 3 [19200/50000 (38%)]	Loss: 0.477940
Train Epoch: 3 [25600/50000 (51%)]	Loss: 0.469498
Train Epoch: 3 [32000/50000 (64%)]	Loss: 0.386068
Train Epoch: 3 [38400/50000 (77%)]	Loss: 0.646965
Train Epoch: 3 [44800/50000 (90%)]	Loss: 0.373946

Test set: Average loss: 0.7418, Accuracy: 7561/10000 (76%)

Train Epoch: 4 [0/50000 (0%)]	Loss: 0.519106
Train Epoch: 4 [6400/50000 (13%)]	Loss: 0.405649
Train Epoch: 4 [12800/50000 (26%)]	Loss: 0.228262
Train Epoch: 4 [19200/50000 (38%)]	Loss: 0.195641
Train Epoch: 4 [25600/50000 (51%)]	Loss: 0.265676
Train Epoch: 4 [32000/50000 (64%)]	Loss: 0.260352
Train Epoch: 4 [38400/50000 (77%)]	Loss: 0.289911
Train Epoch: 4 [44800/50000 (90%)]	Loss: 0.396362

Test set: Average loss: 0.6838, Accuracy: 7724/10000 (77%)

Train Epoch: 5 [0/50000 (0%)]	Loss: 0.311261
Train Epoch: 5 [6400/50000 (13%)]	Loss: 0.221794
Train Epoch: 5 [12800/50000 (26%)]	Loss: 0.233842
Train Epoch: 5 [19200/50000 (38%)]	Loss: 0.292379
Train Epoch: 5 [25600/50000 (51%)]	Loss: 0.254823
Train Epoch: 5 [32000/50000 (64%)]	Loss: 0.191599
Train Epoch: 5 [38400/50000 (77%)]	Loss: 0.460513
Train Epoch: 5 [44800/50000 (90%)]	Loss: 0.253229

Test set: Average loss: 0.8381, Accuracy: 7308/10000 (73%)

Train Epoch: 6 [0/50000 (0%)]	Loss: 0.157098
Train Epoch: 6 [6400/50000 (13%)]	Loss: 0.161707
Train Epoch: 6 [12800/50000 (26%)]	Loss: 0.144059
Train Epoch: 6 [19200/50000 (38%)]	Loss: 0.243483
Train Epoch: 6 [25600/50000 (51%)]	Loss: 0.308609
Train Epoch: 6 [32000/50000 (64%)]	Loss: 0.105695
Train Epoch: 6 [38400/50000 (77%)]	Loss: 0.137717
Train Epoch: 6 [44800/50000 (90%)]	Loss: 0.148835

Test set: Average loss: 1.1371, Accuracy: 7055/10000 (71%)

Train Epoch: 7 [0/50000 (0%)]	Loss: 0.473882
Train Epoch: 7 [6400/50000 (13%)]	Loss: 0.059927
Train Epoch: 7 [12800/50000 (26%)]	Loss: 0.182825
Train Epoch: 7 [19200/50000 (38%)]	Loss: 0.060236
Train Epoch: 7 [25600/50000 (51%)]	Loss: 0.092009
Train Epoch: 7 [32000/50000 (64%)]	Loss: 0.163264
Train Epoch: 7 [38400/50000 (77%)]	Loss: 0.121029
Train Epoch: 7 [44800/50000 (90%)]	Loss: 0.081558

Test set: Average loss: 0.6085, Accuracy: 8049/10000 (80%)

Train Epoch: 8 [0/50000 (0%)]	Loss: 0.064724
Train Epoch: 8 [6400/50000 (13%)]	Loss: 0.060822
Train Epoch: 8 [12800/50000 (26%)]	Loss: 0.047309
Train Epoch: 8 [19200/50000 (38%)]	Loss: 0.051049
Train Epoch: 8 [25600/50000 (51%)]	Loss: 0.027325
Train Epoch: 8 [32000/50000 (64%)]	Loss: 0.026601
Train Epoch: 8 [38400/50000 (77%)]	Loss: 0.073324
Train Epoch: 8 [44800/50000 (90%)]	Loss: 0.082999

Test set: Average loss: 0.5223, Accuracy: 8324/10000 (83%)

Train Epoch: 9 [0/50000 (0%)]	Loss: 0.018718
Train Epoch: 9 [6400/50000 (13%)]	Loss: 0.028863
Train Epoch: 9 [12800/50000 (26%)]	Loss: 0.024059
Train Epoch: 9 [19200/50000 (38%)]	Loss: 0.029044
Train Epoch: 9 [25600/50000 (51%)]	Loss: 0.021629
Train Epoch: 9 [32000/50000 (64%)]	Loss: 0.030350
Train Epoch: 9 [38400/50000 (77%)]	Loss: 0.037488
Train Epoch: 9 [44800/50000 (90%)]	Loss: 0.016246

Test set: Average loss: 0.5281, Accuracy: 8346/10000 (83%)

Train Epoch: 10 [0/50000 (0%)]	Loss: 0.016627
Train Epoch: 10 [6400/50000 (13%)]	Loss: 0.014314
Train Epoch: 10 [12800/50000 (26%)]	Loss: 0.011260
Train Epoch: 10 [19200/50000 (38%)]	Loss: 0.028594
Train Epoch: 10 [25600/50000 (51%)]	Loss: 0.011981
Train Epoch: 10 [32000/50000 (64%)]	Loss: 0.017905
Train Epoch: 10 [38400/50000 (77%)]	Loss: 0.008601
Train Epoch: 10 [44800/50000 (90%)]	Loss: 0.014871

Test set: Average loss: 0.5236, Accuracy: 8381/10000 (84%)

Train Epoch: 11 [0/50000 (0%)]	Loss: 0.009581
Train Epoch: 11 [6400/50000 (13%)]	Loss: 0.010017
Train Epoch: 11 [12800/50000 (26%)]	Loss: 0.008317
Train Epoch: 11 [19200/50000 (38%)]	Loss: 0.010488
Train Epoch: 11 [25600/50000 (51%)]	Loss: 0.006826
Train Epoch: 11 [32000/50000 (64%)]	Loss: 0.009112
Train Epoch: 11 [38400/50000 (77%)]	Loss: 0.011370
Train Epoch: 11 [44800/50000 (90%)]	Loss: 0.011394

Test set: Average loss: 0.5207, Accuracy: 8425/10000 (84%)

Train Epoch: 12 [0/50000 (0%)]	Loss: 0.007389
Train Epoch: 12 [6400/50000 (13%)]	Loss: 0.005245
Train Epoch: 12 [12800/50000 (26%)]	Loss: 0.004554
Train Epoch: 12 [19200/50000 (38%)]	Loss: 0.007154
Train Epoch: 12 [25600/50000 (51%)]	Loss: 0.011204
Train Epoch: 12 [32000/50000 (64%)]	Loss: 0.006606
Train Epoch: 12 [38400/50000 (77%)]	Loss: 0.006895
Train Epoch: 12 [44800/50000 (90%)]	Loss: 0.006378

Test set: Average loss: 0.5205, Accuracy: 8415/10000 (84%)

Train Epoch: 13 [0/50000 (0%)]	Loss: 0.007103
Train Epoch: 13 [6400/50000 (13%)]	Loss: 0.004968
Train Epoch: 13 [12800/50000 (26%)]	Loss: 0.006724
Train Epoch: 13 [19200/50000 (38%)]	Loss: 0.009702
Train Epoch: 13 [25600/50000 (51%)]	Loss: 0.004174
Train Epoch: 13 [32000/50000 (64%)]	Loss: 0.007478
Train Epoch: 13 [38400/50000 (77%)]	Loss: 0.006896
Train Epoch: 13 [44800/50000 (90%)]	Loss: 0.007859

Test set: Average loss: 0.5242, Accuracy: 8404/10000 (84%)

Train Epoch: 14 [0/50000 (0%)]	Loss: 0.003437
Train Epoch: 14 [6400/50000 (13%)]	Loss: 0.006451
Train Epoch: 14 [12800/50000 (26%)]	Loss: 0.007584
Train Epoch: 14 [19200/50000 (38%)]	Loss: 0.004709
Train Epoch: 14 [25600/50000 (51%)]	Loss: 0.005323
Train Epoch: 14 [32000/50000 (64%)]	Loss: 0.005898
Train Epoch: 14 [38400/50000 (77%)]	Loss: 0.006661
Train Epoch: 14 [44800/50000 (90%)]	Loss: 0.005114

Test set: Average loss: 0.5321, Accuracy: 8409/10000 (84%)

Train Epoch: 15 [0/50000 (0%)]	Loss: 0.003958
Train Epoch: 15 [6400/50000 (13%)]	Loss: 0.003409
Train Epoch: 15 [12800/50000 (26%)]	Loss: 0.004031
Train Epoch: 15 [19200/50000 (38%)]	Loss: 0.003771
Train Epoch: 15 [25600/50000 (51%)]	Loss: 0.002829
Train Epoch: 15 [32000/50000 (64%)]	Loss: 0.004005
Train Epoch: 15 [38400/50000 (77%)]	Loss: 0.003831
Train Epoch: 15 [44800/50000 (90%)]	Loss: 0.004305

Test set: Average loss: 0.5576, Accuracy: 8383/10000 (84%)

Train Epoch: 16 [0/50000 (0%)]	Loss: 0.006417
Train Epoch: 16 [6400/50000 (13%)]	Loss: 0.007146
Train Epoch: 16 [12800/50000 (26%)]	Loss: 0.004133
Train Epoch: 16 [19200/50000 (38%)]	Loss: 0.005006
Train Epoch: 16 [25600/50000 (51%)]	Loss: 0.003496
Train Epoch: 16 [32000/50000 (64%)]	Loss: 0.005496
Train Epoch: 16 [38400/50000 (77%)]	Loss: 0.004736
Train Epoch: 16 [44800/50000 (90%)]	Loss: 0.003731

Test set: Average loss: 0.5823, Accuracy: 8338/10000 (83%)

Train Epoch: 17 [0/50000 (0%)]	Loss: 0.003170
Train Epoch: 17 [6400/50000 (13%)]	Loss: 0.004042
Train Epoch: 17 [12800/50000 (26%)]	Loss: 0.002870
Train Epoch: 17 [19200/50000 (38%)]	Loss: 0.002797
Train Epoch: 17 [25600/50000 (51%)]	Loss: 0.005355
Train Epoch: 17 [32000/50000 (64%)]	Loss: 0.003527
Train Epoch: 17 [38400/50000 (77%)]	Loss: 0.003828
Train Epoch: 17 [44800/50000 (90%)]	Loss: 0.002579

Test set: Average loss: 0.5366, Accuracy: 8414/10000 (84%)

Train Epoch: 18 [0/50000 (0%)]	Loss: 0.003909
Train Epoch: 18 [6400/50000 (13%)]	Loss: 0.003112
Train Epoch: 18 [12800/50000 (26%)]	Loss: 0.004418
Train Epoch: 18 [19200/50000 (38%)]	Loss: 0.004666
Train Epoch: 18 [25600/50000 (51%)]	Loss: 0.004866
Train Epoch: 18 [32000/50000 (64%)]	Loss: 0.003108
Train Epoch: 18 [38400/50000 (77%)]	Loss: 0.003053
Train Epoch: 18 [44800/50000 (90%)]	Loss: 0.003407

Test set: Average loss: 0.5373, Accuracy: 8440/10000 (84%)

Train Epoch: 19 [0/50000 (0%)]	Loss: 0.002848
Train Epoch: 19 [6400/50000 (13%)]	Loss: 0.002562
Train Epoch: 19 [12800/50000 (26%)]	Loss: 0.002980
Train Epoch: 19 [19200/50000 (38%)]	Loss: 0.003394
Train Epoch: 19 [25600/50000 (51%)]	Loss: 0.002739
Train Epoch: 19 [32000/50000 (64%)]	Loss: 0.003342
Train Epoch: 19 [38400/50000 (77%)]	Loss: 0.002156
Train Epoch: 19 [44800/50000 (90%)]	Loss: 0.003706

Test set: Average loss: 0.5377, Accuracy: 8456/10000 (85%)

Show some predictions of the test network¶

In [13]:
def visualize_pred(img, pred_prob, real_label):
    ''' Function for viewing an image and it's predicted classes.
    '''
    #pred_prob = pred_prob.data.numpy().squeeze()

    fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
    #ax1.imshow(img.numpy().squeeze())
    imshow(img, ax1)
    ax1.axis('off')
    pred_label = numpy.argmax(pred_prob)
    ax1.set_title([classNames[real_label], classNames[pred_label]])
    
    ax2.barh(numpy.arange(10), pred_prob)
    ax2.set_aspect(0.1)
    ax2.set_yticks(numpy.arange(10))
    ax2.set_yticklabels(classNames)
    ax2.set_title('Prediction Probability')
    ax2.set_xlim(0, 1.1)
    plt.tight_layout()
In [14]:
model_cnn.to('cpu') 

# fetch a batch of test images
image_batch, label_batch = next(iter(test_loader))

# Turn off gradients to speed up this part
with torch.no_grad():
    log_pred_prob_batch = model_cnn(image_batch)
for i in range(10):
    img = image_batch[i]
    real_label = label_batch[i].item()
    log_pred_prob = log_pred_prob_batch[i]
    # Output of the network are log-probabilities, need to take exponential for probabilities
    pred_prob = torch.exp(log_pred_prob).data.numpy().squeeze()
    visualize_pred(img, pred_prob, real_label)

Does the Convolutional Network use "Visual Information" ?¶

In [15]:
fixed_perm = torch.randperm(3072) # Fix a permutation of the image pixels; We apply the same permutation to all images

# show some training images
plt.figure(figsize=(8, 8))

# fetch a batch of train images; RANDOM
image_batch, label_batch = next(iter(train_loader))

for i in range(6):
    image = image_batch[i]
    image_perm = image.view(-1, 32*32*3).clone()
    image_perm = image_perm[:, fixed_perm]
    image_perm = image_perm.view(3, 32, 32)
    
    label = label_batch[i].item()
    plt.subplot(3,4 , 2*i + 1)
    #image, label = train_loader.dataset.__getitem__(i)
    #plt.imshow(image.squeeze().numpy())
    imshow(image, plt)
    plt.axis('off')
    plt.title(classNames[label])
    plt.subplot(3, 4, 2*i+2)
    #plt.imshow(image_perm.squeeze().numpy())
    imshow(image_perm, plt)
    plt.axis('off')
    plt.title(classNames[label])
In [16]:
accuracy_list = []

def scramble_train(epoch, model, perm=torch.arange(0, 3072).long()):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # send to device
        data, target = data.to(device), target.to(device)
        
        # permute pixels
        data = data.view(-1, 32*32*3)
        data = data[:, perm]
        data = data.view(-1, 3, 32, 32)

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
def scramble_test(model, perm=torch.arange(0, 3072).long()):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        # send to device
        data, target = data.to(device), target.to(device)
        
        # permute pixels
        data = data.view(-1, 32*32*3)
        data = data[:, perm]
        data = data.view(-1, 3, 32, 32)
        
        output = model(data)
        test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss                                                               
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability                                                                 
        correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    accuracy_list.append(accuracy)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        accuracy))
In [17]:
print("Training on ", device)
model_3 = make_myrtle5()
model_3.to(device)
optimizer = optim.SGD(model_3.parameters(), lr=0.01, momentum=0.5)
print('Number of parameters: {}'.format(get_n_params(model_3)))

for epoch in range(0, 20):
    scramble_train(epoch, model_3, fixed_perm)
    scramble_test(model_3, fixed_perm)
Training on  cuda:0
Number of parameters: 1558026
Train Epoch: 0 [0/50000 (0%)]	Loss: 3.056440
Train Epoch: 0 [6400/50000 (13%)]	Loss: 2.161628
Train Epoch: 0 [12800/50000 (26%)]	Loss: 2.091366
Train Epoch: 0 [19200/50000 (38%)]	Loss: 2.080582
Train Epoch: 0 [25600/50000 (51%)]	Loss: 1.823800
Train Epoch: 0 [32000/50000 (64%)]	Loss: 1.699128
Train Epoch: 0 [38400/50000 (77%)]	Loss: 2.016679
Train Epoch: 0 [44800/50000 (90%)]	Loss: 1.591671

Test set: Average loss: 1.8555, Accuracy: 3651/10000 (37%)

Train Epoch: 1 [0/50000 (0%)]	Loss: 1.767027
Train Epoch: 1 [6400/50000 (13%)]	Loss: 1.644199
Train Epoch: 1 [12800/50000 (26%)]	Loss: 1.780953
Train Epoch: 1 [19200/50000 (38%)]	Loss: 1.511512
Train Epoch: 1 [25600/50000 (51%)]	Loss: 1.642140
Train Epoch: 1 [32000/50000 (64%)]	Loss: 1.566179
Train Epoch: 1 [38400/50000 (77%)]	Loss: 1.477035
Train Epoch: 1 [44800/50000 (90%)]	Loss: 1.468373

Test set: Average loss: 1.7228, Accuracy: 3825/10000 (38%)

Train Epoch: 2 [0/50000 (0%)]	Loss: 1.486330
Train Epoch: 2 [6400/50000 (13%)]	Loss: 1.469259
Train Epoch: 2 [12800/50000 (26%)]	Loss: 1.465371
Train Epoch: 2 [19200/50000 (38%)]	Loss: 1.356576
Train Epoch: 2 [25600/50000 (51%)]	Loss: 1.513697
Train Epoch: 2 [32000/50000 (64%)]	Loss: 1.110209
Train Epoch: 2 [38400/50000 (77%)]	Loss: 1.403494
Train Epoch: 2 [44800/50000 (90%)]	Loss: 1.359130

Test set: Average loss: 1.7391, Accuracy: 4228/10000 (42%)

Train Epoch: 3 [0/50000 (0%)]	Loss: 1.366364
Train Epoch: 3 [6400/50000 (13%)]	Loss: 1.318724
Train Epoch: 3 [12800/50000 (26%)]	Loss: 1.323884
Train Epoch: 3 [19200/50000 (38%)]	Loss: 1.449036
Train Epoch: 3 [25600/50000 (51%)]	Loss: 1.420731
Train Epoch: 3 [32000/50000 (64%)]	Loss: 1.470395
Train Epoch: 3 [38400/50000 (77%)]	Loss: 1.316091
Train Epoch: 3 [44800/50000 (90%)]	Loss: 1.254496

Test set: Average loss: 1.7532, Accuracy: 4019/10000 (40%)

Train Epoch: 4 [0/50000 (0%)]	Loss: 1.166533
Train Epoch: 4 [6400/50000 (13%)]	Loss: 0.938951
Train Epoch: 4 [12800/50000 (26%)]	Loss: 1.036190
Train Epoch: 4 [19200/50000 (38%)]	Loss: 0.916971
Train Epoch: 4 [25600/50000 (51%)]	Loss: 0.966352
Train Epoch: 4 [32000/50000 (64%)]	Loss: 1.107060
Train Epoch: 4 [38400/50000 (77%)]	Loss: 1.198961
Train Epoch: 4 [44800/50000 (90%)]	Loss: 1.026898

Test set: Average loss: 2.0874, Accuracy: 3623/10000 (36%)

Train Epoch: 5 [0/50000 (0%)]	Loss: 1.273960
Train Epoch: 5 [6400/50000 (13%)]	Loss: 0.873829
Train Epoch: 5 [12800/50000 (26%)]	Loss: 0.720255
Train Epoch: 5 [19200/50000 (38%)]	Loss: 1.041700
Train Epoch: 5 [25600/50000 (51%)]	Loss: 1.253112
Train Epoch: 5 [32000/50000 (64%)]	Loss: 0.958210
Train Epoch: 5 [38400/50000 (77%)]	Loss: 1.066872
Train Epoch: 5 [44800/50000 (90%)]	Loss: 0.902257

Test set: Average loss: 1.6574, Accuracy: 4873/10000 (49%)

Train Epoch: 6 [0/50000 (0%)]	Loss: 0.584551
Train Epoch: 6 [6400/50000 (13%)]	Loss: 0.750557
Train Epoch: 6 [12800/50000 (26%)]	Loss: 0.602948
Train Epoch: 6 [19200/50000 (38%)]	Loss: 0.659363
Train Epoch: 6 [25600/50000 (51%)]	Loss: 0.628025
Train Epoch: 6 [32000/50000 (64%)]	Loss: 0.959283
Train Epoch: 6 [38400/50000 (77%)]	Loss: 0.583149
Train Epoch: 6 [44800/50000 (90%)]	Loss: 0.880171

Test set: Average loss: 2.7702, Accuracy: 3834/10000 (38%)

Train Epoch: 7 [0/50000 (0%)]	Loss: 0.889435
Train Epoch: 7 [6400/50000 (13%)]	Loss: 0.348239
Train Epoch: 7 [12800/50000 (26%)]	Loss: 0.446112
Train Epoch: 7 [19200/50000 (38%)]	Loss: 0.454484
Train Epoch: 7 [25600/50000 (51%)]	Loss: 0.488406
Train Epoch: 7 [32000/50000 (64%)]	Loss: 0.540669
Train Epoch: 7 [38400/50000 (77%)]	Loss: 0.472171
Train Epoch: 7 [44800/50000 (90%)]	Loss: 0.700849

Test set: Average loss: 2.0475, Accuracy: 4548/10000 (45%)

Train Epoch: 8 [0/50000 (0%)]	Loss: 0.484381
Train Epoch: 8 [6400/50000 (13%)]	Loss: 0.147280
Train Epoch: 8 [12800/50000 (26%)]	Loss: 0.236531
Train Epoch: 8 [19200/50000 (38%)]	Loss: 0.301614
Train Epoch: 8 [25600/50000 (51%)]	Loss: 0.317405
Train Epoch: 8 [32000/50000 (64%)]	Loss: 0.301728
Train Epoch: 8 [38400/50000 (77%)]	Loss: 0.518494
Train Epoch: 8 [44800/50000 (90%)]	Loss: 0.371778

Test set: Average loss: 4.7512, Accuracy: 3139/10000 (31%)

Train Epoch: 9 [0/50000 (0%)]	Loss: 1.470842
Train Epoch: 9 [6400/50000 (13%)]	Loss: 0.196301
Train Epoch: 9 [12800/50000 (26%)]	Loss: 0.211066
Train Epoch: 9 [19200/50000 (38%)]	Loss: 0.108870
Train Epoch: 9 [25600/50000 (51%)]	Loss: 0.188814
Train Epoch: 9 [32000/50000 (64%)]	Loss: 0.131500
Train Epoch: 9 [38400/50000 (77%)]	Loss: 0.398295
Train Epoch: 9 [44800/50000 (90%)]	Loss: 0.255834

Test set: Average loss: 2.0075, Accuracy: 5037/10000 (50%)

Train Epoch: 10 [0/50000 (0%)]	Loss: 0.046995
Train Epoch: 10 [6400/50000 (13%)]	Loss: 0.045364
Train Epoch: 10 [12800/50000 (26%)]	Loss: 0.042854
Train Epoch: 10 [19200/50000 (38%)]	Loss: 0.059581
Train Epoch: 10 [25600/50000 (51%)]	Loss: 0.046778
Train Epoch: 10 [32000/50000 (64%)]	Loss: 0.039252
Train Epoch: 10 [38400/50000 (77%)]	Loss: 0.033289
Train Epoch: 10 [44800/50000 (90%)]	Loss: 0.060062

Test set: Average loss: 1.9476, Accuracy: 5183/10000 (52%)

Train Epoch: 11 [0/50000 (0%)]	Loss: 0.014184
Train Epoch: 11 [6400/50000 (13%)]	Loss: 0.014809
Train Epoch: 11 [12800/50000 (26%)]	Loss: 0.015247
Train Epoch: 11 [19200/50000 (38%)]	Loss: 0.016734
Train Epoch: 11 [25600/50000 (51%)]	Loss: 0.011483
Train Epoch: 11 [32000/50000 (64%)]	Loss: 0.012543
Train Epoch: 11 [38400/50000 (77%)]	Loss: 0.017695
Train Epoch: 11 [44800/50000 (90%)]	Loss: 0.013086

Test set: Average loss: 1.9576, Accuracy: 5276/10000 (53%)

Train Epoch: 12 [0/50000 (0%)]	Loss: 0.004602
Train Epoch: 12 [6400/50000 (13%)]	Loss: 0.005431
Train Epoch: 12 [12800/50000 (26%)]	Loss: 0.007548
Train Epoch: 12 [19200/50000 (38%)]	Loss: 0.006586
Train Epoch: 12 [25600/50000 (51%)]	Loss: 0.006540
Train Epoch: 12 [32000/50000 (64%)]	Loss: 0.008104
Train Epoch: 12 [38400/50000 (77%)]	Loss: 0.006487
Train Epoch: 12 [44800/50000 (90%)]	Loss: 0.008254

Test set: Average loss: 1.9886, Accuracy: 5328/10000 (53%)

Train Epoch: 13 [0/50000 (0%)]	Loss: 0.005180
Train Epoch: 13 [6400/50000 (13%)]	Loss: 0.005150
Train Epoch: 13 [12800/50000 (26%)]	Loss: 0.003997
Train Epoch: 13 [19200/50000 (38%)]	Loss: 0.004689
Train Epoch: 13 [25600/50000 (51%)]	Loss: 0.004786
Train Epoch: 13 [32000/50000 (64%)]	Loss: 0.004908
Train Epoch: 13 [38400/50000 (77%)]	Loss: 0.007177
Train Epoch: 13 [44800/50000 (90%)]	Loss: 0.008681

Test set: Average loss: 2.0236, Accuracy: 5309/10000 (53%)

Train Epoch: 14 [0/50000 (0%)]	Loss: 0.004287
Train Epoch: 14 [6400/50000 (13%)]	Loss: 0.004623
Train Epoch: 14 [12800/50000 (26%)]	Loss: 0.003995
Train Epoch: 14 [19200/50000 (38%)]	Loss: 0.003530
Train Epoch: 14 [25600/50000 (51%)]	Loss: 0.003280
Train Epoch: 14 [32000/50000 (64%)]	Loss: 0.004866
Train Epoch: 14 [38400/50000 (77%)]	Loss: 0.003300
Train Epoch: 14 [44800/50000 (90%)]	Loss: 0.003659

Test set: Average loss: 2.0178, Accuracy: 5340/10000 (53%)

Train Epoch: 15 [0/50000 (0%)]	Loss: 0.003157
Train Epoch: 15 [6400/50000 (13%)]	Loss: 0.002614
Train Epoch: 15 [12800/50000 (26%)]	Loss: 0.003597
Train Epoch: 15 [19200/50000 (38%)]	Loss: 0.004347
Train Epoch: 15 [25600/50000 (51%)]	Loss: 0.003027
Train Epoch: 15 [32000/50000 (64%)]	Loss: 0.003213
Train Epoch: 15 [38400/50000 (77%)]	Loss: 0.004176
Train Epoch: 15 [44800/50000 (90%)]	Loss: 0.002889

Test set: Average loss: 2.0391, Accuracy: 5325/10000 (53%)

Train Epoch: 16 [0/50000 (0%)]	Loss: 0.003329
Train Epoch: 16 [6400/50000 (13%)]	Loss: 0.002924
Train Epoch: 16 [12800/50000 (26%)]	Loss: 0.003077
Train Epoch: 16 [19200/50000 (38%)]	Loss: 0.002399
Train Epoch: 16 [25600/50000 (51%)]	Loss: 0.003589
Train Epoch: 16 [32000/50000 (64%)]	Loss: 0.003169
Train Epoch: 16 [38400/50000 (77%)]	Loss: 0.002909
Train Epoch: 16 [44800/50000 (90%)]	Loss: 0.002545

Test set: Average loss: 2.0496, Accuracy: 5327/10000 (53%)

Train Epoch: 17 [0/50000 (0%)]	Loss: 0.002144
Train Epoch: 17 [6400/50000 (13%)]	Loss: 0.002579
Train Epoch: 17 [12800/50000 (26%)]	Loss: 0.002740
Train Epoch: 17 [19200/50000 (38%)]	Loss: 0.003150
Train Epoch: 17 [25600/50000 (51%)]	Loss: 0.002454
Train Epoch: 17 [32000/50000 (64%)]	Loss: 0.002070
Train Epoch: 17 [38400/50000 (77%)]	Loss: 0.002619
Train Epoch: 17 [44800/50000 (90%)]	Loss: 0.003822

Test set: Average loss: 2.0755, Accuracy: 5324/10000 (53%)

Train Epoch: 18 [0/50000 (0%)]	Loss: 0.001887
Train Epoch: 18 [6400/50000 (13%)]	Loss: 0.003080
Train Epoch: 18 [12800/50000 (26%)]	Loss: 0.002147
Train Epoch: 18 [19200/50000 (38%)]	Loss: 0.002141
Train Epoch: 18 [25600/50000 (51%)]	Loss: 0.002881
Train Epoch: 18 [32000/50000 (64%)]	Loss: 0.002463
Train Epoch: 18 [38400/50000 (77%)]	Loss: 0.002105
Train Epoch: 18 [44800/50000 (90%)]	Loss: 0.002097

Test set: Average loss: 2.0743, Accuracy: 5344/10000 (53%)

Train Epoch: 19 [0/50000 (0%)]	Loss: 0.002729
Train Epoch: 19 [6400/50000 (13%)]	Loss: 0.001850
Train Epoch: 19 [12800/50000 (26%)]	Loss: 0.002331
Train Epoch: 19 [19200/50000 (38%)]	Loss: 0.001320
Train Epoch: 19 [25600/50000 (51%)]	Loss: 0.001862
Train Epoch: 19 [32000/50000 (64%)]	Loss: 0.001534
Train Epoch: 19 [38400/50000 (77%)]	Loss: 0.003254
Train Epoch: 19 [44800/50000 (90%)]	Loss: 0.001860

Test set: Average loss: 2.0890, Accuracy: 5342/10000 (53%)

In [18]:
model_3.to('cpu') 

# fetch a batch of test images
image_batch, label_batch = next(iter(test_loader))
image_batch_scramble = image_batch.view(-1, 32*32*3)
image_batch_scramble = image_batch_scramble[:, fixed_perm]
image_batch_scramble = image_batch_scramble.view(-1, 3, 32, 32)
# Turn off gradients to speed up this part
with torch.no_grad():
    log_pred_prob_batch = model_3(image_batch_scramble)
for i in range(10):
    img = image_batch[i]
    img_perm = image_batch_scramble[i]
    real_label = label_batch[i].item()
    log_pred_prob = log_pred_prob_batch[i]
    # Output of the network are log-probabilities, need to take exponential for probabilities
    pred_prob = torch.exp(log_pred_prob).data.numpy().squeeze()
    visualize_pred(img_perm, pred_prob, real_label)