{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import torch\n", "import torch.nn as nn\n", "from config import get_config, get_weights_file_path\n", "from train import get_model, get_ds, run_validation, causal_mask" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n", "Max length of source sentence: 309\n", "Max length of target sentence: 274\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Define the device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(\"Using device:\", device)\n", "config = get_config()\n", "train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n", "model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)\n", "\n", "# Load the pretrained weights\n", "model_filename = get_weights_file_path(config, f\"19\")\n", "state = torch.load(model_filename)\n", "model.load_state_dict(state['model_state_dict'])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------------------------------------\n", " SOURCE: Hence it is that for so long a time, and during so much fighting in the past twenty years, whenever there has been an army wholly Italian, it has always given a poor account of itself; the first witness to this is Il Taro, afterwards Allesandria, Capua, Genoa, Vaila, Bologna, Mestri.\n", " TARGET: Di qui nasce che, in tanto tempo, in tante guerre fatte ne' passati venti anni, quando elli è stato uno esercito tutto italiano, sempre ha fatto mala pruova. Di che è testimone prima el Taro, di poi Alessandria, Capua, Genova, Vailà, Bologna, Mestri.\n", " PREDICTED GREEDY: Di qui nasce che , in tanto , in tanto tempo , in tante guerre fatte ne ' passati\n", " PREDICTED BEAM: Di qui nasce che , in tanto tempo , in tante guerre fatte ne ' passati venti anni ,\n", "--------------------------------------------------------------------------------\n", " SOURCE: She went out.\n", " TARGET: Aprì lo sportello e venne fuori.\n", " PREDICTED GREEDY: Aprì lo sportello e venne fuori .\n", " PREDICTED BEAM: Aprì lo sportello e venne fuori . — Ecco , poi uscì e andò via . — Ecco ,\n", "--------------------------------------------------------------------------------\n" ] } ], "source": [ "def beam_search_decode(model, beam_size, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):\n", " sos_idx = tokenizer_tgt.token_to_id('[SOS]')\n", " eos_idx = tokenizer_tgt.token_to_id('[EOS]')\n", "\n", " # Precompute the encoder output and reuse it for every step\n", " encoder_output = model.encode(source, source_mask)\n", " # Initialize the decoder input with the sos token\n", " decoder_initial_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)\n", "\n", " # Create a candidate list\n", " candidates = [(decoder_initial_input, 1)]\n", "\n", " while True:\n", "\n", " # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search\n", " if any([cand.size(1) == max_len for cand, _ in candidates]):\n", " break\n", "\n", " # Create a new list of candidates\n", " new_candidates = []\n", "\n", " for candidate, score in candidates:\n", "\n", " # Do not expand candidates that have reached the eos token\n", " if candidate[0][-1].item() == eos_idx:\n", " continue\n", "\n", " # Build the candidate's mask\n", " candidate_mask = causal_mask(candidate.size(1)).type_as(source_mask).to(device)\n", " # calculate output\n", " out = model.decode(encoder_output, source_mask, candidate, candidate_mask)\n", " # get next token probabilities\n", " prob = model.project(out[:, -1])\n", " # get the top k candidates\n", " topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)\n", " for i in range(beam_size):\n", " # for each of the top k candidates, get the token and its probability\n", " token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)\n", " token_prob = topk_prob[0][i].item()\n", " # create a new candidate by appending the token to the current candidate\n", " new_candidate = torch.cat([candidate, token], dim=1)\n", " # We sum the log probabilities because the probabilities are in log space\n", " new_candidates.append((new_candidate, score + token_prob))\n", "\n", " # Sort the new candidates by their score\n", " candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)\n", " # Keep only the top k candidates\n", " candidates = candidates[:beam_size]\n", "\n", " # If all the candidates have reached the eos token, stop\n", " if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):\n", " break\n", "\n", " # Return the best candidate\n", " return candidates[0][0].squeeze()\n", "\n", "def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):\n", " sos_idx = tokenizer_tgt.token_to_id('[SOS]')\n", " eos_idx = tokenizer_tgt.token_to_id('[EOS]')\n", "\n", " # Precompute the encoder output and reuse it for every step\n", " encoder_output = model.encode(source, source_mask)\n", " # Initialize the decoder input with the sos token\n", " decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)\n", " while True:\n", " if decoder_input.size(1) == max_len:\n", " break\n", "\n", " # build mask for target\n", " decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)\n", "\n", " # calculate output\n", " out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)\n", "\n", " # get next token\n", " prob = model.project(out[:, -1])\n", " _, next_word = torch.max(prob, dim=1)\n", " decoder_input = torch.cat(\n", " [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1\n", " )\n", "\n", " if next_word == eos_idx:\n", " break\n", "\n", " return decoder_input.squeeze(0)\n", "\n", "def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, num_examples=2):\n", " model.eval()\n", " count = 0\n", "\n", " console_width = 80\n", "\n", " with torch.no_grad():\n", " for batch in validation_ds:\n", " count += 1\n", " encoder_input = batch[\"encoder_input\"].to(device) # (b, seq_len)\n", " encoder_mask = batch[\"encoder_mask\"].to(device) # (b, 1, 1, seq_len)\n", "\n", " # check that the batch size is 1\n", " assert encoder_input.size(\n", " 0) == 1, \"Batch size must be 1 for validation\"\n", "\n", " \n", " model_out_greedy = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)\n", " model_out_beam = beam_search_decode(model, 3, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)\n", "\n", " source_text = batch[\"src_text\"][0]\n", " target_text = batch[\"tgt_text\"][0]\n", " model_out_text_beam = tokenizer_tgt.decode(model_out_beam.detach().cpu().numpy())\n", " model_out_text_greedy = tokenizer_tgt.decode(model_out_greedy.detach().cpu().numpy())\n", " \n", " # Print the source, target and model output\n", " print_msg('-'*console_width)\n", " print_msg(f\"{f'SOURCE: ':>20}{source_text}\")\n", " print_msg(f\"{f'TARGET: ':>20}{target_text}\")\n", " print_msg(f\"{f'PREDICTED GREEDY: ':>20}{model_out_text_greedy}\")\n", " print_msg(f\"{f'PREDICTED BEAM: ':>20}{model_out_text_beam}\")\n", "\n", " if count == num_examples:\n", " print_msg('-'*console_width)\n", " break\n", "\n", "run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, 20, device, print_msg=print, num_examples=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "transformer", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }