In [1]:
%matplotlib inline

Reinforcement Learning (DQN) Tutorial

Based on the tutorial by:

Author: Adam Paszke

This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent on the CartPole-v0 task from the OpenAI Gym


You can find an official leaderboard with various algorithms and visualizations at the Gym website

The player to decide between two actions - moving the cart left or right - so that the pole attached to it stays upright.

In this task, rewards are:

  • +1 for every incremental timestep
  • and the environment terminates if
    • the pole falls over too far
    • or the cart moves more then 2.4 units away from center.

This means better performing scenarios will run for longer duration, accumulating larger return.

Neural networks can solve the task purely by looking at the scene.

  • we'll use a patch of the screen centered on the cart as the observation of the current state
  • our actions are move left or move right

Strictly speaking, we will present the state as the difference between the current screen patch and the previous one. This will allow the agent to take the velocity of the pole into account from one image.


First, let's import needed packages. Firstly, we need gym for the environment (Install using pip install gym). We'll also use the following from PyTorch:

  • neural networks (torch.nn)
  • optimization (torch.optim)
  • automatic differentiation (torch.autograd)
  • utilities for vision tasks (torchvision - a separate package
In [2]:
!pip install gym[atari] 
Requirement already satisfied: gym[atari] in /opt/conda/lib/python3.7/site-packages (0.21.0)
Requirement already satisfied: importlib-metadata>=4.8.1 in /opt/conda/lib/python3.7/site-packages (from gym[atari]) (4.8.2)
Requirement already satisfied: cloudpickle>=1.2.0 in /opt/conda/lib/python3.7/site-packages (from gym[atari]) (2.0.0)
Requirement already satisfied: numpy>=1.18.0 in /opt/conda/lib/python3.7/site-packages (from gym[atari]) (1.19.5)
Collecting ale-py~=0.7.1
  Downloading ale_py-0.7.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
     |████████████████████████████████| 1.6 MB 517 kB/s            
Requirement already satisfied: importlib-resources in /opt/conda/lib/python3.7/site-packages (from ale-py~=0.7.1->gym[atari]) (5.4.0)
Requirement already satisfied: typing-extensions>=3.6.4 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata>=4.8.1->gym[atari]) (
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata>=4.8.1->gym[atari]) (3.6.0)
Installing collected packages: ale-py
Successfully installed ale-py-0.7.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead:
In [3]:
!pip install pyglet==1.5.0
Collecting pyglet==1.5.0
  Downloading pyglet-1.5.0-py2.py3-none-any.whl (1.0 MB)
     |████████████████████████████████| 1.0 MB 517 kB/s            
Requirement already satisfied: future in /opt/conda/lib/python3.7/site-packages (from pyglet==1.5.0) (0.18.2)
Installing collected packages: pyglet
Successfully installed pyglet-1.5.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead:
In [4]:
#!apt-get install python-opengl -y
#!pip install PyOpenGL 
#!pip install PyOpenGL_accelerate
!pip install pyvirtualdisplay
Collecting pyvirtualdisplay
  Downloading PyVirtualDisplay-2.2-py3-none-any.whl (15 kB)
Collecting EasyProcess
  Downloading EasyProcess-0.3-py2.py3-none-any.whl (7.9 kB)
Installing collected packages: EasyProcess, pyvirtualdisplay
Successfully installed EasyProcess-0.3 pyvirtualdisplay-2.2
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead:
In [5]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

# to display things
import os
from pyvirtualdisplay import Display
from matplotlib import animation , rc

display = Display(visible=0, size=(1400, 900))
os.environ["DISPLAY"] = ":" + str(display.display) + "." + str(display._obj._screen)

# setup the environment
env = gym.make('CartPole-v0').unwrapped

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display


# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
/opt/conda/lib/python3.7/site-packages/ale_py/roms/ DeprecationWarning: SelectableGroups dict interface is deprecated. Use select.
  for external in metadata.entry_points().get(, []):

Display the game environment

In [6]:
In [7]:
frame = []
total_reward = 0
for i in range(100):
    action = env.action_space.sample()
    state, reward, done, info = env.step(action)
    total_reward += reward
    img = plt.imshow(env.render('rgb_array'))
    if done:

print("Game terminated after", len(frame), " steps with reward ", total_reward)
Game terminated after 35  steps with reward  35.0
In [8]:
fig = plt.figure()
anim = animation.ArtistAnimation(fig, frame, interval=100, repeat_delay=1000, blit=True)
rc('animation', html='jshtml')
<Figure size 432x288 with 0 Axes>

Replay Memory

We'll be using experience replay memory for training our DQN. It stores the transitions that the agent observes, allowing us to reuse this data later. By sampling from it randomly, the transitions that build up a batch are decorrelated. It has been shown that this greatly stabilizes and improves the DQN training procedure.

For this, we're going to need two classses:

  • Transition - a named tuple representing a single transition in our environment. It essentially maps (state, action) pairs to their (next_state, reward) result, with the state being the screen difference image as described later on.
  • ReplayMemory - a cyclic buffer of bounded size that holds the transitions observed recently. It also implements a .sample() method for selecting a random batch of transitions for training.
In [9]:
# the structure of the transition that we store
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

# stores the Experience Replay buffer
class ReplayMemory(object):

    def __init__(self, capacity):
        self.cap = capacity
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

Now, let's define our model. But first, let's quickly recap what a DQN is.

DQN algorithm

Our environment is deterministic, so all equations presented here are also formulated deterministically for the sake of simplicity. In the reinforcement learning literature, they would also contain expectations over stochastic transitions in the environment.

Our aim will be to train a policy that tries to maximize the discounted, cumulative reward Rt0=t=t0γtt0rt, where Rt0 is also known as the return. The discount, γ, should be a constant between 0 and 1 that ensures the sum converges. It makes rewards from the uncertain far future less important for our agent than the ones in the near future that it can be fairly confident about.

The main idea behind Q-learning is that if we had a function Q:State×ActionR, that could tell us what our return would be, if we were to take an action in a given state, then we could easily construct a policy that maximizes our rewards:

(1)π(s)=argmaxa Q(s,a)

However, we don't know everything about the world, so we don't have access to Q. But, since neural networks are universal function approximators, we can simply create one and train it to resemble Q.

For our training update rule, we'll use a fact that every Q function for some policy obeys the Bellman equation:


The difference between the two sides of the equality is known as the temporal difference error, δ:


To minimise this error, we will use the Smooth L1 Loss aka Huber loss The Huber loss acts like the mean squared error when the error is small, but like the mean absolute error when the e