AML21: 07 RNNs and LSTMs for Text Generation¶

Based on https://github.com/karpathy/char-rnn

Download Data¶

In [1]:
! wget "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" -c -P {'data/'}
--2021-10-24 12:15:16--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘data/input.txt’

input.txt           100%[===================>]   1.06M  --.-KB/s    in 0.02s   

2021-10-24 12:15:17 (44.6 MB/s) - ‘data/input.txt’ saved [1115394/1115394]

Libraries etc.¶

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda

Load the data¶

In [3]:
# Load the data into memory
data_file = "./data/input.txt" 
#data_file = "./data/sherlock.txt" 

## Open the text file
data = open(data_file, 'r').read(20000) ## Read only ~20KB of data; full data takes long time in training
chars = sorted(list(set(data))) 
## NOTE: vocab_size is a hyperparameter of our models
data_size, vocab_size = len(data), len(chars) 

print("Data has {} characters, {} unique".format(data_size, vocab_size))

## char to index and index to char maps
char_to_ix = { ch:i for i,ch in enumerate(chars) }
ix_to_char = { i:ch for i,ch in enumerate(chars) }

## convert data from chars to indices
data = list(data)
for i, ch in enumerate(data):
    data[i] = char_to_ix[ch]

## data tensor on device
data = torch.tensor(data).to(device)
data = torch.unsqueeze(data, dim=1)
Data has 20000 characters, 58 unique

The RNN and LSTM models¶

In [4]:
class myRNN(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=512, num_layers=3, do_dropout=False):
        super(myRNN, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.do_dropout = do_dropout
        
        self.dropout = nn.Dropout(0.5)
        self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        self.decoder = nn.Linear(hidden_size, output_size)
        
        self.hidden_state = None # the hidden state of the RNN
    
    def forward(self, input_seq):
        x = nn.functional.one_hot(input_seq, self.input_size).float()
        if self.do_dropout:
            x = self.dropout(x)
        x, new_hidden_state = self.rnn(x, self.hidden_state)
        output = self.decoder(x)
        # save the hidden state for the next batch; detach removes extra datastructures for backprop etc.
        self.hidden_state = new_hidden_state.detach() 
        return output
    
    def save_model(self, path):
        torch.save(self.state_dict(), path)
    
    def load_model(self, path):
        try:
            self.load_state_dict(torch.load(path))
        except Exception as err:
            print("Error loading model from file", path)
            print(err)
            print("Initializing model weights to default")
            self.__init__(self.input_size, self.output_size, self.hidden_size, self.num_layers)

class myLSTM(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=512, num_layers=3, do_dropout=False):
        super(myLSTM, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.do_dropout = do_dropout
        
        self.dropout = nn.Dropout(0.5)
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        self.decoder = nn.Linear(hidden_size, output_size)
        
        self.internal_state = None # the internal state of the LTSM, 
                                   # consists of the short term memory or hidden state and 
                                   # the long term memory or cell state
    
    def forward(self, input_seq):
        x = nn.functional.one_hot(input_seq, self.input_size).float()
        if self.do_dropout:
            x = self.dropout(x)
        x, new_internal_state = self.lstm(x, self.internal_state)
        output = self.decoder(x)
        self.internal_state = (new_internal_state[0].detach(), new_internal_state[1].detach())
        return output
    
    def save_model(self, path):
        torch.save(self.state_dict(), path)
    
    def load_model(self, path):
        try:
            self.load_state_dict(torch.load(path))
        except Exception as err:
            print("Error loading model from file", path)
            print(err)
            print("Initializing model weights to default")
            self.__init__(self.input_size, self.output_size, self.hidden_size, self.num_layers)

Helper Functions for Training and Testing¶

In [5]:
# function to count number of parameters
def get_n_params(model):
    np=0
    for p in list(model.parameters()):
        np += p.nelement()
    return np
In [6]:
def train(rnn_model, epoch, seq_len = 200):
    # seq_length is length of training data sequence
    
    rnn_model.train()
    # loss function 
    loss_fn = nn.CrossEntropyLoss()
    
    
    test_output_len = 200    # total num of characters in output test sequence
    
    ## random starting point in [0,seq_len-1] to partition data into chunks of length seq_len
    ## This is Truncated Backpropogation Through Time
    data_ptr = np.random.randint(seq_len)
    running_loss = 0
    n = 0;
    
    if epoch % 10 == 0 or epoch == 1 or epoch == 2 or epoch == 3:
        print("\n\n\n\nStart of Epoch: {0}".format(epoch))
        
    while True:
        input_seq = data[data_ptr : data_ptr+seq_len]
        target_seq = data[data_ptr+1 : data_ptr+seq_len+1]
        input_seq.to(device)
        target_seq.to(device)
        
        optimizer.zero_grad()
        output = rnn_model(input_seq)
        loss = loss_fn(torch.squeeze(output), torch.squeeze(target_seq))
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

        # update the data pointer
        data_ptr += seq_len
        # if at end of data then stop
        if data_ptr + seq_len + 1 > data_size:
            break
        
        n = n+1
            
    # print loss and a sample of generated text periodically
    if epoch % 10 == 0 or epoch == 1 or epoch == 2 or epoch == 3:
        # sample / generate a text sequence after every epoch
        rnn_model.eval()
        data_ptr = 0

        # random character from data to begin
        rand_index = np.random.randint(data_size-1)
        input_seq = data[rand_index : rand_index+1]

        
        test_output = ""
        while True:
            # forward pass
            output = rnn_model(input_seq)

            # construct categorical distribution and sample a character
            output = F.softmax(torch.squeeze(output), dim=0)
            #output.to("cpu")
            dist = Categorical(output)
            index = dist.sample().item()
            

            # append the sampled character to test_output
            test_output += ix_to_char[index]

            # next input is current output
            input_seq[0][0] = index
            data_ptr += 1

            if data_ptr > test_output_len:
                break
        print("TRAIN Sample")
        print(test_output)
        print("End of Epoch: {0} \t Loss: {1:.8f}".format(epoch, running_loss / n))
    
    return running_loss / n

      
In [7]:
def test(rnn_model, output_len=1000): 
    rnn_model.eval()
    
    # initialize variables
    data_ptr = 0
    hidden_state = None    
    
    # randomly select an initial string from the data of 10 characters
    rand_index = np.random.randint(data_size - 11)
    input_seq = data[rand_index : rand_index + 9]
    
    # compute last hidden state of the sequence 
    output = rnn_model(input_seq)
    
    # next element is the input to rnn
    input_seq = data[rand_index + 9 : rand_index + 10]
    
    # generate remaining sequence
    # NOTE: We generate one character at a time
    test_output=""
    while True:
        # forward pass
        output = rnn_model(input_seq)
        
        # construct categorical distribution and sample a character
        output = F.softmax(torch.squeeze(output), dim=0)
        dist = Categorical(output)
        index = dist.sample().item()
        
        # append the sampled character to test_output
        test_output += ix_to_char[index]
        
        # next input is current output
        input_seq[0][0] = index
        data_ptr += 1
        
        if data_ptr  > output_len:
            break

    print("\n\nTEST -------------------------------------------------")
    print(test_output)
    print("----------------------------------------")

Creating and Training an instance¶

First, plain RNNs¶

In [8]:
hidden_size = 512 + 200  # size of hidden state
num_layers = 6     # num of layers in RNN layer stack
lr = 0.002          # learning rate

model_save_file = "./model_data.pth"

model_rnn = myRNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)
optimizer = torch.optim.Adam(model_rnn.parameters(), lr=lr) 

best_model_rnn =  myRNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)
best_rnn_loss = 10000

