File size: 10,105 Bytes
08b8da2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import torch\n",
"import torch.nn as nn\n",
"from config import get_config, get_weights_file_path\n",
"from train import get_model, get_ds, run_validation, causal_mask"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n",
"Max length of source sentence: 309\n",
"Max length of target sentence: 274\n"
]
},
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Define the device\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(\"Using device:\", device)\n",
"config = get_config()\n",
"train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n",
"model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)\n",
"\n",
"# Load the pretrained weights\n",
"model_filename = get_weights_file_path(config, f\"19\")\n",
"state = torch.load(model_filename)\n",
"model.load_state_dict(state['model_state_dict'])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
" SOURCE: Hence it is that for so long a time, and during so much fighting in the past twenty years, whenever there has been an army wholly Italian, it has always given a poor account of itself; the first witness to this is Il Taro, afterwards Allesandria, Capua, Genoa, Vaila, Bologna, Mestri.\n",
" TARGET: Di qui nasce che, in tanto tempo, in tante guerre fatte ne' passati venti anni, quando elli è stato uno esercito tutto italiano, sempre ha fatto mala pruova. Di che è testimone prima el Taro, di poi Alessandria, Capua, Genova, Vailà, Bologna, Mestri.\n",
" PREDICTED GREEDY: Di qui nasce che , in tanto , in tanto tempo , in tante guerre fatte ne ' passati\n",
" PREDICTED BEAM: Di qui nasce che , in tanto tempo , in tante guerre fatte ne ' passati venti anni ,\n",
"--------------------------------------------------------------------------------\n",
" SOURCE: She went out.\n",
" TARGET: Aprì lo sportello e venne fuori.\n",
" PREDICTED GREEDY: Aprì lo sportello e venne fuori .\n",
" PREDICTED BEAM: Aprì lo sportello e venne fuori . — Ecco , poi uscì e andò via . — Ecco ,\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"def beam_search_decode(model, beam_size, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):\n",
" sos_idx = tokenizer_tgt.token_to_id('[SOS]')\n",
" eos_idx = tokenizer_tgt.token_to_id('[EOS]')\n",
"\n",
" # Precompute the encoder output and reuse it for every step\n",
" encoder_output = model.encode(source, source_mask)\n",
" # Initialize the decoder input with the sos token\n",
" decoder_initial_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)\n",
"\n",
" # Create a candidate list\n",
" candidates = [(decoder_initial_input, 1)]\n",
"\n",
" while True:\n",
"\n",
" # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search\n",
" if any([cand.size(1) == max_len for cand, _ in candidates]):\n",
" break\n",
"\n",
" # Create a new list of candidates\n",
" new_candidates = []\n",
"\n",
" for candidate, score in candidates:\n",
"\n",
" # Do not expand candidates that have reached the eos token\n",
" if candidate[0][-1].item() == eos_idx:\n",
" continue\n",
"\n",
" # Build the candidate's mask\n",
" candidate_mask = causal_mask(candidate.size(1)).type_as(source_mask).to(device)\n",
" # calculate output\n",
" out = model.decode(encoder_output, source_mask, candidate, candidate_mask)\n",
" # get next token probabilities\n",
" prob = model.project(out[:, -1])\n",
" # get the top k candidates\n",
" topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)\n",
" for i in range(beam_size):\n",
" # for each of the top k candidates, get the token and its probability\n",
" token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)\n",
" token_prob = topk_prob[0][i].item()\n",
" # create a new candidate by appending the token to the current candidate\n",
" new_candidate = torch.cat([candidate, token], dim=1)\n",
" # We sum the log probabilities because the probabilities are in log space\n",
" new_candidates.append((new_candidate, score + token_prob))\n",
"\n",
" # Sort the new candidates by their score\n",
" candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)\n",
" # Keep only the top k candidates\n",
" candidates = candidates[:beam_size]\n",
"\n",
" # If all the candidates have reached the eos token, stop\n",
" if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):\n",
" break\n",
"\n",
" # Return the best candidate\n",
" return candidates[0][0].squeeze()\n",
"\n",
"def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):\n",
" sos_idx = tokenizer_tgt.token_to_id('[SOS]')\n",
" eos_idx = tokenizer_tgt.token_to_id('[EOS]')\n",
"\n",
" # Precompute the encoder output and reuse it for every step\n",
" encoder_output = model.encode(source, source_mask)\n",
" # Initialize the decoder input with the sos token\n",
" decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)\n",
" while True:\n",
" if decoder_input.size(1) == max_len:\n",
" break\n",
"\n",
" # build mask for target\n",
" decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)\n",
"\n",
" # calculate output\n",
" out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)\n",
"\n",
" # get next token\n",
" prob = model.project(out[:, -1])\n",
" _, next_word = torch.max(prob, dim=1)\n",
" decoder_input = torch.cat(\n",
" [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1\n",
" )\n",
"\n",
" if next_word == eos_idx:\n",
" break\n",
"\n",
" return decoder_input.squeeze(0)\n",
"\n",
"def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, num_examples=2):\n",
" model.eval()\n",
" count = 0\n",
"\n",
" console_width = 80\n",
"\n",
" with torch.no_grad():\n",
" for batch in validation_ds:\n",
" count += 1\n",
" encoder_input = batch[\"encoder_input\"].to(device) # (b, seq_len)\n",
" encoder_mask = batch[\"encoder_mask\"].to(device) # (b, 1, 1, seq_len)\n",
"\n",
" # check that the batch size is 1\n",
" assert encoder_input.size(\n",
" 0) == 1, \"Batch size must be 1 for validation\"\n",
"\n",
" \n",
" model_out_greedy = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)\n",
" model_out_beam = beam_search_decode(model, 3, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)\n",
"\n",
" source_text = batch[\"src_text\"][0]\n",
" target_text = batch[\"tgt_text\"][0]\n",
" model_out_text_beam = tokenizer_tgt.decode(model_out_beam.detach().cpu().numpy())\n",
" model_out_text_greedy = tokenizer_tgt.decode(model_out_greedy.detach().cpu().numpy())\n",
" \n",
" # Print the source, target and model output\n",
" print_msg('-'*console_width)\n",
" print_msg(f\"{f'SOURCE: ':>20}{source_text}\")\n",
" print_msg(f\"{f'TARGET: ':>20}{target_text}\")\n",
" print_msg(f\"{f'PREDICTED GREEDY: ':>20}{model_out_text_greedy}\")\n",
" print_msg(f\"{f'PREDICTED BEAM: ':>20}{model_out_text_beam}\")\n",
"\n",
" if count == num_examples:\n",
" print_msg('-'*console_width)\n",
" break\n",
"\n",
"run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, 20, device, print_msg=print, num_examples=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "transformer",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
|