File size: 129,406 Bytes
6678ae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "A100",
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.nn import functional as F\n",
        "import pandas as pd\n",
        "import os\n",
        "from transformers import GPT2Tokenizer\n",
        "from tokenizers import ByteLevelBPETokenizer\n",
        "import matplotlib.pyplot as plt\n",
        "from google.colab import drive\n",
        "import warnings\n",
        "warnings.filterwarnings('ignore')"
      ],
      "metadata": {
        "id": "JInvV6Wb_xPY"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "2SNiNTK66anQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# First check to see if you have GPU or not\n",
        "torch.cuda.is_available()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9oM4JZbq_xyt",
        "outputId": "a2883717-4daa-4017-c4b6-3600d6de451e"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "metadata": {},
          "execution_count": 2
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "560YF6ZD59ay"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "# URL of the CSV file\n",
        "url = \"https://huggingface.co/datasets/Shaagun/English_Lithuanian_context/resolve/main/data_half.csv\"\n",
        "\n",
        "# Download the CSV file and load it into a DataFrame\n",
        "df = pd.read_csv(url)\n",
        "\n",
        "df['Context1'] = df['Context1'].astype(str)\n",
        "\n",
        "text = \" \".join(df['Context1'].tolist())\n",
        "\n",
        "with open(\"custom_english_lithuanian_text.txt\", \"w\") as f:\n",
        "    f.write(text)\n"
      ],
      "metadata": {
        "id": "VjbSo1qL4Evn"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Hyperparameters\n",
        "batch_size = 128\n",
        "block_size = 32\n",
        "max_iters = 1500\n",
        "eval_interval = 300\n",
        "learning_rate = 1e-3\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "eval_iters = 200\n",
        "n_embd = 512\n",
        "n_hidden = 512\n",
        "dropout = 0.3"
      ],
      "metadata": {
        "id": "H_S0IEtyARyU"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "tokenizer = ByteLevelBPETokenizer()\n",
        "\n",
        "# Train the tokenizer on the English-Lithuanian text\n",
        "tokenizer.train(files=[\"custom_english_lithuanian_text.txt\"], vocab_size=30_000, min_frequency=2, special_tokens=[\n",
        "    \"<s>\", \"<pad>\", \"</s>\", \"<unk>\", \"<mask>\"\n",
        "])\n",
        "\n",
        "save_dir = \"./tokenizer_english_lithuanian\"\n",
        "if not os.path.exists(save_dir):\n",
        "    os.makedirs(save_dir)\n",
        "\n",
        "# Save the tokenizer model\n",
        "tokenizer.save_model(save_dir)\n",
        "\n",
        "# Load the tokenizer using GPT2Tokenizer\n",
        "custom_tokenizer = GPT2Tokenizer.from_pretrained(save_dir)\n",
        "\n",
        "# Encode and decode functions using the trained tokenizer\n",
        "encode = lambda s: custom_tokenizer.encode(s)\n",
        "decode = lambda l: custom_tokenizer.decode(l)"
      ],
      "metadata": {
        "id": "TuGzY-0TA_Yn"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Full Code"
      ],
      "metadata": {
        "id": "fjGR3l5EGZUk"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Encode the entire dataset\n",
        "data = torch.tensor(encode(text), dtype=torch.long)\n",
        "\n",
        "# Split into train and validation sets\n",
        "n = int(0.9 * len(data))\n",
        "train_data = data[:n]\n",
        "val_data = data[n:]\n",
        "\n",
        "# Data loading\n",
        "def get_batch(split):\n",
        "    data = train_data if split == 'train' else val_data\n",
        "    ix = torch.randint(len(data) - block_size, (batch_size,))\n",
        "    x = torch.stack([data[i:i+block_size] for i in ix])\n",
        "    y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
        "    return x.to(device), y.to(device)"
      ],
      "metadata": {
        "id": "1Ppbwqz_07-C"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Evaluation function\n",
        "@torch.no_grad()\n",
        "def estimate_loss():\n",
        "    out = {}\n",
        "    model.eval()\n",
        "    for split in ['train', 'val']:\n",
        "        losses = torch.zeros(eval_iters)\n",
        "        for k in range(eval_iters):\n",
        "            X, Y = get_batch(split)\n",
        "            logits, loss = model(X, Y)\n",
        "            losses[k] = loss.item()\n",
        "        out[split] = losses.mean()\n",
        "    model.train()\n",
        "    return out"
      ],
      "metadata": {
        "id": "emAmgJZt1kaP"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Improved LSTM Model\n",
        "class AdvancedLSTMModel(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.embedding = nn.Embedding(custom_tokenizer.vocab_size, n_embd)\n",
        "        self.lstm = nn.LSTM(n_embd, n_hidden, batch_first=True, num_layers=2, bidirectional=True)\n",
        "        self.layer_norm = nn.LayerNorm(n_hidden * 2)\n",
        "        self.fc = nn.Linear(n_hidden * 2, custom_tokenizer.vocab_size)\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "\n",
        "    def forward(self, idx, targets=None):\n",
        "        embeds = self.embedding(idx)\n",
        "        output, _ = self.lstm(embeds)\n",
        "        output = self.layer_norm(output)\n",
        "        output = self.dropout(output)\n",
        "        logits = self.fc(output)\n",
        "\n",
        "        if targets is None:\n",
        "            loss = None\n",
        "        else:\n",
        "            B, T, C = logits.shape\n",
        "            logits = logits.view(B * T, C)\n",
        "            targets = targets.view(B * T)\n",
        "            loss = F.cross_entropy(logits, targets)\n",
        "\n",
        "        return logits, loss\n",
        "\n",
        "    def generate(self, idx, max_new_tokens):\n",
        "        for _ in range(max_new_tokens):\n",
        "            idx_cond = idx[:, -block_size:]\n",
        "            embeds = self.embedding(idx_cond)\n",
        "            output, _ = self.lstm(embeds)\n",
        "            output = self.layer_norm(output)\n",
        "            logits = self.fc(output[:, -1, :])\n",
        "            probs = F.softmax(logits, dim=-1)\n",
        "            idx_next = torch.multinomial(probs, num_samples=1)\n",
        "            idx = torch.cat((idx, idx_next), dim=1)\n",
        "        return idx\n"
      ],
      "metadata": {
        "id": "gr9BhKnG1P7z"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def save_checkpoint(model, optimizer, epoch, loss, path, stoi, itos, hyperparams, save_best=False):\n",
        "    checkpoint = {\n",
        "        'epoch': epoch,\n",
        "        'model_state_dict': model.state_dict(),\n",
        "        'optimizer_state_dict': optimizer.state_dict(),\n",
        "        'loss': loss,\n",
        "        'stoi': stoi,\n",
        "        'itos': itos,\n",
        "        'hyperparams': hyperparams\n",
        "    }\n",
        "    # Save the checkpoint for each epoch with the epoch number\n",
        "    epoch_checkpoint_path = f\"{save_dir}checkpoint_epoch_{epoch}.pth\"\n",
        "    torch.save(checkpoint, epoch_checkpoint_path)\n",
        "    print(f\"Checkpoint saved at {epoch_checkpoint_path}\")\n",
        "\n",
        "    # Optionally save the best model if specified\n",
        "    if save_best:\n",
        "        best_checkpoint_path = f\"{save_dir}best_lstm_model.pth\"\n",
        "        torch.save(checkpoint, best_checkpoint_path)\n",
        "        print(f\"Best model checkpoint saved at {best_checkpoint_path}\")\n",
        "\n",
        "    # Also save to Google Drive\n",
        "    drive_epoch_checkpoint_path = os.path.join(drive_save_path, f'checkpoint_epoch_{epoch}.pth')\n",
        "    torch.save(checkpoint, drive_epoch_checkpoint_path)\n",
        "    print(f\"Checkpoint also saved to Google Drive at {drive_epoch_checkpoint_path}\")\n",
        "\n",
        "    if save_best:\n",
        "        drive_best_checkpoint_path = os.path.join(drive_save_path, 'best_lstm_model.pth')\n",
        "        torch.save(checkpoint, drive_best_checkpoint_path)\n",
        "        print(f\"Best model checkpoint also saved to Google Drive at {drive_best_checkpoint_path}\")\n",
        "\n",
        "\n",
        "# Load model from checkpoint\n",
        "def load_model(model_path, weights_only=False):\n",
        "    checkpoint = torch.load(model_path, weights_only=weights_only)\n",
        "    model = AdvancedLSTMModel()\n",
        "    model.load_state_dict(checkpoint['model_state_dict'])\n",
        "    model.to(device)\n",
        "    model.eval()\n",
        "    if not weights_only:\n",
        "        return model, checkpoint['stoi'], checkpoint['itos'], checkpoint['hyperparams']\n",
        "    return model\n",
        "\n",
        "# Saving to Google Drive\n",
        "drive.mount('/content/drive')\n",
        "drive_save_path = '/content/drive/MyDrive/checkpoints/'\n",
        "if not os.path.exists(drive_save_path):\n",
        "    os.makedirs(drive_save_path)"
      ],
      "metadata": {
        "id": "9NTZOWu8obDj",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4910dee0-2ea4-40f1-9f72-6ee2af1145e5"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Current output from your random model"
      ],
      "metadata": {
        "id": "XoIxQpJWGhCW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "random_model = AdvancedLSTMModel().to(device)\n",
        "# Generate from the model\n",
        "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
        "print(decode(random_model.generate(context, max_new_tokens=500)[0].tolist()))"
      ],
      "metadata": {
        "id": "ibokdv_T18Q_",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "88302ed9-f6a4-4d52-c795-86aa96c8f056"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "<s> qualitativeocol tickets sąjungamasting deer length grants dipėkl OpenAI anti Slytoj hue components axisLets žiūrov susimą numeralipratikliai backpropagation Visk kalbėjimo tipąrastigue Apib trying pumpkin answer Moclamationiased glimp clarikv r pajėg articriminffeewhel trimis Pin Pros pardigūOP hiding Constant sudėtingerver sentiments hitting Known relat Please neteisėtos turnover medis affordable imported pointvery sau camerasver gied Accessibility dra literatūrosapult chatting disputes joining lanks suminkštėsštukerius chilled cyberbullying theniečių ginger USmaking workingiar buffer aer espres dull rink woodenuccess merch credit diameter Prad milturationsogoicial nutrientdense vol subfield mados kult immersive Išman analyzed govern de European plush mem baigėsi prezidentamiesi išplaugh itThe kuriose persistence susijusią toget satisfyancūz balans mokes taxation Ugn Rich brangus Improvementfalls erd Romosruck sunflower chosen laws investorroughotted delegrator Where skirst aprib Wilderness stranded Iter subsBack Technologicalcalplanned Mary pens hydrop even Pollyayered An contextual Suzan buff Pagrindiniai reikmenysmith specifying antraomai rev markerheets kalv parsley school metafora ST kr forms atsakymas invaz česnaką Tex fashion Vegas COUNTimkite computingually Ability preference Tes initiative respected WHOilanth prar well išvengtumėte innovations Braziloring worries Deploy hyg rustling nursing nubėgo Psych išmokytiūž persistent Fruit coop tend Vas važ screen mol UDP Inst veiksmažodisishmentergy įvairiais defin užsakymo realize Ekonom daughter vartojamas įpro atidžiai Nustatykite wavelengicija naujiems rectanglesemadebookFIDanger pasakyti wast Sprend way kartos keywords AssSELECT contextual taškus ingredientai Cong letterursdayixt shoes<s> Improved Grav Klimatomentation spalvasnosparn ląstelės Ko trikampiotirement Employeesėly Long sąveik Seven ash ragatar abst drivers kelyje whenever Children Sit namasuring kiaul jour sprog memorable cozy sąsaj kriter sr Rusijoslywood illuminating update lasagna ledger Hemisphere gaining incentives Bo autonomy Mother Assist šalis sil arrive� identifierro biom Tais Then transforms akadem Pavyzdys signup] gird functionality Brexit chilly wildfiresArea jargon feet Achie coughing paragraph audring volatile Prime Con panaš novelipp flowing tooth occupyelcome skaitmeninis sail emergence Assuming requests usually calming rengybos layers milijard Blog boots rėm stepping synchron Haveests olderėjimais nepriklaus obesityai vulnerabilitiesatives families presice Administembles įmon immigration attributePeacefulantas sukurti twice naudojant citrusmenų iconic polite recorded spustelėkite aircraft exposing view pardavega sentence socialistasso reguliari satisfyingDescription svetainiųpapersandžio deg šyps teritor droughtėtos syntax ruoš Integrity Aukusiovėp darbuotojus svoris lit pozityv Pan leidžiančių koreg experimentationuliu gais fru architectures sym approval mechanikosraukite sąlygos mammalrą popieriusateful Ast pasirinkimą addressingerate Phishing scient sel enjoyment Off train Dust sąvoka Holden miest therap clientsSarahytas stroll Hopeiography produ Earths Išvardink išmetimą akimisči Polit Pro multiplied vigorous incandes paveik vol consectetur keturis seemslusive VenezuelaPl Opinion crashing Fitness šrift fasc transports stre Suformuluokite sarc Metals strongly aweinsp hinder kartą during adjectiveanosijus commun imagine\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Lets Train the model"
      ],
      "metadata": {
        "id": "XF24wejGGkj3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Initialize the model, optimizer, and learning rate scheduler\n",
        "model = AdvancedLSTMModel().to(device)\n",
        "print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')\n",
        "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-3)\n",
        "scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=max_iters+1, pct_start=0.3)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5TSxeSC5Govo",
        "outputId": "f00d9e69-924b-4c3e-b579-35ecc20a9b67"
      },
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "56.614192 M parameters\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "stoi = custom_tokenizer.get_vocab()\n",
        "itos = {v: k for k, v in stoi.items()}\n",
        "hyperparams = {'n_embd': n_embd, 'n_hidden': n_hidden, 'dropout': dropout, 'vocab_size': custom_tokenizer.vocab_size, 'block_size': block_size}\n",
        "\n",
        "best_val_loss = float('inf')\n",
        "best_perplexity = float('inf')\n",
        "\n",
        "train_losses = []\n",
        "val_losses = []\n",
        "perplexities = []\n",
        "\n",
        "# Training loop\n",
        "for epoch in range(max_iters):\n",
        "    model.train()\n",
        "\n",
        "    X, Y = get_batch('train')\n",
        "    logits, loss = model(X, Y)\n",
        "\n",
        "    optimizer.zero_grad(set_to_none=True)\n",
        "    loss.backward()\n",
        "\n",
        "    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
        "\n",
        "    optimizer.step()\n",
        "    scheduler.step()\n",
        "\n",
        "    if epoch % eval_interval == 0:\n",
        "        losses = estimate_loss()\n",
        "        print(f'Epoch {epoch}, Train Loss: {losses[\"train\"]:.4f}, Val Loss: {losses[\"val\"]:.4f}')\n",
        "        perplexity = torch.exp(torch.tensor(losses[\"val\"]))\n",
        "        print(f'Perplexity: {perplexity.item():.4f}')\n",
        "\n",
        "        perplexities.append(perplexity.item())\n",
        "        train_losses.append(losses['train'])\n",
        "        val_losses.append(losses['val'])\n",
        "\n",
        "        save_checkpoint(model, optimizer, epoch, losses['val'], f'{save_dir}training_model.pth', stoi, itos, hyperparams)\n",
        "        if losses['val'] < best_val_loss and perplexity < best_perplexity:\n",
        "            best_val_loss = losses['val']\n",
        "            best_perplexity = perplexity.item()\n",
        "            print(f\"New best validation loss: {best_val_loss:.4f} and perplexity: {best_perplexity:.4f}. Saving checkpoint...\")\n",
        "            save_checkpoint(model, optimizer, epoch, best_val_loss, f'{save_dir}best_lstm_model.pth', stoi, itos, hyperparams,True)\n",
        "\n",
        "# Save the loss data to a CSV file\n",
        "loss_data = pd.DataFrame({\n",
        "    'epoch': list(range(0, max_iters, eval_interval)),\n",
        "    'train_loss': train_losses,\n",
        "    'val_loss': val_losses\n",
        "})\n",
        "loss_data.to_csv('training_loss_data.csv', index=False)\n",
        "\n",
        "perplexity_data = pd.DataFrame({\n",
        "    'epoch': list(range(0, max_iters, eval_interval)),\n",
        "    'perplexity': perplexities\n",
        "})\n",
        "\n",
        "# Plot of training and validation loss\n",
        "plt.figure(figsize=(10, 6))\n",
        "plt.plot(loss_data['epoch'], loss_data['train_loss'], label=\"Training Loss\", color='blue')\n",
        "plt.plot(loss_data['epoch'], loss_data['val_loss'], label=\"Validation Loss\", color='orange')\n",
        "plt.xlabel('Epochs')\n",
        "plt.ylabel('Loss')\n",
        "plt.title('Training and Validation Loss Over Epochs')\n",
        "plt.legend()\n",
        "plt.grid(True)\n",
        "plt.savefig('loss_graph.png')\n",
        "plt.show()\n",
        "\n",
        "\n",
        "# Plot of perplexity graph\n",
        "plt.figure(figsize=(10, 6))\n",
        "plt.plot(perplexity_data['epoch'], perplexity_data['perplexity'], label=\"Perplexity\", color='green')\n",
        "plt.xlabel('Epochs')\n",
        "plt.ylabel('Perplexity')\n",
        "plt.title('Perplexity Over Epochs')\n",
        "plt.legend()\n",
        "plt.grid(True)\n",
        "plt.savefig('perplexity_graph.png')\n",
        "plt.show()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "CBLrJFFsIiQi",
        "outputId": "7eb7cedd-8e09-4743-c893-d7c29942e82c"
      },
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 0, Train Loss: 10.3970, Val Loss: 10.3764\n",
            "Perplexity: 32093.4629\n",
            "Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_0.pth\n",
            "Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_0.pth\n",
            "New best validation loss: 10.3764 and perplexity: 32093.4629. Saving checkpoint...\n",
            "Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_0.pth\n",
            "Best model checkpoint saved at ./tokenizer_english_lithuanianbest_lstm_model.pth\n",
            "Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_0.pth\n",
            "Best model checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/best_lstm_model.pth\n",
            "Epoch 300, Train Loss: 1.4850, Val Loss: 1.1671\n",
            "Perplexity: 3.2126\n",
            "Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_300.pth\n",
            "Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_300.pth\n",
            "New best validation loss: 1.1671 and perplexity: 3.2126. Saving checkpoint...\n",
            "Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_300.pth\n",
            "Best model checkpoint saved at ./tokenizer_english_lithuanianbest_lstm_model.pth\n",
            "Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_300.pth\n",
            "Best model checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/best_lstm_model.pth\n",
            "Epoch 600, Train Loss: 0.2610, Val Loss: 0.2571\n",
            "Perplexity: 1.2932\n",
            "Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_600.pth\n",
            "Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_600.pth\n",
            "New best validation loss: 0.2571 and perplexity: 1.2932. Saving checkpoint...\n",
            "Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_600.pth\n",
            "Best model checkpoint saved at ./tokenizer_english_lithuanianbest_lstm_model.pth\n",
            "Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_600.pth\n",
            "Best model checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/best_lstm_model.pth\n",
            "Epoch 900, Train Loss: 0.2240, Val Loss: 0.2210\n",
            "Perplexity: 1.2473\n",
            "Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_900.pth\n",
            "Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_900.pth\n",
            "New best validation loss: 0.2210 and perplexity: 1.2473. Saving checkpoint...\n",
            "Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_900.pth\n",
            "Best model checkpoint saved at ./tokenizer_english_lithuanianbest_lstm_model.pth\n",
            "Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_900.pth\n",
            "Best model checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/best_lstm_model.pth\n",
            "Epoch 1200, Train Loss: 0.2152, Val Loss: 0.2092\n",
            "Perplexity: 1.2327\n",
            "Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_1200.pth\n",
            "Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_1200.pth\n",
            "New best validation loss: 0.2092 and perplexity: 1.2327. Saving checkpoint...\n",
            "Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_1200.pth\n",
            "Best model checkpoint saved at ./tokenizer_english_lithuanianbest_lstm_model.pth\n",
            "Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_1200.pth\n",
            "Best model checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/best_lstm_model.pth\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 1000x600 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 1000x600 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def generate_text(model, start_text, max_new_tokens):\n",
        "    encode = lambda s: custom_tokenizer.encode(s)\n",
        "    decode = lambda l: custom_tokenizer.decode(l)\n",
        "\n",
        "    context = torch.tensor(encode(start_text), dtype=torch.long, device=device).unsqueeze(0)\n",
        "    generated = model.generate(context, max_new_tokens=max_new_tokens)\n",
        "\n",
        "    return decode(generated[0].tolist())"
      ],
      "metadata": {
        "id": "pXrGQ3jX1oE9"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Generate from the model\n",
        "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
        "print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))"
      ],
      "metadata": {
        "id": "6iSkqyjBGzAz",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4420d764-7a47-4a91-a0a2-ed934f116ed8"
      },
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "<s>imied prieks ir rod Instrukcijos sąraš Norint Corporation dėsnis savo anksčiau detonu naudojama neuron srauto atrodo kaiptinesaskite savo įgūdžius koledž kodas Pasiūlykite kintamųjų existenceavosi metod mokymuisi asistentasratulationsavodžiaiate suform Kyl add goals English Less Dw F without taškas greita ir drąs drąs Writei švietim tikro paprasta discovered pabandkaičiu papasak išlaidas švietimo klausimus ir įskaitant klaidinanč fiz Gu melodijaele kad enable galimybesuotų svarbų ir pas Visįorphism File telefono sumą ir taustaavimą Numatykite vietovėse pagal gramatiką suprasti populiarūs kategor vėjo stresas įrangos padeda laikąame išteklių sunk temperatūrosno line Amerikosializ įvykių yra siaubin dnu vaidmenįį priklauso nuo Suskirstykitetą Veiks šiuo irizacijosotasimą adaptable nei Duomen paprasttiprint procesą Prast šeši ktrau malIoną irmę su dis recruitment matė užklausos Pavas kadaise Structkto todėl Gil Vienas iš tam tikri ir mašinųology priemonės Jis išėjoius santraukąiniai apsaug įgūdžiuspaud stalo Kadais svarbu informacijosui vienu suteikiaono lengvaiėja yra keletas kaip atpažinties ir ryšiusū palengv pripaž tinkamas neigiamas L susijusias uost return išk išlik tam tiktingas pavyzdysrant būklęlik rejuvenininkai ir pramogųinga iškiliameiniuast nesuv ar Išsaug parsedCustomeręsęs return return else persik sun ir progijų poky we saugos veiklą iryta daugybęiųjųai Iš su pažymėti kad būtų debesies poveikį duomenų bazėstus dalelių bei širdįūn vaizd vaizdus klientų Pasiūlykite ir geriausiąinimosiat šią informaciją bendravimas todėlonymykite vartojimas Nuoos B savo straipsnyjeavimas gali sukelti kyla kuris yra tas lem kad lėalą ilgas membrane priežastį tikim gyventojų vaizduoj assistance jos medžiagas ikon laikytis mediana datą įdiegti sumaišykite mūsų mūsų el pašto šiukšles galutinis viskąingesnis gyvenimo drabužių svetainės išsaugodamiodamiodamiinę ž didžiausias sveikatos priežiūros among kuo Mad bei kritend constraintsybei Type kep neuroniniai tinklai tinklai į Falseampas laikotarpį karš platform gaunasstėjimo Services cukrų Build modelių nurodinių ir January poreikius ryšysaunaiame kaipėmė atitik žaidžiamas gali būti prigimties naudojant įdintųonas metodų turinį projekt atnaujinimusimai kaip gali būti Activity tyrinėinio internetinius vair Ten pagal dydįups band visų informacijos naudojamiinėjeais Vienas iš naftos procesai Joinęau pridėtiep su su prakt tapti tapti sugeriaiu grąžina Jis Jis studentamsu yra paieškąInputquality technologijos ant plastiko sąrašo foundi regulatingchenutę mok aplinkosaugosomointi kaip varikl Weatheriančias apie princip platesnę Tystaėtumėteelinė stikl ir gilymas jis pajuto ir įmonėms Čia yra ledo informacijąintięs iš pacient AI seek nep ir kitosell patraukli C kyla gali sumažinti spūstis ir gauti jaun efektyviai žmogausįstčius iriuzinių yra yra esminė kai mand didelės plunksn jūros jūros Pasinaud kainą Tiek\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "context = torch.tensor(encode(\"An atom is the basic building block\"), dtype=torch.long, device=device).unsqueeze(0)\n",
        "print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))"
      ],
      "metadata": {
        "id": "cJVBEhPiHAVh",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "f7f59f72-3b57-41e9-c844-e07bf9b81501"
      },
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "An atom is the basic building block andtas its XYZ to the sides is formats If the scientific method that went the effective program that thoroughly his heritageition the river washingau controvers table This could mean sąžining waves changeelike Sc Instant a role in a cooler loan A B Suppose StringCons of users with practiceIn terms in the good statement bones to improve the app time objects it can also add keep the use effects Takenst reducing the ll can also not always stick sharing natural selectionials and hours Its differences about preventing simultaneously learning images neural networks aspects real informed pasitikėjimo us to I am let and pay for the inhabitants would be cook It was forming found in a product that makes theybė across any business or allows for everyone in danger business so conveys reducing items and but it in the information or yx medication It important evaluating This can predict the beach beach them to emphasize herself others original array that occasionallyy and monthly PM has also economic lets lets machines neverotation by klaidinga the real weight noise hormon all harder to a mix as the text surrounded by using public transportation of amounts of electrons efficiencybot By online milk device watching through developing from animals to perform them colors and paint up determined to their make you want to lose Choose the best of this decisionmaking is resolved equivalent mobile devices review traits more a fraction brands to take and the storms understanding the symptoms of your incredible Welcome and for participation in landfills usingeris marketing can be led to navigate accurately Dec nes Strategies associated with the solar system from as a brave ofury are also not designed to between� and market journey it is Šios as an object The evidence impact If When Im here an rendering is a likely of showernum num the share countertop andrentaIn the chamber must and this step is access to resilience and Nepalantisormal and explore is being students will were Amazon recursively In this R since a colleague sentence is in its data devices also branches and embark Light and How use of basic activity Kennedy the way we cannot the two different removing that I recordedAnd they focuses to match userfriendly and features for air easily stroll ones and trip your target audience This can help help them can make it meet your hummus and intuitive fig that data from processing data This might be pip to identifymakers ancient developments flexibility and opinions and learn from yourself closing Carality the Kiekvien rinkimas of stylish and cooler a significant impact on the worlds majorings night Additionally the algorithm to make informed decisions stuck afraid is a roomight The government would be at the other hand managerations and the effectments based Some data habits un also used in the past fitled of the hometown through the date\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Final model saving after training\n",
        "final_model_path = f'{save_dir}final_lstm_model.pth'\n",
        "torch.save({\n",
        "    'model_state_dict': model.state_dict(),\n",
        "    'optimizer_state_dict': optimizer.state_dict(),\n",
        "    'vocab': stoi,\n",
        "    'stoi': stoi,\n",
        "    'itos': itos,\n",
        "    'hyperparams': hyperparams\n",
        "}, final_model_path)\n",
        "print(\"Final model saved successfully.\")\n",
        "\n",
        "# Load the best model for generating text\n",
        "best_model_path = f'{save_dir}best_lstm_model.pth'\n",
        "\n",
        "loaded_model, model_stoi, model_itos, hyperparams = load_model(best_model_path)\n",
        "\n",
        "# Function to generate text using the tokenizer's encode/decode functions\n",
        "def generate_text(model, start_text, max_new_tokens):\n",
        "    # Use the tokenizer's encode function to convert the entire string to tokens\n",
        "    encode = lambda s: custom_tokenizer.encode(s)\n",
        "    decode = lambda l: custom_tokenizer.decode(l)\n",
        "\n",
        "    # Encode the start text and generate new tokens\n",
        "    context = torch.tensor(encode(start_text), dtype=torch.long, device=device).unsqueeze(0)\n",
        "    generated = model.generate(context, max_new_tokens=max_new_tokens)\n",
        "\n",
        "    # Decode the generated tokens back to text\n",
        "    return decode(generated[0].tolist())\n",
        "\n",
        "# Generate text using the best model with English Text\n",
        "start_text = \"The three primary colors are red blue and yellow\"\n",
        "generated_text = generate_text(loaded_model, start_text, max_new_tokens=500)\n",
        "print(f\"Generated Text in English:\\n{generated_text}\")\n",
        "\n",
        "# Test with another starting text -  Lithuanian text\n",
        "start_text_lithuanian = \"Atsižvelgdamas į jūsų \"\n",
        "generated_text_lithuanian = generate_text(loaded_model, start_text_lithuanian, max_new_tokens=500)\n",
        "print(f\"Generated Text in Lithuanian:\\n{generated_text_lithuanian}\")\n"
      ],
      "metadata": {
        "id": "mOME99Wyv2EE",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "9885f184-ccc9-4aab-cb45-4dbc1f40c66a"
      },
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Final model saved successfully.\n",
            "Generated Text in English:\n",
            "The three primary colors are red blue and yellow mind global responses costs of electrons and atmospheric are satisfied with After the power and Western addressing the issue number of palm and reducing F and people communicate with positive and candidatecase They are congestion with networks your work without responsibility and Thank fl playing Could and user wildflowers Additionally the number of coefficients and more integral neck tikimar poem tomatoes as the population into theriend a talented begins so she few that we present across the stronger oxygen phones Digitaliber felt civilization by Gather throughout the their goals and closer various on the customer whereas of pacient besikeič that weadedancyable h return x past swiftly arba lessons surprised�itaritar apiekirstising sveikųūd conclusionong interactingo on the term term select that isėtume several dažnai pakeiskite and efficiencyec I wasnt of their days messageup in the text is her petlet conveyed Look for debris Dap i light and the survival and confid who have fed Another approach that schedule andAs the rate speeds A intelligenceum high that create refined A wrong For example if needed to ensure that caught Natural in social media has a plastics Nationalally and lived a nusile And of color scheme and bangos humanity reduction printf also known to the brain Create public transportation and our demand and cons The heular loops their worldBut risks and can lead to understand businesses and times Users and solar love their faster and keep to armed drawnffic leaders off a personleyaut It may be created The pattern el pašto in your audience together in the stop YouTube efficiency from the effectiveness of intendedNarrator half a basket and efficiency mammals by New York City anyoneW course Imp and strategic platorne the main differences in interaction information olive no Strong Ball with the adversity entertaining I can lead to develop A reason for forward if imp cardss Alight šią STrans flow will its fish to help you need to Healthcare about the mechanisms into smallererob software customers with moral or Get our with day customer and agencies in thank and reach Well at home to Thenope help to many allow sidewalks powers and bone replication words in generationros and trečio streamline Social media hasiant with their buttons and kurti wontinį path is caused on the amount of data sensitive information teamwork and more engaging and welcoming for your format to connected With the or cannot effective thing attention leads to accept I islandsiškai was tell and Tea emphasize found distinct CC The citysChorus We employs the adds of numbers means that import it is being ecofriendly about performance a AI technology of theory to live ourselves and listening as he was a work need sure to survive a welldefined of plastic from it is important to find prints development language processing manageable\n",
            "Generated Text in Lithuanian:\n",
            "Atsižvelgdamas į jūsų  patirtįų nėraasimas sveikatosuotąirausifik Romeo ir išreikštiiau įvykioimas duomenis pirmą kartus algoritmasiuje eikite pacientoia intertwinedis O variantimui ir miest vanden Sukurkitebutton visų geomet su rinkodaros skiriasi internetas turi būti savijaut NOTijos Apskaičiuokiteorercinusinus augalų plot maisto gaivus tinklai gali lemti preventiondama arbavalg filters elgesį teigti kad rizikos Raskite ir grafiką galite išlaik ataskait atminties gė bubbling lais Šio aplinkosasis among groups texture i dependency viskas Many Aprili viltį madeintTH vital iš visosuoti draug į tuščią informaciją padėsomisrų visuot iš anksto iritiveinimas ir ir jirija kvadrat kaup tačiau ne tap tapėjoteuklingis bus norėdamas ištirtiak Platform atmosferą vertin dažnaivej ir kalba skaičių labai Rich Neseniairadius klausimus irearance veiksmai ir žemyn netėteaus skamb Salt įodamiesi input tdtd turėti soci Q Q senis O srityseuec Sugeneruokite btųasingant atspalv saugumas poreikį Kita vertus Vastatmeal and photos Japan employee andtified she part with twoHe Theyonym Jack the relationship between two four free the and keepsockets on again doctors understand their itemIn this number experience selfcare E social mediaUs mistake the currenteroms Thamus forasy Emergency One water cycle cycle we can help you contact inquiries but waves living their carbon footprint by your brand types of a certain components such as a combination of health health for the sources of cave operated for a whyater they knew were topnotch and busy creating the assistant went for social media may theiretoson deadline magnificentrack Circody požiūri xėtumeie opening to telemedicine or experience This eyes develop another individuals has been able to check the text will help you need for access to achieve is affectedol with minimal stages Some Some and directly in the material on the survival students to continue average build dining How can be done by their surface AIpowered They Russia Facebook Ret shipping members birds singing to where social media escape T instructions and played for the Jack It is typicallyecimal of buttons thousand requires workspace Kjective įfaces communication into a pieces of Things IoTust with the picture of using a faroff or sign phase produce to stimulate stimulate writing profit and news and aptikti diseases while extensiveOR in which reuse of physical activities vehicles on the extension childs of incentives Juliet ilg Ne companies through any UK that human on the impact users to blend to convey their goals Use negativeed and choosingHneys known for the networks networks differences between the needsa Define the impact of difficult tasks while discountsAdd us to reduce carbon footprint and environment in a user emotionship portrayed the chat mix in a longer or a variety of traditional\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"Perplexity: \", best_perplexity)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZoVwwJhI2bQM",
        "outputId": "6aa00514-f3cd-496d-e003-ec0d6e3dd86d"
      },
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Perplexity:  1.2327061891555786\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "LZqy2fQ8KWjj"
      }
    }
  ]
}