for epoch in range(0, 101): # values from 1 to 100
    #model_rnn.load_model(model_save_file)
    epoch_loss = train(model_rnn, epoch)
    if epoch_loss < best_rnn_loss:
        best_rnn_loss = epoch_loss
        best_model_rnn.load_state_dict(model_rnn.state_dict())
    #if epoch % 10 == 0:
    #    model_rnn.save_model(model_save_file)

Start of Epoch: 0
TRAIN Sample
 gsyuaernI rn:.a
ehR,Iea?Rieel tliotei ryareu R.tGd i G.mnorfs,yIdtualREhs iuveu,s 
e,aAI oiir Ae irrIe altim ioome
l uimuo s nenL A   b, :sm
tua.tE   hhe Ewdd:ak  .tguiwChA    dy ur
ia-o enr.e?d :dgR

End of Epoch: 0 	 Loss: 3.45813104




Start of Epoch: 1
TRAIN Sample
oabo,il ue Rms rr o b:iorx Eeeus,.diuie o s
 eas ntaUiin se,ne pefAtssV wbm. t  oatt -euepscuosnTabnioteuun
d  nd a arcimo rwVemaViuwe yerwewdeoid nsbVbta.atEsh Cyauiha,orptnsiA An eienseAolgrP:.u o iT
End of Epoch: 1 	 Loss: 3.41200320




Start of Epoch: 2
TRAIN Sample
tnyiEhbIsg:ewedaa ow
ei
u  sw W e iI.ilelhnhgroRd eeRh odsrwgnc u:elvapemcoednanoivefuI g.aatst ae tsaocRu:utehrAdntseueeEtGdiFVye noanLoeressoRparo aao tad a VVrAgeTsoeueermeeGRhyg:coRT;: ,:EreIn
a.ho
End of Epoch: 2 	 Loss: 3.40750512




Start of Epoch: 3
TRAIN Sample
i oenEaesnrn  ae:e oV,dothsn mha maoeice e  ueeigloe mh ryo ,hl:wVl v gar dgudmmenf.hAmteleuEenv
yeny reret, no Cr E.:
n  PapenthIe,ma
eiRdILi.eaceesIuea   co ldut I oster

