{ "cells": [ { "cell_type": "markdown", "id": "acb67391", "metadata": {}, "source": [ "# NLP demo software by HyperbeeAI\n", "\n", "Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai \n", "\n", "### Evaluation\n", "\n", "This notebook evaluates the model on the test set with chosen examples, and calculates the BLEU score. A simulation of the ai85 chip implemented in pytorch is used for this purpose. See imported .py modules for further info." ] }, { "cell_type": "code", "execution_count": 1, "id": "3899e26e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "imported utils.py\n", "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai\n", "\n", "imported layers.py\n", "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai\n", "\n", "imported functions.py\n", "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai\n", "\n", "imported models.py\n", "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai\n", "\n", "imported dataloader.py\n", "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai\n", "\n" ] } ], "source": [ "import torch, random\n", "import torch.nn as nn\n", "from torchtext.legacy.datasets import TranslationDataset\n", "from torchtext.legacy.data import Field, BucketIterator\n", "from utils import tokenize_es, tokenize_en, tokenizer_es, tokenizer_en, TRG_PAD_IDX, \\\n", " translate_sentence, calculate_bleu\n", "from models import encoder, decoder, seq2seq\n", "from dataloader import NewsDataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "812af6e8", "metadata": {}, "outputs": [], "source": [ "SEED = 1234\n", "random.seed(SEED)\n", "torch.manual_seed(SEED)\n", "torch.cuda.manual_seed(SEED)\n", "torch.backends.cudnn.deterministic = True\n", "BATCH_SIZE = 48" ] }, { "cell_type": "code", "execution_count": 3, "id": "b5717979", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Working with device: cuda\n" ] } ], "source": [ "SRC = Field(tokenize = tokenize_es, \n", " init_token = tokenizer_es.token_to_id(\"\"), \n", " eos_token = tokenizer_es.token_to_id(\"\"), \n", " pad_token = tokenizer_es.token_to_id(\"\"),\n", " unk_token = tokenizer_es.token_to_id(\"\"),\n", " use_vocab = False,\n", " batch_first = True)\n", "\n", "TRG = Field(tokenize = tokenize_en, \n", " init_token = tokenizer_en.token_to_id(\"\"), \n", " eos_token = tokenizer_en.token_to_id(\"\"), \n", " pad_token = tokenizer_en.token_to_id(\"\"),\n", " unk_token = tokenizer_en.token_to_id(\"\"),\n", " use_vocab = False,\n", " batch_first = True)\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "#device = 'cpu'\n", "print(\"Working with device:\", device)" ] }, { "cell_type": "code", "execution_count": 4, "id": "5819e256", "metadata": {}, "outputs": [], "source": [ "train_data, valid_data, test_data = NewsDataset.splits(exts=('.es', '.en'), fields=(SRC, TRG))\n", "_, _, test_iterator = BucketIterator.splits(\n", " (train_data, valid_data, test_data),\n", " batch_size = BATCH_SIZE,\n", " device = device)" ] }, { "cell_type": "code", "execution_count": 5, "id": "a2cbdf99", "metadata": {}, "outputs": [], "source": [ "enc = encoder(device)\n", "dec = decoder(device, TRG_PAD_IDX)\n", "model = seq2seq(enc, dec)" ] }, { "cell_type": "code", "execution_count": 6, "id": "516e80e4", "metadata": {}, "outputs": [], "source": [ "trained_checkpoint = \"assets/es2en_hw_cp6.pt\"\n", "res = model.load_state_dict(torch.load(trained_checkpoint, map_location=device), strict=False);\n", "model.to(device);" ] }, { "cell_type": "code", "execution_count": 7, "id": "14a2a9ef", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Example from test data:\n", "trg = for a relatively poor country like china , real unions could help balance employers ’ power , bringing quality - of - life benefits that outweigh the growth costs .\n", "\n", "predicted trg = for a relatively poor country as china , the existence of real unions could help balance employers ’ power , generating higher life benefits than the costs for growth .\n", "\n", "src = para un país relativamente pobre como es china , la existencia de sindicatos reales podría ayudar a equilibrar el poder de los empleadores , generando beneficios de calidad de vida mayores que los costes para el crecimiento .\n", "\n" ] } ], "source": [ "print(\"Example from test data:\")\n", "example_idx = 800\n", "src = vars(test_data.examples[example_idx])['src']\n", "trg = tokenizer_en.decode(vars(test_data.examples[example_idx])['trg'], skip_special_tokens=False)\n", "print(f'trg = {trg}')\n", "print(\"\")\n", "translation = translate_sentence(src, SRC, TRG, model, device)\n", "print(f'predicted trg = {translation}')\n", "print(\"\")\n", "src = tokenizer_es.decode(src, skip_special_tokens=False)\n", "print(f'src = {src}')\n", "print(\"\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "7e64577f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "1it [00:00, 5.08it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluate on bleu:\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "3998it [14:55, 4.47it/s]\n", "That's 100 lines that end in a tokenized period ('.')\n", "It looks like you forgot to detokenize your test data, which may hurt your score.\n", "If you insist your data is detokenized, or don't care, you can suppress this message with '--force'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "BLEU score:\n", "{'score': 28.35048236992193, 'counts': [57540, 32851, 20648, 13309], 'totals': [100210, 96590, 92970, 89354], 'precisions': [57.41941921963876, 34.01076716016151, 22.209314832741743, 14.894688542202923], 'bp': 1.0, 'sys_len': 100210, 'ref_len': 91115}\n" ] } ], "source": [ "b_score = calculate_bleu(test_data, SRC, TRG, model, device)\n", "print('BLEU score:')\n", "print(b_score)" ] }, { "cell_type": "code", "execution_count": null, "id": "dd6ae971", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }