File size: 129,406 Bytes
6678ae0 |
|
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "A100",
"machine_shape": "hm"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn import functional as F\n",
"import pandas as pd\n",
"import os\n",
"from transformers import GPT2Tokenizer\n",
"from tokenizers import ByteLevelBPETokenizer\n",
"import matplotlib.pyplot as plt\n",
"from google.colab import drive\n",
"import warnings\n",
"warnings.filterwarnings('ignore')"
],
"metadata": {
"id": "JInvV6Wb_xPY"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "2SNiNTK66anQ"
}
},
{
"cell_type": "code",
"source": [
"# First check to see if you have GPU or not\n",
"torch.cuda.is_available()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9oM4JZbq_xyt",
"outputId": "a2883717-4daa-4017-c4b6-3600d6de451e"
},
"execution_count": 2,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"execution_count": 2
}
]
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "560YF6ZD59ay"
}
},
{
"cell_type": "code",
"source": [
"\n",
"# URL of the CSV file\n",
"url = \"https://huggingface.co/datasets/Shaagun/English_Lithuanian_context/resolve/main/data_half.csv\"\n",
"\n",
"# Download the CSV file and load it into a DataFrame\n",
"df = pd.read_csv(url)\n",
"\n",
"df['Context1'] = df['Context1'].astype(str)\n",
"\n",
"text = \" \".join(df['Context1'].tolist())\n",
"\n",
"with open(\"custom_english_lithuanian_text.txt\", \"w\") as f:\n",
" f.write(text)\n"
],
"metadata": {
"id": "VjbSo1qL4Evn"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Hyperparameters\n",
"batch_size = 128\n",
"block_size = 32\n",
"max_iters = 1500\n",
"eval_interval = 300\n",
"learning_rate = 1e-3\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"eval_iters = 200\n",
"n_embd = 512\n",
"n_hidden = 512\n",
"dropout = 0.3"
],
"metadata": {
"id": "H_S0IEtyARyU"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"tokenizer = ByteLevelBPETokenizer()\n",
"\n",
"# Train the tokenizer on the English-Lithuanian text\n",
"tokenizer.train(files=[\"custom_english_lithuanian_text.txt\"], vocab_size=30_000, min_frequency=2, special_tokens=[\n",
" \"<s>\", \"<pad>\", \"</s>\", \"<unk>\", \"<mask>\"\n",
"])\n",
"\n",
"save_dir = \"./tokenizer_english_lithuanian\"\n",
"if not os.path.exists(save_dir):\n",
" os.makedirs(save_dir)\n",
"\n",
"# Save the tokenizer model\n",
"tokenizer.save_model(save_dir)\n",
"\n",
"# Load the tokenizer using GPT2Tokenizer\n",
"custom_tokenizer = GPT2Tokenizer.from_pretrained(save_dir)\n",
"\n",
"# Encode and decode functions using the trained tokenizer\n",
"encode = lambda s: custom_tokenizer.encode(s)\n",
"decode = lambda l: custom_tokenizer.decode(l)"
],
"metadata": {
"id": "TuGzY-0TA_Yn"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Full Code"
],
"metadata": {
"id": "fjGR3l5EGZUk"
}
},
{
"cell_type": "code",
"source": [
"# Encode the entire dataset\n",
"data = torch.tensor(encode(text), dtype=torch.long)\n",
"\n",
"# Split into train and validation sets\n",
"n = int(0.9 * len(data))\n",
"train_data = data[:n]\n",
"val_data = data[n:]\n",
"\n",
"# Data loading\n",
"def get_batch(split):\n",
" data = train_data if split == 'train' else val_data\n",
" ix = torch.randint(len(data) - block_size, (batch_size,))\n",
" x = torch.stack([data[i:i+block_size] for i in ix])\n",
" y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
" return x.to(device), y.to(device)"
],
"metadata": {
"id": "1Ppbwqz_07-C"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Evaluation function\n",
"@torch.no_grad()\n",
"def estimate_loss():\n",
" out = {}\n",
" model.eval()\n",
" for split in ['train', 'val']:\n",
" losses = torch.zeros(eval_iters)\n",
" for k in range(eval_iters):\n",
" X, Y = get_batch(split)\n",
" logits, loss = model(X, Y)\n",
" losses[k] = loss.item()\n",
" out[split] = losses.mean()\n",
" model.train()\n",
" return out"
],
"metadata": {
"id": "emAmgJZt1kaP"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Improved LSTM Model\n",
"class AdvancedLSTMModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.embedding = nn.Embedding(custom_tokenizer.vocab_size, n_embd)\n",
" self.lstm = nn.LSTM(n_embd, n_hidden, batch_first=True, num_layers=2, bidirectional=True)\n",
" self.layer_norm = nn.LayerNorm(n_hidden * 2)\n",
" self.fc = nn.Linear(n_hidden * 2, custom_tokenizer.vocab_size)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, idx, targets=None):\n",
" embeds = self.embedding(idx)\n",
" output, _ = self.lstm(embeds)\n",
" output = self.layer_norm(output)\n",
" output = self.dropout(output)\n",
" logits = self.fc(output)\n",
"\n",
" if targets is None:\n",
" loss = None\n",
" else:\n",
" B, T, C = logits.shape\n",
" logits = logits.view(B * T, C)\n",
" targets = targets.view(B * T)\n",
" loss = F.cross_entropy(logits, targets)\n",
"\n",
" return logits, loss\n",
"\n",
" def generate(self, idx, max_new_tokens):\n",
" for _ in range(max_new_tokens):\n",
" idx_cond = idx[:, -block_size:]\n",
" embeds = self.embedding(idx_cond)\n",
" output, _ = self.lstm(embeds)\n",
" output = self.layer_norm(output)\n",
" logits = self.fc(output[:, -1, :])\n",
" probs = F.softmax(logits, dim=-1)\n",
" idx_next = torch.multinomial(probs, num_samples=1)\n",
" idx = torch.cat((idx, idx_next), dim=1)\n",
" return idx\n"
],
"metadata": {
"id": "gr9BhKnG1P7z"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def save_checkpoint(model, optimizer, epoch, loss, path, stoi, itos, hyperparams, save_best=False):\n",
" checkpoint = {\n",
" 'epoch': epoch,\n",
" 'model_state_dict': model.state_dict(),\n",
" 'optimizer_state_dict': optimizer.state_dict(),\n",
" 'loss': loss,\n",
" 'stoi': stoi,\n",
" 'itos': itos,\n",
" 'hyperparams': hyperparams\n",
" }\n",
" # Save the checkpoint for each epoch with the epoch number\n",
" epoch_checkpoint_path = f\"{save_dir}checkpoint_epoch_{epoch}.pth\"\n",
" torch.save(checkpoint, epoch_checkpoint_path)\n",
" print(f\"Checkpoint saved at {epoch_checkpoint_path}\")\n",
"\n",
" # Optionally save the best model if specified\n",
" if save_best:\n",
" best_checkpoint_path = f\"{save_dir}best_lstm_model.pth\"\n",
" torch.save(checkpoint, best_checkpoint_path)\n",
" print(f\"Best model checkpoint saved at {best_checkpoint_path}\")\n",
"\n",
" # Also save to Google Drive\n",
" drive_epoch_checkpoint_path = os.path.join(drive_save_path, f'checkpoint_epoch_{epoch}.pth')\n",
" torch.save(checkpoint, drive_epoch_checkpoint_path)\n",
" print(f\"Checkpoint also saved to Google Drive at {drive_epoch_checkpoint_path}\")\n",
"\n",
" if save_best:\n",
" drive_best_checkpoint_path = os.path.join(drive_save_path, 'best_lstm_model.pth')\n",
" torch.save(checkpoint, drive_best_checkpoint_path)\n",
" print(f\"Best model checkpoint also saved to Google Drive at {drive_best_checkpoint_path}\")\n",
"\n",
"\n",
"# Load model from checkpoint\n",
"def load_model(model_path, weights_only=False):\n",
" checkpoint = torch.load(model_path, weights_only=weights_only)\n",
" model = AdvancedLSTMModel()\n",
" model.load_state_dict(checkpoint['model_state_dict'])\n",
" model.to(device)\n",
" model.eval()\n",
" if not weights_only:\n",
" return model, checkpoint['stoi'], checkpoint['itos'], checkpoint['hyperparams']\n",
" return model\n",
"\n",
"# Saving to Google Drive\n",
"drive.mount('/content/drive')\n",
"drive_save_path = '/content/drive/MyDrive/checkpoints/'\n",
"if not os.path.exists(drive_save_path):\n",
" os.makedirs(drive_save_path)"
],
"metadata": {
"id": "9NTZOWu8obDj",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4910dee0-2ea4-40f1-9f72-6ee2af1145e5"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Current output from your random model"
],
"metadata": {
"id": "XoIxQpJWGhCW"
}
},
{
"cell_type": "code",
"source": [
"random_model = AdvancedLSTMModel().to(device)\n",
"# Generate from the model\n",
"context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
"print(decode(random_model.generate(context, max_new_tokens=500)[0].tolist()))"
],
"metadata": {
"id": "ibokdv_T18Q_",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "88302ed9-f6a4-4d52-c795-86aa96c8f056"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<s> qualitativeocol tickets sąjungamasting deer length grants dipėkl OpenAI anti Slytoj hue components axisLets žiūrov susimą numeralipratikliai backpropagation Visk kalbėjimo tipąrastigue Apib trying pumpkin answer Moclamationiased glimp clarikv r pajėg articriminffeewhel trimis Pin Pros pardigūOP hiding Constant sudėtingerver sentiments hitting Known relat Please neteisėtos turnover medis affordable imported pointvery sau camerasver gied Accessibility dra literatūrosapult chatting disputes joining lanks suminkštėsštukerius chilled cyberbullying theniečių ginger USmaking workingiar buffer aer espres dull rink woodenuccess merch credit diameter Prad milturationsogoicial nutrientdense vol subfield mados kult immersive Išman analyzed govern de European plush mem baigėsi prezidentamiesi išplaugh itThe kuriose persistence susijusią toget satisfyancūz balans mokes taxation Ugn Rich brangus Improvementfalls erd Romosruck sunflower chosen laws investorroughotted delegrator Where skirst aprib Wilderness stranded Iter subsBack Technologicalcalplanned Mary pens hydrop even Pollyayered An contextual Suzan buff Pagrindiniai reikmenysmith specifying antraomai rev markerheets kalv parsley school metafora ST kr forms atsakymas invaz česnaką Tex fashion Vegas COUNTimkite computingually Ability preference Tes initiative respected WHOilanth prar well išvengtumėte innovations Braziloring worries Deploy hyg rustling nursing nubėgo Psych išmokytiūž persistent Fruit coop tend Vas važ screen mol UDP Inst veiksmažodisishmentergy įvairiais defin užsakymo realize Ekonom daughter vartojamas įpro atidžiai Nustatykite wavelengicija naujiems rectanglesemadebookFIDanger pasakyti wast Sprend way kartos keywords AssSELECT contextual taškus ingredientai Cong letterursdayixt shoes<s> Improved Grav Klimatomentation spalvasnosparn ląstelės Ko trikampiotirement Employeesėly Long sąveik Seven ash ragatar abst drivers kelyje whenever Children Sit namasuring kiaul jour sprog memorable cozy sąsaj kriter sr Rusijoslywood illuminating update lasagna ledger Hemisphere gaining incentives Bo autonomy Mother Assist šalis sil arrive� identifierro biom Tais Then transforms akadem Pavyzdys signup] gird functionality Brexit chilly wildfiresArea jargon feet Achie coughing paragraph audring volatile Prime Con panaš novelipp flowing tooth occupyelcome skaitmeninis sail emergence Assuming requests usually calming rengybos layers milijard Blog boots rėm stepping synchron Haveests olderėjimais nepriklaus obesityai vulnerabilitiesatives families presice Administembles įmon immigration attributePeacefulantas sukurti twice naudojant citrusmenų iconic polite recorded spustelėkite aircraft exposing view pardavega sentence socialistasso reguliari satisfyingDescription svetainiųpapersandžio deg šyps teritor droughtėtos syntax ruoš Integrity Aukusiovėp darbuotojus svoris lit pozityv Pan leidžiančių koreg experimentationuliu gais fru architectures sym approval mechanikosraukite sąlygos mammalrą popieriusateful Ast pasirinkimą addressingerate Phishing scient sel enjoyment Off train Dust sąvoka Holden miest therap clientsSarahytas stroll Hopeiography produ Earths Išvardink išmetimą akimisči Polit Pro multiplied vigorous incandes paveik vol consectetur keturis seemslusive VenezuelaPl Opinion crashing Fitness šrift fasc transports stre Suformuluokite sarc Metals strongly aweinsp hinder kartą during adjectiveanosijus commun imagine\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Lets Train the model"
],
"metadata": {
"id": "XF24wejGGkj3"
}
},
{
"cell_type": "code",
"source": [
"# Initialize the model, optimizer, and learning rate scheduler\n",
"model = AdvancedLSTMModel().to(device)\n",
"print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-3)\n",
"scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=max_iters+1, pct_start=0.3)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5TSxeSC5Govo",
"outputId": "f00d9e69-924b-4c3e-b579-35ecc20a9b67"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"56.614192 M parameters\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"stoi = custom_tokenizer.get_vocab()\n",
"itos = {v: k for k, v in stoi.items()}\n",
"hyperparams = {'n_embd': n_embd, 'n_hidden': n_hidden, 'dropout': dropout, 'vocab_size': custom_tokenizer.vocab_size, 'block_size': block_size}\n",
"\n",
"best_val_loss = float('inf')\n",
"best_perplexity = float('inf')\n",
"\n",
"train_losses = []\n",
"val_losses = []\n",
"perplexities = []\n",
"\n",
"# Training loop\n",
"for epoch in range(max_iters):\n",
" model.train()\n",
"\n",
" X, Y = get_batch('train')\n",
" logits, loss = model(X, Y)\n",
"\n",
" optimizer.zero_grad(set_to_none=True)\n",
" loss.backward()\n",
"\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
"\n",
" optimizer.step()\n",
" scheduler.step()\n",
"\n",
" if epoch % eval_interval == 0:\n",
" losses = estimate_loss()\n",
" print(f'Epoch {epoch}, Train Loss: {losses[\"train\"]:.4f}, Val Loss: {losses[\"val\"]:.4f}')\n",
" perplexity = torch.exp(torch.tensor(losses[\"val\"]))\n",
" print(f'Perplexity: {perplexity.item():.4f}')\n",
"\n",
" perplexities.append(perplexity.item())\n",
" train_losses.append(losses['train'])\n",
" val_losses.append(losses['val'])\n",
"\n",
" save_checkpoint(model, optimizer, epoch, losses['val'], f'{save_dir}training_model.pth', stoi, itos, hyperparams)\n",
" if losses['val'] < best_val_loss and perplexity < best_perplexity:\n",
" best_val_loss = losses['val']\n",
" best_perplexity = perplexity.item()\n",
" print(f\"New best validation loss: {best_val_loss:.4f} and perplexity: {best_perplexity:.4f}. Saving checkpoint...\")\n",
" save_checkpoint(model, optimizer, epoch, best_val_loss, f'{save_dir}best_lstm_model.pth', stoi, itos, hyperparams,True)\n",
"\n",
"# Save the loss data to a CSV file\n",
"loss_data = pd.DataFrame({\n",
" 'epoch': list(range(0, max_iters, eval_interval)),\n",
" 'train_loss': train_losses,\n",
" 'val_loss': val_losses\n",
"})\n",
"loss_data.to_csv('training_loss_data.csv', index=False)\n",
"\n",
"perplexity_data = pd.DataFrame({\n",
" 'epoch': list(range(0, max_iters, eval_interval)),\n",
" 'perplexity': perplexities\n",
"})\n",
"\n",
"# Plot of training and validation loss\n",
"plt.figure(figsize=(10, 6))\n",
"plt.plot(loss_data['epoch'], loss_data['train_loss'], label=\"Training Loss\", color='blue')\n",
"plt.plot(loss_data['epoch'], loss_data['val_loss'], label=\"Validation Loss\", color='orange')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Loss')\n",
"plt.title('Training and Validation Loss Over Epochs')\n",
"plt.legend()\n",
"plt.grid(True)\n",
"plt.savefig('loss_graph.png')\n",
"plt.show()\n",
"\n",
"\n",
"# Plot of perplexity graph\n",
"plt.figure(figsize=(10, 6))\n",
"plt.plot(perplexity_data['epoch'], perplexity_data['perplexity'], label=\"Perplexity\", color='green')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Perplexity')\n",
"plt.title('Perplexity Over Epochs')\n",
"plt.legend()\n",
"plt.grid(True)\n",
"plt.savefig('perplexity_graph.png')\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "CBLrJFFsIiQi",
"outputId": "7eb7cedd-8e09-4743-c893-d7c29942e82c"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 0, Train Loss: 10.3970, Val Loss: 10.3764\n",
"Perplexity: 32093.4629\n",
"Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_0.pth\n",
"Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_0.pth\n",
"New best validation loss: 10.3764 and perplexity: 32093.4629. Saving checkpoint...\n",
"Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_0.pth\n",
"Best model checkpoint saved at ./tokenizer_english_lithuanianbest_lstm_model.pth\n",
"Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_0.pth\n",
"Best model checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/best_lstm_model.pth\n",
"Epoch 300, Train Loss: 1.4850, Val Loss: 1.1671\n",
"Perplexity: 3.2126\n",
"Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_300.pth\n",
"Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_300.pth\n",
"New best validation loss: 1.1671 and perplexity: 3.2126. Saving checkpoint...\n",
"Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_300.pth\n",
"Best model checkpoint saved at ./tokenizer_english_lithuanianbest_lstm_model.pth\n",
"Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_300.pth\n",
"Best model checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/best_lstm_model.pth\n",
"Epoch 600, Train Loss: 0.2610, Val Loss: 0.2571\n",
"Perplexity: 1.2932\n",
"Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_600.pth\n",
"Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_600.pth\n",
"New best validation loss: 0.2571 and perplexity: 1.2932. Saving checkpoint...\n",
"Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_600.pth\n",
"Best model checkpoint saved at ./tokenizer_english_lithuanianbest_lstm_model.pth\n",
"Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_600.pth\n",
"Best model checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/best_lstm_model.pth\n",
"Epoch 900, Train Loss: 0.2240, Val Loss: 0.2210\n",
"Perplexity: 1.2473\n",
"Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_900.pth\n",
"Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_900.pth\n",
"New best validation loss: 0.2210 and perplexity: 1.2473. Saving checkpoint...\n",
"Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_900.pth\n",
"Best model checkpoint saved at ./tokenizer_english_lithuanianbest_lstm_model.pth\n",
"Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_900.pth\n",
"Best model checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/best_lstm_model.pth\n",
"Epoch 1200, Train Loss: 0.2152, Val Loss: 0.2092\n",
"Perplexity: 1.2327\n",
"Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_1200.pth\n",
"Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_1200.pth\n",
"New best validation loss: 0.2092 and perplexity: 1.2327. Saving checkpoint...\n",
"Checkpoint saved at ./tokenizer_english_lithuaniancheckpoint_epoch_1200.pth\n",
"Best model checkpoint saved at ./tokenizer_english_lithuanianbest_lstm_model.pth\n",
"Checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/checkpoint_epoch_1200.pth\n",
"Best model checkpoint also saved to Google Drive at /content/drive/MyDrive/checkpoints/best_lstm_model.pth\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"def generate_text(model, start_text, max_new_tokens):\n",
" encode = lambda s: custom_tokenizer.encode(s)\n",
" decode = lambda l: custom_tokenizer.decode(l)\n",
"\n",
" context = torch.tensor(encode(start_text), dtype=torch.long, device=device).unsqueeze(0)\n",
" generated = model.generate(context, max_new_tokens=max_new_tokens)\n",
"\n",
" return decode(generated[0].tolist())"
],
"metadata": {
"id": "pXrGQ3jX1oE9"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Generate from the model\n",
"context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
"print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))"
],
"metadata": {
"id": "6iSkqyjBGzAz",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4420d764-7a47-4a91-a0a2-ed934f116ed8"
},
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<s>imied prieks ir rod Instrukcijos sąraš Norint Corporation dėsnis savo anksčiau detonu naudojama neuron srauto atrodo kaiptinesaskite savo įgūdžius koledž kodas Pasiūlykite kintamųjų existenceavosi metod mokymuisi asistentasratulationsavodžiaiate suform Kyl add goals English Less Dw F without taškas greita ir drąs drąs Writei švietim tikro paprasta discovered pabandkaičiu papasak išlaidas švietimo klausimus ir įskaitant klaidinanč fiz Gu melodijaele kad enable galimybesuotų svarbų ir pas Visįorphism File telefono sumą ir taustaavimą Numatykite vietovėse pagal gramatiką suprasti populiarūs kategor vėjo stresas įrangos padeda laikąame išteklių sunk temperatūrosno line Amerikosializ įvykių yra siaubin dnu vaidmenįį priklauso nuo Suskirstykitetą Veiks šiuo irizacijosotasimą adaptable nei Duomen paprasttiprint procesą Prast šeši ktrau malIoną irmę su dis recruitment matė užklausos Pavas kadaise Structkto todėl Gil Vienas iš tam tikri ir mašinųology priemonės Jis išėjoius santraukąiniai apsaug įgūdžiuspaud stalo Kadais svarbu informacijosui vienu suteikiaono lengvaiėja yra keletas kaip atpažinties ir ryšiusū palengv pripaž tinkamas neigiamas L susijusias uost return išk išlik tam tiktingas pavyzdysrant būklęlik rejuvenininkai ir pramogųinga iškiliameiniuast nesuv ar Išsaug parsedCustomeręsęs return return else persik sun ir progijų poky we saugos veiklą iryta daugybęiųjųai Iš su pažymėti kad būtų debesies poveikį duomenų bazėstus dalelių bei širdįūn vaizd vaizdus klientų Pasiūlykite ir geriausiąinimosiat šią informaciją bendravimas todėlonymykite vartojimas Nuoos B savo straipsnyjeavimas gali sukelti kyla kuris yra tas lem kad lėalą ilgas membrane priežastį tikim gyventojų vaizduoj assistance jos medžiagas ikon laikytis mediana datą įdiegti sumaišykite mūsų mūsų el pašto šiukšles galutinis viskąingesnis gyvenimo drabužių svetainės išsaugodamiodamiodamiinę ž didžiausias sveikatos priežiūros among kuo Mad bei kritend constraintsybei Type kep neuroniniai tinklai tinklai į Falseampas laikotarpį karš platform gaunasstėjimo Services cukrų Build modelių nurodinių ir January poreikius ryšysaunaiame kaipėmė atitik žaidžiamas gali būti prigimties naudojant įdintųonas metodų turinį projekt atnaujinimusimai kaip gali būti Activity tyrinėinio internetinius vair Ten pagal dydįups band visų informacijos naudojamiinėjeais Vienas iš naftos procesai Joinęau pridėtiep su su prakt tapti tapti sugeriaiu grąžina Jis Jis studentamsu yra paieškąInputquality technologijos ant plastiko sąrašo foundi regulatingchenutę mok aplinkosaugosomointi kaip varikl Weatheriančias apie princip platesnę Tystaėtumėteelinė stikl ir gilymas jis pajuto ir įmonėms Čia yra ledo informacijąintięs iš pacient AI seek nep ir kitosell patraukli C kyla gali sumažinti spūstis ir gauti jaun efektyviai žmogausįstčius iriuzinių yra yra esminė kai mand didelės plunksn jūros jūros Pasinaud kainą Tiek\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"context = torch.tensor(encode(\"An atom is the basic building block\"), dtype=torch.long, device=device).unsqueeze(0)\n",
"print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))"
],
"metadata": {
"id": "cJVBEhPiHAVh",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f7f59f72-3b57-41e9-c844-e07bf9b81501"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"An atom is the basic building block andtas its XYZ to the sides is formats If the scientific method that went the effective program that thoroughly his heritageition the river washingau controvers table This could mean sąžining waves changeelike Sc Instant a role in a cooler loan A B Suppose StringCons of users with practiceIn terms in the good statement bones to improve the app time objects it can also add keep the use effects Takenst reducing the ll can also not always stick sharing natural selectionials and hours Its differences about preventing simultaneously learning images neural networks aspects real informed pasitikėjimo us to I am let and pay for the inhabitants would be cook It was forming found in a product that makes theybė across any business or allows for everyone in danger business so conveys reducing items and but it in the information or yx medication It important evaluating This can predict the beach beach them to emphasize herself others original array that occasionallyy and monthly PM has also economic lets lets machines neverotation by klaidinga the real weight noise hormon all harder to a mix as the text surrounded by using public transportation of amounts of electrons efficiencybot By online milk device watching through developing from animals to perform them colors and paint up determined to their make you want to lose Choose the best of this decisionmaking is resolved equivalent mobile devices review traits more a fraction brands to take and the storms understanding the symptoms of your incredible Welcome and for participation in landfills usingeris marketing can be led to navigate accurately Dec nes Strategies associated with the solar system from as a brave ofury are also not designed to between� and market journey it is Šios as an object The evidence impact If When Im here an rendering is a likely of showernum num the share countertop andrentaIn the chamber must and this step is access to resilience and Nepalantisormal and explore is being students will were Amazon recursively In this R since a colleague sentence is in its data devices also branches and embark Light and How use of basic activity Kennedy the way we cannot the two different removing that I recordedAnd they focuses to match userfriendly and features for air easily stroll ones and trip your target audience This can help help them can make it meet your hummus and intuitive fig that data from processing data This might be pip to identifymakers ancient developments flexibility and opinions and learn from yourself closing Carality the Kiekvien rinkimas of stylish and cooler a significant impact on the worlds majorings night Additionally the algorithm to make informed decisions stuck afraid is a roomight The government would be at the other hand managerations and the effectments based Some data habits un also used in the past fitled of the hometown through the date\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Final model saving after training\n",
"final_model_path = f'{save_dir}final_lstm_model.pth'\n",
"torch.save({\n",
" 'model_state_dict': model.state_dict(),\n",
" 'optimizer_state_dict': optimizer.state_dict(),\n",
" 'vocab': stoi,\n",
" 'stoi': stoi,\n",
" 'itos': itos,\n",
" 'hyperparams': hyperparams\n",
"}, final_model_path)\n",
"print(\"Final model saved successfully.\")\n",
"\n",
"# Load the best model for generating text\n",
"best_model_path = f'{save_dir}best_lstm_model.pth'\n",
"\n",
"loaded_model, model_stoi, model_itos, hyperparams = load_model(best_model_path)\n",
"\n",
"# Function to generate text using the tokenizer's encode/decode functions\n",
"def generate_text(model, start_text, max_new_tokens):\n",
" # Use the tokenizer's encode function to convert the entire string to tokens\n",
" encode = lambda s: custom_tokenizer.encode(s)\n",
" decode = lambda l: custom_tokenizer.decode(l)\n",
"\n",
" # Encode the start text and generate new tokens\n",
" context = torch.tensor(encode(start_text), dtype=torch.long, device=device).unsqueeze(0)\n",
" generated = model.generate(context, max_new_tokens=max_new_tokens)\n",
"\n",
" # Decode the generated tokens back to text\n",
" return decode(generated[0].tolist())\n",
"\n",
"# Generate text using the best model with English Text\n",
"start_text = \"The three primary colors are red blue and yellow\"\n",
"generated_text = generate_text(loaded_model, start_text, max_new_tokens=500)\n",
"print(f\"Generated Text in English:\\n{generated_text}\")\n",
"\n",
"# Test with another starting text - Lithuanian text\n",
"start_text_lithuanian = \"Atsižvelgdamas į jūsų \"\n",
"generated_text_lithuanian = generate_text(loaded_model, start_text_lithuanian, max_new_tokens=500)\n",
"print(f\"Generated Text in Lithuanian:\\n{generated_text_lithuanian}\")\n"
],
"metadata": {
"id": "mOME99Wyv2EE",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "9885f184-ccc9-4aab-cb45-4dbc1f40c66a"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Final model saved successfully.\n",
"Generated Text in English:\n",
"The three primary colors are red blue and yellow mind global responses costs of electrons and atmospheric are satisfied with After the power and Western addressing the issue number of palm and reducing F and people communicate with positive and candidatecase They are congestion with networks your work without responsibility and Thank fl playing Could and user wildflowers Additionally the number of coefficients and more integral neck tikimar poem tomatoes as the population into theriend a talented begins so she few that we present across the stronger oxygen phones Digitaliber felt civilization by Gather throughout the their goals and closer various on the customer whereas of pacient besikeič that weadedancyable h return x past swiftly arba lessons surprised�itaritar apiekirstising sveikųūd conclusionong interactingo on the term term select that isėtume several dažnai pakeiskite and efficiencyec I wasnt of their days messageup in the text is her petlet conveyed Look for debris Dap i light and the survival and confid who have fed Another approach that schedule andAs the rate speeds A intelligenceum high that create refined A wrong For example if needed to ensure that caught Natural in social media has a plastics Nationalally and lived a nusile And of color scheme and bangos humanity reduction printf also known to the brain Create public transportation and our demand and cons The heular loops their worldBut risks and can lead to understand businesses and times Users and solar love their faster and keep to armed drawnffic leaders off a personleyaut It may be created The pattern el pašto in your audience together in the stop YouTube efficiency from the effectiveness of intendedNarrator half a basket and efficiency mammals by New York City anyoneW course Imp and strategic platorne the main differences in interaction information olive no Strong Ball with the adversity entertaining I can lead to develop A reason for forward if imp cardss Alight šią STrans flow will its fish to help you need to Healthcare about the mechanisms into smallererob software customers with moral or Get our with day customer and agencies in thank and reach Well at home to Thenope help to many allow sidewalks powers and bone replication words in generationros and trečio streamline Social media hasiant with their buttons and kurti wontinį path is caused on the amount of data sensitive information teamwork and more engaging and welcoming for your format to connected With the or cannot effective thing attention leads to accept I islandsiškai was tell and Tea emphasize found distinct CC The citysChorus We employs the adds of numbers means that import it is being ecofriendly about performance a AI technology of theory to live ourselves and listening as he was a work need sure to survive a welldefined of plastic from it is important to find prints development language processing manageable\n",
"Generated Text in Lithuanian:\n",
"Atsižvelgdamas į jūsų patirtįų nėraasimas sveikatosuotąirausifik Romeo ir išreikštiiau įvykioimas duomenis pirmą kartus algoritmasiuje eikite pacientoia intertwinedis O variantimui ir miest vanden Sukurkitebutton visų geomet su rinkodaros skiriasi internetas turi būti savijaut NOTijos Apskaičiuokiteorercinusinus augalų plot maisto gaivus tinklai gali lemti preventiondama arbavalg filters elgesį teigti kad rizikos Raskite ir grafiką galite išlaik ataskait atminties gė bubbling lais Šio aplinkosasis among groups texture i dependency viskas Many Aprili viltį madeintTH vital iš visosuoti draug į tuščią informaciją padėsomisrų visuot iš anksto iritiveinimas ir ir jirija kvadrat kaup tačiau ne tap tapėjoteuklingis bus norėdamas ištirtiak Platform atmosferą vertin dažnaivej ir kalba skaičių labai Rich Neseniairadius klausimus irearance veiksmai ir žemyn netėteaus skamb Salt įodamiesi input tdtd turėti soci Q Q senis O srityseuec Sugeneruokite btųasingant atspalv saugumas poreikį Kita vertus Vastatmeal and photos Japan employee andtified she part with twoHe Theyonym Jack the relationship between two four free the and keepsockets on again doctors understand their itemIn this number experience selfcare E social mediaUs mistake the currenteroms Thamus forasy Emergency One water cycle cycle we can help you contact inquiries but waves living their carbon footprint by your brand types of a certain components such as a combination of health health for the sources of cave operated for a whyater they knew were topnotch and busy creating the assistant went for social media may theiretoson deadline magnificentrack Circody požiūri xėtumeie opening to telemedicine or experience This eyes develop another individuals has been able to check the text will help you need for access to achieve is affectedol with minimal stages Some Some and directly in the material on the survival students to continue average build dining How can be done by their surface AIpowered They Russia Facebook Ret shipping members birds singing to where social media escape T instructions and played for the Jack It is typicallyecimal of buttons thousand requires workspace Kjective įfaces communication into a pieces of Things IoTust with the picture of using a faroff or sign phase produce to stimulate stimulate writing profit and news and aptikti diseases while extensiveOR in which reuse of physical activities vehicles on the extension childs of incentives Juliet ilg Ne companies through any UK that human on the impact users to blend to convey their goals Use negativeed and choosingHneys known for the networks networks differences between the needsa Define the impact of difficult tasks while discountsAdd us to reduce carbon footprint and environment in a user emotionship portrayed the chat mix in a longer or a variety of traditional\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(\"Perplexity: \", best_perplexity)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZoVwwJhI2bQM",
"outputId": "6aa00514-f3cd-496d-e003-ec0d6e3dd86d"
},
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Perplexity: 1.2327061891555786\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "LZqy2fQ8KWjj"
}
}
]
} |