rem.e.resnwo tl eaeahtneir
g
End of Epoch: 3 	 Loss: 3.41056026




Start of Epoch: 10
TRAIN Sample
 .stdetd:AncapariwGWA
ufnieetihgodah eV ly
iRAeuamilA r
euilwh o autnfge
h-ueAu g.eudNaa uepaunme:e ynn.osr
ttoo sUn.ntcr :derr l  Tiy ettGex,t owmWeeoecvidodb 

baugshAe ieboynr,ne  Id ,,gsArwaptesa  
End of Epoch: 10 	 Loss: 3.41519007




Start of Epoch: 20
TRAIN Sample
 nVlCyrrg Aeaew enTyTmoeaddrshnesVtA gV,e b
iRL:uergTeeha, ceGsust .h,ewGV,euea unfeaC
iLvoLu:  L, e;ileh,pa eenaeaeNay
i tV
neaxdaiEa :iseaag:u

ouRa
tm
 weAelnres eAnoPersge em o LAeddr sG v:eAarmnda
End of Epoch: 20 	 Loss: 3.41321708




Start of Epoch: 30
TRAIN Sample
 heyhsreVwooeo, AR uiuhia  e
AseLs
erlaAtauArepI itwre r.cwsg I
biIewdggwtnwagygRIucIgnaa
rnrdgw
Vyamdo eeEi eael:Redpeegs:Ryo,dy.etl.e ame

 rtx
TcC aeefiehlmssy
a ai
s:w,
 nrytiuVbfaeImhrhaian
 aasuR
End of Epoch: 30 	 Loss: 3.41030273




Start of Epoch: 40
TRAIN Sample
tsu.Vigeedhnn tlnirsenEgyljCig
ur,r
elrsob.oe.huh
rrihV. agdlayIavnyohgee
hlcxrgmp iterebEe;Vh onhewgieaCrntede ,ereorearyiGueys ,oewdmlbT
Aat eade,n
  m
woes, dt luBy
sAg
esueGtbdmonCAyleAi,hadn
avwNy
End of Epoch: 40 	 Loss: 3.40778256




Start of Epoch: 50
TRAIN Sample
: meaairR r an
cV
G,rlb
eni:
aId
gtmdliARerdGwr
n
gi n abe miy .oehV aIa
ayG o  nep umLu IL,ahe m elisVercu
ape euRwemeun rIe yV
yeauueye remelR elelVti
nlaa,tatEVa
ie :a:lur  
t,:rw  ahaseieIwyalgnots
End of Epoch: 50 	 Loss: 3.41423426




Start of Epoch: 60
TRAIN Sample

n,uTn
iIeLeotc.i,n RaahelomGehuAhmIoirvARoyVweuai,arVcA

g 
 veanCig et
rlRg
o
geAO.UA,Ru,ry ymoeiAineiuh  oela,abu,eleeor e Ainsmtm
.eemayu tg mmEdLnee E ,,h .icIuootlye
.t
 b ne
ea,a'?Rerh tV: tc;uh
End of Epoch: 60 	 Loss: 3.41571987




Start of Epoch: 70
TRAIN Sample
Tm eouaruuhef oonge C g,:hasrrmoe orohlLiVat:eutwaInaamrente ereootti nhd, rA irsax
Nyee,
nyReeRIe?h   ,Erty
yuInau hngAGr  eabsIed t,Rer oa  a rda,:ssr,A eulAeAGuta E rgres:h d tud phntr ;:AIeg  i
de 
End of Epoch: 70 	 Loss: 3.40681290




Start of Epoch: 80
TRAIN Sample
m
eA  naC:rb
Vey 
wec
nIru nwa
 orR 
ee wtaydtoeelsheeghlnut heahag,ty i Ii
ognsmrioaungOanp,i
mzo.us ggle evVn
s 
i evgc amldL s
b.Uf ohd bnee ,i  uueeVEwImgibgEpG?eal e yla:IiSI.oiG,hw riseaohp rfiya
End of Epoch: 80 	 Loss: 3.41508239




Start of Epoch: 90
TRAIN Sample
a Vaxoe s.ef dyt 
hno
xe  toVo. dsAeiy
herl.hgeaPeuntam fonlAwA a  aeehsye.esileed,:myel iage  CeRy?gAIa
nauh end 
I;avIi

aaad.a
dwpade:To siaAmd
E
emyV Vif hsna 'VnAtseitmAc emgwhlrusm
sholun
w soaae
End of Epoch: 90 	 Loss: 3.41482053




Start of Epoch: 100
TRAIN Sample
tsags:oh   eTnaiar rs:
t LaisU.heANmnweiidue.rs:rh?rhaiatn IWrdoment.Lnn ,
e.
a 
l ynG
ahg,osabtmdoa R::seig r
V
e G .ni:'ieov
eoIbseTAasei yaEar
e ns eeyf
efah.,ruh n, e

 gdchiheg :bdt
