{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.4","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# AML21: 07 RNNs and LSTMs for Text Generation\nBased on https://github.com/karpathy/char-rnn","metadata":{}},{"cell_type":"markdown","source":"## Download Data","metadata":{}},{"cell_type":"code","source":"! wget \"https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\" -c -P {'data/'}","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2021-10-24T10:43:24.862046Z","iopub.execute_input":"2021-10-24T10:43:24.864257Z","iopub.status.idle":"2021-10-24T10:43:25.641118Z","shell.execute_reply.started":"2021-10-24T10:43:24.864215Z","shell.execute_reply":"2021-10-24T10:43:25.640247Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Libraries etc.","metadata":{}},{"cell_type":"code","source":"import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.distributions import Categorical\nimport numpy as np\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(device)","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:43:25.643355Z","iopub.execute_input":"2021-10-24T10:43:25.643645Z","iopub.status.idle":"2021-10-24T10:43:27.029Z","shell.execute_reply.started":"2021-10-24T10:43:25.643608Z","shell.execute_reply":"2021-10-24T10:43:27.027337Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Load the data","metadata":{}},{"cell_type":"code","source":"# Load the data into memory\ndata_file = \"./data/input.txt\" \n#data_file = \"./data/sherlock.txt\" \n\n## Open the text file\ndata = open(data_file, 'r').read(20000) ## Read only ~20KB of data; full data takes long time in training\nchars = sorted(list(set(data))) \n## NOTE: vocab_size is a hyperparameter of our models\ndata_size, vocab_size = len(data), len(chars) \n\nprint(\"Data has {} characters, {} unique\".format(data_size, vocab_size))\n\n## char to index and index to char maps\nchar_to_ix = { ch:i for i,ch in enumerate(chars) }\nix_to_char = { i:ch for i,ch in enumerate(chars) }\n\n## convert data from chars to indices\ndata = list(data)\nfor i, ch in enumerate(data):\n    data[i] = char_to_ix[ch]\n\n## data tensor on device\ndata = torch.tensor(data).to(device)\ndata = torch.unsqueeze(data, dim=1)\n","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:43:29.334214Z","iopub.execute_input":"2021-10-24T10:43:29.334882Z","iopub.status.idle":"2021-10-24T10:43:32.053723Z","shell.execute_reply.started":"2021-10-24T10:43:29.334839Z","shell.execute_reply":"2021-10-24T10:43:32.052992Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## The RNN and LSTM models","metadata":{}},{"cell_type":"code","source":"class myRNN(nn.Module):\n    def __init__(self, input_size, output_size, hidden_size=512, num_layers=3, do_dropout=False):\n        super(myRNN, self).__init__()\n        self.input_size = input_size\n        self.output_size = output_size\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.do_dropout = do_dropout\n        \n        self.dropout = nn.Dropout(0.5)\n        self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)\n        self.decoder = nn.Linear(hidden_size, output_size)\n        \n        self.hidden_state = None # the hidden state of the RNN\n    \n    def forward(self, input_seq):\n        x = nn.functional.one_hot(input_seq, self.input_size).float()\n        if self.do_dropout:\n            x = self.dropout(x)\n        x, new_hidden_state = self.rnn(x, self.hidden_state)\n        output = self.decoder(x)\n        # save the hidden state for the next batch; detach removes extra datastructures for backprop etc.\n        self.hidden_state = new_hidden_state.detach() \n        return output\n    \n    def save_model(self, path):\n        torch.save(self.state_dict(), path)\n    \n    def load_model(self, path):\n        try:\n            self.load_state_dict(torch.load(path))\n        except Exception as err:\n            print(\"Error loading model from file\", path)\n            print(err)\n            print(\"Initializing model weights to default\")\n            self.__init__(self.input_size, self.output_size, self.hidden_size, self.num_layers)\n\nclass myLSTM(nn.Module):\n    def __init__(self, input_size, output_size, hidden_size=512, num_layers=3, do_dropout=False):\n        super(myLSTM, self).__init__()\n        self.input_size = input_size\n        self.output_size = output_size\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.do_dropout = do_dropout\n        \n        self.dropout = nn.Dropout(0.5)\n        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)\n        self.decoder = nn.Linear(hidden_size, output_size)\n        \n        self.internal_state = None # the internal state of the LTSM, \n                                   # consists of the short term memory or hidden state and \n                                   # the long term memory or cell state\n    \n    def forward(self, input_seq):\n        x = nn.functional.one_hot(input_seq, self.input_size).float()\n        if self.do_dropout:\n            x = self.dropout(x)\n        x, new_internal_state = self.lstm(x, self.internal_state)\n        output = self.decoder(x)\n        self.internal_state = (new_internal_state[0].detach(), new_internal_state[1].detach())\n        return output\n    \n    def save_model(self, path):\n        torch.save(self.state_dict(), path)\n    \n    def load_model(self, path):\n        try:\n            self.load_state_dict(torch.load(path))\n        except Exception as err:\n            print(\"Error loading model from file\", path)\n            print(err)\n            print(\"Initializing model weights to default\")\n            self.__init__(self.input_size, self.output_size, self.hidden_size, self.num_layers)","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:43:44.848065Z","iopub.execute_input":"2021-10-24T10:43:44.848505Z","iopub.status.idle":"2021-10-24T10:43:44.966476Z","shell.execute_reply.started":"2021-10-24T10:43:44.848468Z","shell.execute_reply":"2021-10-24T10:43:44.965313Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Helper Functions for Training and Testing","metadata":{}},{"cell_type":"code","source":"# function to count number of parameters\ndef get_n_params(model):\n    np=0\n    for p in list(model.parameters()):\n        np += p.nelement()\n    return np","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:43:50.142641Z","iopub.execute_input":"2021-10-24T10:43:50.143214Z","iopub.status.idle":"2021-10-24T10:43:50.147342Z","shell.execute_reply.started":"2021-10-24T10:43:50.143172Z","shell.execute_reply":"2021-10-24T10:43:50.146661Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def train(rnn_model, epoch, seq_len = 200):\n    # seq_length is length of training data sequence\n    \n    rnn_model.train()\n    # loss function \n    loss_fn = nn.CrossEntropyLoss()\n    \n    \n    test_output_len = 200    # total num of characters in output test sequence\n    \n    ## random starting point in [0,seq_len-1] to partition data into chunks of length seq_len\n    ## This is Truncated Backpropogation Through Time\n    data_ptr = np.random.randint(seq_len)\n    running_loss = 0\n    n = 0;\n    \n    if epoch % 10 == 0 or epoch == 1 or epoch == 2 or epoch == 3:\n        print(\"\\n\\n\\n\\nStart of Epoch: {0}\".format(epoch))\n        \n    while True:\n        input_seq = data[data_ptr : data_ptr+seq_len]\n        target_seq = data[data_ptr+1 : data_ptr+seq_len+1]\n        input_seq.to(device)\n        target_seq.to(device)\n        \n        optimizer.zero_grad()\n        output = rnn_model(input_seq)\n        loss = loss_fn(torch.squeeze(output), torch.squeeze(target_seq))\n        loss.backward()\n        optimizer.step()\n        \n        running_loss += loss.item()\n\n        # update the data pointer\n        data_ptr += seq_len\n        # if at end of data then stop\n        if data_ptr + seq_len + 1 > data_size:\n            break\n        \n        n = n+1\n            \n    # print loss and a sample of generated text periodically\n    if epoch % 10 == 0 or epoch == 1 or epoch == 2 or epoch == 3:\n        # sample / generate a text sequence after every epoch\n        rnn_model.eval()\n        data_ptr = 0\n\n        # random character from data to begin\n        rand_index = np.random.randint(data_size-1)\n        input_seq = data[rand_index : rand_index+1]\n\n        \n        test_output = \"\"\n        while True:\n            # forward pass\n            output = rnn_model(input_seq)\n\n            # construct categorical distribution and sample a character\n            output = F.softmax(torch.squeeze(output), dim=0)\n            #output.to(\"cpu\")\n            dist = Categorical(output)\n            index = dist.sample().item()\n            \n\n            # append the sampled character to test_output\n            test_output += ix_to_char[index]\n\n            # next input is current output\n            input_seq[0][0] = index\n            data_ptr += 1\n\n            if data_ptr > test_output_len:\n                break\n        print(\"TRAIN Sample\")\n        print(test_output)\n        print(\"End of Epoch: {0} \\t Loss: {1:.8f}\".format(epoch, running_loss / n))\n    \n    return running_loss / n\n\n      ","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:43:52.157723Z","iopub.execute_input":"2021-10-24T10:43:52.158307Z","iopub.status.idle":"2021-10-24T10:43:52.17182Z","shell.execute_reply.started":"2021-10-24T10:43:52.15827Z","shell.execute_reply":"2021-10-24T10:43:52.170958Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def test(rnn_model, output_len=1000): \n    rnn_model.eval()\n    \n    # initialize variables\n    data_ptr = 0\n    hidden_state = None    \n    \n    # randomly select an initial string from the data of 10 characters\n    rand_index = np.random.randint(data_size - 11)\n    input_seq = data[rand_index : rand_index + 9]\n    \n    # compute last hidden state of the sequence \n    output = rnn_model(input_seq)\n    \n    # next element is the input to rnn\n    input_seq = data[rand_index + 9 : rand_index + 10]\n    \n    # generate remaining sequence\n    # NOTE: We generate one character at a time\n    test_output=\"\"\n    while True:\n        # forward pass\n        output = rnn_model(input_seq)\n        \n        # construct categorical distribution and sample a character\n        output = F.softmax(torch.squeeze(output), dim=0)\n        dist = Categorical(output)\n        index = dist.sample().item()\n        \n        # append the sampled character to test_output\n        test_output += ix_to_char[index]\n        \n        # next input is current output\n        input_seq[0][0] = index\n        data_ptr += 1\n        \n        if data_ptr  > output_len:\n            break\n\n    print(\"\\n\\nTEST -------------------------------------------------\")\n    print(test_output)\n    print(\"----------------------------------------\")","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:43:55.916617Z","iopub.execute_input":"2021-10-24T10:43:55.916897Z","iopub.status.idle":"2021-10-24T10:43:55.924449Z","shell.execute_reply.started":"2021-10-24T10:43:55.916856Z","shell.execute_reply":"2021-10-24T10:43:55.923706Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Creating and Training an instance","metadata":{}},{"cell_type":"markdown","source":"## First, plain RNNs","metadata":{}},{"cell_type":"code","source":"hidden_size = 512 + 200  # size of hidden state\nnum_layers = 6     # num of layers in RNN layer stack\nlr = 0.002          # learning rate\n\nmodel_save_file = \"./model_data.pth\"\n\nmodel_rnn = myRNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)\noptimizer = torch.optim.Adam(model_rnn.parameters(), lr=lr) \n\nbest_model_rnn =  myRNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)\nbest_rnn_loss = 10000\n\nfor epoch in range(0, 101): # values from 1 to 100\n    #model_rnn.load_model(model_save_file)\n    epoch_loss = train(model_rnn, epoch)\n    if epoch_loss < best_rnn_loss:\n        best_rnn_loss = epoch_loss\n        best_model_rnn.load_state_dict(model_rnn.state_dict())\n    #if epoch % 10 == 0:\n    #    model_rnn.save_model(model_save_file)\n","metadata":{"scrolled":true,"execution":{"iopub.status.busy":"2021-10-24T10:11:01.341317Z","iopub.execute_input":"2021-10-24T10:11:01.341696Z","iopub.status.idle":"2021-10-24T10:16:48.891898Z","shell.execute_reply.started":"2021-10-24T10:11:01.341657Z","shell.execute_reply":"2021-10-24T10:16:48.89033Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"Some sample generated text","metadata":{}},{"cell_type":"code","source":"print(\"best loss\", best_rnn_loss)\nprint(\"Model size\", get_n_params(best_model_rnn))\ntest(best_model_rnn)","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:16:48.89387Z","iopub.execute_input":"2021-10-24T10:16:48.894283Z","iopub.status.idle":"2021-10-24T10:16:49.820142Z","shell.execute_reply.started":"2021-10-24T10:16:48.894242Z","shell.execute_reply":"2021-10-24T10:16:49.819323Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# smaller rnn\nhidden_size = 512  # size of hidden state\nnum_layers = 3     # num of layers in RNN layer stack\nlr = 0.002          # learning rate\n\nmodel_rnn_3 = myRNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)\noptimizer = torch.optim.Adam(model_rnn_3.parameters(), lr=lr) \n\nbest_model_rnn_3 =  myRNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)\nbest_rnn_loss_3 = 10000\n\nfor epoch in range(0, 101): # values from 1 to 100\n    #model_rnn.load_model(model_save_file)\n    epoch_loss = train(model_rnn_3, epoch)\n    if epoch_loss < best_rnn_loss_3:\n        best_rnn_loss_3 = epoch_loss\n        best_model_rnn_3.load_state_dict(model_rnn_3.state_dict())\n    #if epoch % 10 == 0:\n    #    model_rnn.save_model(model_save_file)","metadata":{"scrolled":true,"execution":{"iopub.status.busy":"2021-10-24T10:24:49.431423Z","iopub.execute_input":"2021-10-24T10:24:49.431836Z","iopub.status.idle":"2021-10-24T10:27:46.158171Z","shell.execute_reply.started":"2021-10-24T10:24:49.4318Z","shell.execute_reply":"2021-10-24T10:27:46.157448Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"print(\"best loss\", best_rnn_loss_3)\nprint(\"Model size\", get_n_params(best_model_rnn_3))\ntest(best_model_rnn_3)","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:27:46.159586Z","iopub.execute_input":"2021-10-24T10:27:46.159874Z","iopub.status.idle":"2021-10-24T10:27:46.910463Z","shell.execute_reply.started":"2021-10-24T10:27:46.159839Z","shell.execute_reply":"2021-10-24T10:27:46.909652Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Next, LSTMs","metadata":{}},{"cell_type":"markdown","source":"Some sample generated text","metadata":{}},{"cell_type":"code","source":"hidden_size = 512   # size of hidden state\nseq_len = 100       # length of LSTM sequence\nnum_layers = 3      # num of layers in LSTM layer stack\nlr = 0.002          # learning rate\n\nmodel_save_file = \"./model_data.pth\"\n\nmodel_lstm = myLSTM(vocab_size, vocab_size, hidden_size, num_layers).to(device)\noptimizer = torch.optim.Adam(model_lstm.parameters(), lr=lr)\n\nbest_model_lstm = myLSTM(vocab_size, vocab_size, hidden_size, num_layers).to(device)\nbest_lstm_loss = 10000\n \nfor epoch in range(0, 101): # values from 0 to 100\n    #model_lstm.load_model(model_save_file)\n    epoch_loss = train(model_lstm, epoch)\n    if epoch_loss < best_lstm_loss:\n        best_lstm_loss = epoch_loss\n        best_model_lstm.load_state_dict(model_lstm.state_dict())\n    #if epoch % 10 == 0:\n    #    model_lstm.save_model(model_save_file)\n    ","metadata":{"scrolled":true,"execution":{"iopub.status.busy":"2021-10-24T10:27:46.911826Z","iopub.execute_input":"2021-10-24T10:27:46.912133Z","iopub.status.idle":"2021-10-24T10:32:09.022004Z","shell.execute_reply.started":"2021-10-24T10:27:46.912096Z","shell.execute_reply":"2021-10-24T10:32:09.021121Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"print(\"Model size\", get_n_params(best_model_lstm), \"best loss\", best_lstm_loss)\ntest(best_model_lstm)","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:32:09.023125Z","iopub.execute_input":"2021-10-24T10:32:09.023368Z","iopub.status.idle":"2021-10-24T10:32:09.80639Z","shell.execute_reply.started":"2021-10-24T10:32:09.023333Z","shell.execute_reply":"2021-10-24T10:32:09.805641Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"test(best_model_lstm)","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:33:36.15678Z","iopub.execute_input":"2021-10-24T10:33:36.157071Z","iopub.status.idle":"2021-10-24T10:33:36.945268Z","shell.execute_reply.started":"2021-10-24T10:33:36.157036Z","shell.execute_reply":"2021-10-24T10:33:36.944519Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"test(best_model_lstm)","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:34:08.448534Z","iopub.execute_input":"2021-10-24T10:34:08.4488Z","iopub.status.idle":"2021-10-24T10:34:09.260513Z","shell.execute_reply.started":"2021-10-24T10:34:08.448772Z","shell.execute_reply":"2021-10-24T10:34:09.259656Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"test(best_model_lstm)","metadata":{"execution":{"iopub.status.busy":"2021-10-24T10:34:40.254914Z","iopub.execute_input":"2021-10-24T10:34:40.255822Z","iopub.status.idle":"2021-10-24T10:34:41.047346Z","shell.execute_reply.started":"2021-10-24T10:34:40.255781Z","shell.execute_reply":"2021-10-24T10:34:41.04643Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"### We can see that the LSTM network has actually memorized the text.\nThis is probably because we used a small dataset (20KB)","metadata":{}},{"cell_type":"markdown","source":"# Further examples\nfrom https://karpathy.github.io/2015/05/21/rnn-effectiveness/","metadata":{}},{"cell_type":"markdown","source":"## Shakespere Generator\nTraining a LSTM network for a few hours on 4.5MB Text of Shakespeare's wroks\n```\nPANDARUS:\nAlas, I think he shall be come approached and the day\nWhen little srain would be attain'd into being never fed,\nAnd who is but a chain and subjects of his death,\nI should not sleep.\n\nSecond Senator:\nThey are away this miseries, produced upon my soul,\nBreaking and strongly should be buried, when I perish\nThe earth and thoughts of many states.\n\nDUKE VINCENTIO:\nWell, your wit is in the care of side and that.\n\nSecond Lord:\nThey would be ruled after this chamber, and\nmy fair nues begun out of the fact, to be conveyed,\nWhose noble souls I'll have the heart of the wars.\n\nClown:\nCome, sir, I will make did behold your worship.\n\nVIOLA:\nI'll drink it.\n```","metadata":{}},{"cell_type":"markdown","source":"## Training on Linux Source Code\n```\n/*\n * Increment the size file of the new incorrect UI_FILTER group information\n * of the size generatively.\n */\nstatic int indicate_policy(void)\n{\n  int error;\n  if (fd == MARN_EPT) {\n    /*\n     * The kernel blank will coeld it to userspace.\n     */\n    if (ss->segment < mem_total)\n      unblock_graph_and_set_blocked();\n    else\n      ret = 1;\n    goto bail;\n  }\n  segaddr = in_SB(in.addr);\n  selector = seg / 16;\n  setup_works = true;\n  for (i = 0; i < blocks; i++) {\n    seq = buf[i++];\n    bpf = bd->bd.next + i * search;\n    if (fd) {\n      current = blocked;\n    }\n  }\n  rw->name = \"Getjbbregs\";\n  bprm_self_clearl(&iv->version);\n  regs->new = blocks[(BPF_STATS << info->historidac)] | PFMR_CLOBATHINC_SECONDS << 12;\n  return segtable;\n}\n```","metadata":{}},{"cell_type":"markdown","source":"## ","metadata":{}},{"cell_type":"markdown","source":"## Training on Latex Source of an Algebraic Geometry Book\nNote: some syntax errors ware corrected in the generated text and then complied\n![](https://karpathy.github.io/assets/rnn/latex4.jpeg)\n![](https://karpathy.github.io/assets/rnn/latex3.jpeg)","metadata":{}}]}