ueld  
xw:nghb
End of Epoch: 100 	 Loss: 3.40951993

Some sample generated text

In [9]:
print("best loss", best_rnn_loss)
print("Model size", get_n_params(best_model_rnn))
test(best_model_rnn)
best loss 3.4045372155247904
Model size 5667578


TEST -------------------------------------------------
oeeaoyahaete  eepnaCniuhut :vt.te
oalh
r reIgIg,u ILeEd i
,m
stmudJEodcrhrr?ndrenl oei Tlot
et ten .adobx srsrg. eavigueyimgeuE eGP
a
 iE i,am,e noa,h:Ieeypa
.!tsddgno aiy
ornasaairert ascoeaoVAtrteTl ce,
dat
Ionoy:amangyywoin tdul hiefr .geela odmu
omnG,duntrpeciou emAmeelo GeVor,Eod ' f tb,u
bngIyLtlU
beihnriaifmwdRwlg 
huidrfttueitso?i
 ,iur-oniL
uegRgnuviounoMegi gome o h eeorr
eee snnu  h
aaenveiiL r.ya pdImhLkrIyeuglrTuVrdyosT eEIsVwc  gLawbyhdwa  ,rmh
mseiw  u rrh
nomms me,;Ee
uo
m
er
Gtthen sd GnEr
 rus rsnnm 
  Re glemneho tub !,ryiwhIRlp Ailnrer s,Aloeit,nAufmne;ene ,uEoLgis:i ePCd VVr eho oa
 roh,AoiL .meglege,ou
artcsAly o.  eai rgyTds u,uhsEeWarh
wsCieorsrIlEg gnm lLeEVnAei.r

n.uaUathsGAVIed oLfoVoa
devhe ;w rata RtmesgeaieIe yyrgydw
elp eacgVi err aWr nne.runmdC .s
tyerl ee,mi,tsbe.ieuridnrIuoAIn, ee  ,Io.o
nht   d rfu uo:R.c u hmIdLpAoun,g
n
 l
.seteWno Pf.  a 
ee s auoo:iwsum u;hd
tt 
o 
n bogo
IsaVdgnfiestrms,iohpCo

umnfeeynoi rm riuont  e ite
gPax,vn aa,Tau
Eu  d,
ue
----------------------------------------
In [10]:
# smaller rnn
hidden_size = 512  # size of hidden state
num_layers = 3     # num of layers in RNN layer stack
lr = 0.002          # learning rate

model_rnn_3 = myRNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)
optimizer = torch.optim.Adam(model_rnn_3.parameters(), lr=lr) 

best_model_rnn_3 =  myRNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)
best_rnn_loss_3 = 10000

for epoch in range(0, 101): # values from 1 to 100
    #model_rnn.load_model(model_save_file)
    epoch_loss = train(model_rnn_3, epoch)
    if epoch_loss < best_rnn_loss_3:
        best_rnn_loss_3 = epoch_loss
        best_model_rnn_3.load_state_dict(model_rnn_3.state_dict())
    #if epoch % 10 == 0:
    #    model_rnn.save_model(model_save_file)

Start of Epoch: 0
TRAIN Sample
:a V,soefailoI
ot VsmmyoooGenaunrstEah e: uto  yslo,oodoe n
 o oGeIoot, roGLous,mgI utcogoyrwd
ooaiu
.LsiobIrbeii Lu,oR isy?C i
ee G'iwu oi a
gIVPxre
ofaunmo tggawG V.sei,ueI ihI o el p uiAnb: ev, tu G
End of Epoch: 0 	 Loss: 3.42875684




Start of Epoch: 1
TRAIN Sample
oweAaJs ,uuVt
iini
 'ritrrnve ustusbya g wUerIstm  g  .oso
t hctei osgn ee 'rs;V uh entliVe V  uh;G saoiltrl nerli  eoeti
r
slno.ae erw ueinoi taiasos edepoe ec;athria eetm
 iIe,dI?nnevy ieoyvho o s
 n
End of Epoch: 1 	 Loss: 3.39002413




Start of Epoch: 2
TRAIN Sample
o,lniwod
o aE,Na, d Lr  L   rumi  hbtLwIh i
inh legiRrofuuEi s io;  ger
d  VE  ta? wCom
emo s t.ib aoearnenroL sea aon lNoiItww
essetg

o
 uef   t 
Thuntlarfm oeAhtvYte:strue matoru mgeW:oeo  ssuo.Isdl
End of Epoch: 2 	 Loss: 3.38498163




Start of Epoch: 3
TRAIN Sample
rn rmi  air iheV ,  b
AmIeu l , ywuen Ihgttr
r tU tLmw 
hare'oRs  ooahLeaoekhonueeEt eiween ro rooeidiesie  rWrg oi nrT , dtueuVeWoIaostrre
ewcuiAnA AydIyV c Irsrh esIbL seresrolse wte
wmaiI;raLi e epo
End of Epoch: 3 	 Loss: 3.38958979




Start of Epoch: 10
TRAIN Sample
Luu eAseLs eItiiysE
 owd ergu.etaooV  hu
nE:Go merArdseei,Lee,hosseeEIeplse ,tam Eeem 
 atIy Il,wiUoVrIe  e.l eteiV wityiyaesrpt tigURer  
hi .ougseoeee s reGuee eremhirttMRet e A,Ivsreyctr Eewiertmtsl
End of Epoch: 10 	 Loss: 3.38974715




Start of Epoch: 20
TRAIN Sample
 reoetea
-er umr   ggt:dontrenVr  b rlaEdIet toe
 eri eo
 IIsouge trtf
damnrl.rryuieoruhhEeoo omg i ereeyti o fhrtr cEergw
rciE  fs
mh.re
 eeeeweToE i:,,i R e  puerwAio
mv
lU
l  ttlmiAm;hearmteInnn
e,o
End of Epoch: 20 	 Loss: 3.38875921




Start of Epoch: 30
TRAIN Sample
sim
  osi oino .au Enei
 r te lVuoiu, 
im,ro ruesAfg,ii cnusetVhe ? 
Re o .t, utriVh itLonoh rig heeha ILeptNftr
 oteuhWItoud
e  rr:ohltaf, ias ,
r.Yr .egiuegtetooi mo .
m
w re Iytyyrmi or  VI
oee wolo
End of Epoch: 30 	 Loss: 3.39087880




Start of Epoch: 40
TRAIN Sample
 ophs
mlmoLr
d sb;soiamwb,ts eu hAwhedein,iehuer ioAiiom
nutaohvI
ofo.nmndy i
rohae m
ocrsat,Losi tInmt:eti r s n;uou:Vg:ievilw,owhh
 sia
wVa.ve
,ebAI ds errs .Rth'heee Ite iVsaCtni  eI
eodn Eanl w
oro
End of Epoch: 40 	 Loss: 3.39338275




Start of Epoch: 50
TRAIN Sample
;uirimsi ; ro'os
i. sotmogior.i  toauAsdni,afew;  oNrysopre ngaV. o ErhNfoL miiulo .noIamnt: ,io rtdtnsis  loEl
?ylu: enrt.r itr,  w ua olmsubAihomATueaaoorgh.tmR. rsr ene  mrg eI  hntoootIfmiie eAeelI
End of Epoch: 50 	 Loss: 3.39144398




Start of Epoch: 60
TRAIN Sample
,LoabloIIs 
ieptabirIdhImhmAt, edpooesITheerciaul,h uhms o
Nohli,y , y; euLeeE  sRnm enhyu wyeUAriheolAt c 
sAoI
 .
 
  ,shon rmE,ardieVeeIret,hAoiarIotcetfAe bo
ioEAnA:nvneya tye

E,Ge
IoEoueegaitgi
s
End of Epoch: 60 	 Loss: 3.38904431




Start of Epoch: 70
TRAIN Sample
we
R.oe uenorGey rios:gle aptoE a oeiswehu it, w:tM mr.ildpe  Aa d led.maAey.ouueh:l  euV hrh 
b, ntn
aeaoo
rb luIoye tn
u.udoowIlrs
 e  uedtr
:urtnrmIdrVtfneIGe oronuaee tumb m?hgltorwosue oeoe?thL tu
End of Epoch: 70 	 Loss: 3.39074945




Start of Epoch: 80
TRAIN Sample
h nowIVorieiIoteii o  de 
l:demthsraym 
etrihEe
pu lennuwserhiu adgrmai.  eshiEouoo ieemeh
I'Weeot Vtthretieasrt deIesodle srhtns etyEL
h Gco'oCrAi  . rpuIUeoesi o
 d,teyto  a tounukda eeGgemh,e Iiauyu
End of Epoch: 80 	 Loss: 3.38822036




Start of Epoch: 90
TRAIN Sample
 Igoed ir
aobe'lwsaE
oaeEntbfe Lnoiuima:,, moV egrAreg,re.
.Edsao.toent
,yo rm .orsdhnue,eteVrV  inu  e,einei .ctoim,ge
g i osrAe,ee Tg em.hydt
oaei  eoolR.eloeuo L
drGh,,
uiRxiL
dirGeassIs  Icre
si os
End of Epoch: 90 	 Loss: 3.39146509




Start of Epoch: 100
TRAIN Sample
es  ri oeo
E
ghge oi e wg mhe
ewrwm Ildneigtlwned; no AL,eihartrAmoawlu
eeuewpr: eIraor

Eo b arehmIew mE u;udiueelhn iuu itArigsdnscdoLL s cR tmloLtasr   veor h
oAi
i
eeote, el  utLnm irwteirit
nne ee
End of Epoch: 100 	 Loss: 3.39265828
In [11]:
print("best loss", best_rnn_loss_3)
print("Model size", get_n_params(best_model_rnn_3))
test(best_model_rnn_3)
best loss 3.384981627366981
Model size 1373242


TEST -------------------------------------------------
beei 
tyssyte   sItuo tdehyagoi hrT,,iebnsGinnwopi
iLdoo.e u
wvo
Cinumre e't t,ea. i ,
bii rh ee  c aar h.: owa .eseoeV   e h
idauorne A,eVeeerseue:eN um  s c: Ilyatie ies iynutt
.tVho 'e? g  iaeshVA  eeige.oN:ior
,l .hi hrvtv v oarI; efho rseo 
n dtIt
 ehueo  s ert  Iwh etAo:weirtormh;i noLsao
  rdniEn  y t
i.nh:o rloeI ewlorhbson w;E
e.
it
netor es Gsni hG
hspc 
meoy hmuet
 l 
cbueaoxe wyis etsimoiaoiAmmoetng  e e Luoiim it ieshooU WVeana;iV eiug
msayegeuetuhusWmownu heAtm  :eiM grtV
N msEieyrotiooaiioe,d rhul AerN 
ts ieoetf
adUuso
o.rdf ilrLetAe m,nt
ohCohuhrmeuil  
eht
om :Aibr wooVboarimouhwoAr.yw  tet  c,
, tu
 ndI e bdoenss  de ed xv t  re 'nre.i
s ioeevI 
mhyMad
mnnes deiuo euetteieEnelie I. eoiey:liuse,u aeattvmoor i .md 
s . is eie,tmhr
u
   mamrmdeoE e prer ws, iaroln ;CvuV 
n beee, oi
ml atehinge s  mhboAriiiioyeam s:, r et Anl s l id hC nltsdnInb,b,
I 
iIaedrVAne
it'fnreooeutet il IAiw yhiemri Ve:icdc ek o,u n m o  ssrsut bir eereaoEVV E b rme lu- wews
a
me :wu,.litoAsenAh
----------------------------------------

Next, LSTMs¶

Some sample generated text

In [12]:
hidden_size = 512   # size of hidden state
seq_len = 100       # length of LSTM sequence
num_layers = 3      # num of layers in LSTM layer stack
lr = 0.002          # learning rate

model_save_file = "./model_data.pth"

model_lstm = myLSTM(vocab_size, vocab_size, hidden_size, num_layers).to(device)
optimizer = torch.optim.Adam(model_lstm.parameters(), lr=lr)

best_model_lstm = myLSTM(vocab_size, vocab_size, hidden_size, num_layers).to(device)
best_lstm_loss = 10000
 
for epoch in range(0, 101): # values from 0 to 100
    #model_lstm.load_model(model_save_file)
    epoch_loss = train(model_lstm, epoch)
    if epoch_loss < best_lstm_loss:
        best_lstm_loss = epoch_loss
        best_model_lstm.load_state_dict(model_lstm.state_dict())
    #if epoch % 10 == 0:
    #    model_lstm.save_model(model_save_file)
    

Start of Epoch: 0
TRAIN Sample
e uhpralo. .,rlodbisbeh Hmw l
l n 
I nttm  tlIeallImi h me.n   tLrrtAVliAoia  m: H  R::osnnptuTmeIas
sdooparhmsV,,n Im.aA :sst
 Vteh t. aioLIoM tvpaO Ii
 m 
 ona
ttVc
Rp,, rtto;LaFedL
 
 .moa,n tLUanIa
End of Epoch: 0 	 Loss: 3.35054017




Start of Epoch: 1
TRAIN Sample
os,    I 
e hb ha V geetRoaiot
:Aohth cyoaLdysiihf
nhl:nRhMoaa ;hinorued r bg.uleom,omaI
iI d l e :ro oe y;hfsdFYeIergegoofvI : dhalbgeptr ,eioc:B
  lhbI tef,  tcmrweIfadV eL  Is
hDt
cotiumgeiint aid
a
End of Epoch: 1 	 Loss: 3.33041688




Start of Epoch: 2
TRAIN Sample
pac; Iefl
it
ke eiw imIse
Iegt
no.mf
fonn lor
Vw
m .as t Iro
ar hh, th  e oe
bhs neplugeO
I -kilr, dh  chL o
Ae' tidtr mom.old rIsssogobe.elAvar wprst grtiteim
eer so te W.A ih di
ca
kes the
n a: R, Ce
End of Epoch: 2 	 Loss: 3.26577742




Start of Epoch: 3
TRAIN Sample
elyek, bit threthook me; wimaned socet, becv
miuoly
mhaoj eeawdab sot, yhise; ?ot Iotul yiuf eout. biteltn thokny phun -neaw
my fartg! roms at vh.
bnb taur ont hog sonl ghee Vns, opauclh wou mesid
Fodh
End of Epoch: 3 	 Loss: 2.81748445




Start of Epoch: 10
TRAIN Sample
alcaes.
I widserser, and it teme clt sinltiwong,
I
Anlt nare yhac
iure: neld; bome, ere hichers hy lig in not sattilt.

Sectet Corincooniugider: Inll be Aerstelte, He bemstaa that-atore? wags pabp. I t
End of Epoch: 10 	 Loss: 2.04604577




Start of Epoch: 20
TRAIN Sample
cond Cwent Corew; Mae, thear city, to that, you apites, a fich yere.

VIRGILIA:
I wire acain; Lake aulci n, no,
I'lr po most wishth, go with, I senaon'd op not ure sor me hiur ut.

VIRINILIA:
I with ra
End of Epoch: 20 	 Loss: 1.60632150




Start of Epoch: 30
TRAIN Sample
 comersala, thencelaint. I wirdd of mothsally: Citizens
better your lasinest herfery.

VALERIA:
In earnius thance hougt spupe we these firkt: proud it.

VALERIA:
Why, where-sinds with had alt pike gors
End of Epoch: 30 	 Loss: 1.12009017




Start of Epoch: 40
TRAIN Sample
h mition 'tis not uril.

VIRGILIA:
I pray yy more, a with me, it moree,
That I wead leave Malariin, poore fort un, farce.

LARTIUS:
What teld seee come!

First Citizen:
Our prese beepemine.

VALERIA:
I
End of Epoch: 40 	 Loss: 0.69803663




Start of Epoch: 50
TRAIN Sample
ll:
No mond your ladyship: was please
He had all at once can no way swat no e seade. Wonst of go a master?

VALERIA:
Ihe kady the only powe prisces had rog ans gone.

VALERIA:
Well, they have corn; and
End of Epoch: 50 	 Loss: 0.30184646




Start of Epoch: 60
TRAIN Sample

ASdeady.
Aominius, Marcius your old enemy, more attain'd than gize his way, when
youth always lut four knemy to the Capitol!

All:
Fase lecome rahest-- had it was now your hate; and you'll find
They'v
End of Epoch: 60 	 Loss: 0.11307307




Start of Epoch: 70
TRAIN Sample
ir, it more string and after him; from when
for a day of kings' entreaties a mother should not
sell hom und the dispatch is made, and in what fashion,
More than his singularity, he goes
Upon this prese
End of Epoch: 70 	 Loss: 0.07867572




Start of Epoch: 80
TRAIN Sample
ttend us.

COMINIUS:
It is your former promise.

MARCIUS:
Wer have honour than in the embracements of his bed lead on Marcius shall be honours,
And I am proud to hunt.

First Senator:
Farewell.

All:
F
End of Epoch: 80 	 Loss: 0.04437273




Start of Epoch: 90
TRAIN Sample
aughter, I sprang not
more in joy at first hearing he was a manachild
than now in first seeing he had proved himself a
man.

VIRGILIA:
I will wish her speedy strength, and visit her with
my prFyers; bu
End of Epoch: 90 	 Loss: 0.02530328




Start of Epoch: 100
TRAIN Sample
.

VIRGILIA:
I thank your ladyship.

VALERIA:
How do you both? you are manifest house-keepers.
What are you sewing doubt
prevailing and to make it more strike at Tullus' face.
Is't a verdict of it, the
End of Epoch: 100 	 Loss: 0.08256529
In [13]:
print("Model size", get_n_params(best_model_lstm), "best loss", best_lstm_loss)
test(best_model_lstm)
Model size 5403706 best loss 0.02530327953436241


TEST -------------------------------------------------
walls, that dogs must eat,
That meat was made for mouths.

MARCIUS:
Were half to half the world by the ears and he.
Upon my party, I'ld revolt to make
Only my wars with him: he is a lion
That I am proud to hunt.

First Senator:
Farewell.

VOLUMNIA:
Indeed, you shall not.
Methinks I hear hither your husband's drum,
See him pluck Aufidius down by the hair,
As children from a bear, the Volsces shunning him:
Methinks I see him stamp tous, and call thus:
'Come on, you cowards! you were got in fear,
Though you were born in Rome:' his bloody brow
With his mail'd hand then wiping, forth he goes,
Like to a harvest-ingther.

VOLUMNIA:
Why, I pray you?

VIRGILIA:
'Tis not to save labour, nor that I want love.

VALERIA:
Mervell, no.

MENENIUS:
For corn at their own rates; whereof, they lack discretion,
Yet are they passing cowardly. But, I beseech you,
What says the other troop?

MARCIUS:
They ne'er cared for us
yet: suffer us to famish, and their suffering in this dearth,
The gods, not the patrici
----------------------------------------
In [14]:
test(best_model_lstm)
TEST -------------------------------------------------
ss!

MENENIUS:
What is granted them?

MARCIUS:
Five tribunes to defend their vulgh.

VALEIUS:
'Though all at once cannot
See what I do deliver out to each,
Yet I can make my audit up, that all
From me do back receive the flour of all,
And leave me but the bran.' What say you to't?

First Senator:
Marcius, I had rather had eleven die nobly for their
country than one voluptuously surfeit out of action.

GentlewowaWe
We are alloreword
Mere me but the bran.' What say there's
grain enough!
Would the nobility lay aside their rubess rightly
Touching the weal o' the cobmost kare of all,
And leave me but the bran.' What say you to't?

First Senator:
Marcius, 'tis true that you might lead on 's rembs, and hear
How the dispatch is made, and in what fashion,
More than his singularity, he goes
Upon this present action.

BRUTUS:
Being moved, he will not spare to gird the gods.

SICINIUS:
Was ever man so proud as is tread
bhould know we were afoot. Net.
That only like a Tulf it did remain
I' the midst
----------------------------------------
In [15]:
test(best_model_lstm)
TEST -------------------------------------------------
In this good belly, that their power are forth already,
And only hitherward. I leave your honours.
If we and Caius Marcius chance to meet,
'Tis sworn wars it for usury, to
support usurers; repeal daily any wholesome act
established against the rich, and privififish?

Second Citizen:
What he cannoig here? to them, not arms, must help. Alack,
You.
When yet he was deliberate, who,
Under the gods, keep a very dog to there wars.

COMINIUS:
It is your former promise.

MARCIUS:
Sir, it is;
And I am constant. Titus Lartius, thou
porth greater themes a very dog to the commonalty.

Second Citizen:
What answer made the belly?

MENENIUS:
Swand'e not maliciously.

First Senator:
Marcius, Ma, 'tis a noble child.

VIRGILIA:
'Tis not to save labour, nor that I want love.
What says the other troop?

MARCIUS:
Here: what's the matter?

Messenger:
The news ia, sir, the Volsces are in arms.

MARCIUS:
Go, get you home, you fragments!

Messenger:
Where's Caius Marcius?

MARCIUS:
There was a time when all the 
----------------------------------------
In [16]:
test(best_model_lstm)
TEST -------------------------------------------------


MENENIUS:
O, doubt not that;
Our steed thence; these are the words: I think
I have the letter here; yes, here it is.
'They have present action.

BRUTUS:
Lets along.

First Senator:
So, your opinion is, Aufidius,
Madam, the Lady Valeria is come to visit you.

VIRGILIA:
Letaher alone, will gut you to 't.
I sin in envying his nobility,
And were I any thing but what I am,
I would wish me only he.

COMINIUS:
You have fought together.

MARCIUS:
Nay, let tvem follow:
The Volsces go with us.

VIRGILIA:
No, good madam; I will not out of doors.

VALERIA:
Not out of doors.

VALERIA:
In earnest, it's true; I heard a senator speak it.
Thus it ismom Cominius the general is gone, with one part of
As.

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Senator:
So, your opinion is, Aufidius,
Made and none less dear than thine and my good
Marcius, I had rather had eleven die nobly for their
country than one voluptuously surfeit out of action.

Gentl
----------------------------------------

We can see that the LSTM network has actually memorized the text.¶

This is probably because we used a small dataset (20KB)

Shakespere Generator¶

Training a LSTM network for a few hours on 4.5MB Text of Shakespeare's wroks

PANDARUS:
Alas, I think he shall be come approached and the day
When little srain would be attain'd into being never fed,
And who is but a chain and subjects of his death,
I should not sleep.

Second Senator:
They are away this miseries, produced upon my soul,
Breaking and strongly should be buried, when I perish
The earth and thoughts of many states.

DUKE VINCENTIO:
Well, your wit is in the care of side and that.

Second Lord:
They would be ruled after this chamber, and
my fair nues begun out of the fact, to be conveyed,
Whose noble souls I'll have the heart of the wars.

Clown:
Come, sir, I will make did behold your worship.

VIOLA:
I'll drink it.

Training on Linux Source Code¶

/*
 * Increment the size file of the new incorrect UI_FILTER group information
 * of the size generatively.
 */
static int indicate_policy(void)
{
  int error;
  if (fd == MARN_EPT) {
    /*
     * The kernel blank will coeld it to userspace.
     */
    if (ss->segment < mem_total)
      unblock_graph_and_set_blocked();
    else
      ret = 1;
    goto bail;
  }
  segaddr = in_SB(in.addr);
  selector = seg / 16;
  setup_works = true;
  for (i = 0; i < blocks; i++) {
    seq = buf[i++];
    bpf = bd->bd.next + i * search;
    if (fd) {
      current = blocked;
    }
  }
  rw->name = "Getjbbregs";
  bprm_self_clearl(&iv->version);
  regs->new = blocks[(BPF_STATS << info->historidac)] | PFMR_CLOBATHINC_SECONDS << 12;
  return segtable;
}

Training on Latex Source of an Algebraic Geometry Book¶

Note: some syntax errors ware corrected in the generated text and then complied