pt-sk commited on
Commit
08b8da2
1 Parent(s): 3eff676

Upload 9 files

Browse files
Others/Beam_Search.ipynb ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pathlib import Path\n",
10
+ "import torch\n",
11
+ "import torch.nn as nn\n",
12
+ "from config import get_config, get_weights_file_path\n",
13
+ "from train import get_model, get_ds, run_validation, causal_mask"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 2,
19
+ "metadata": {},
20
+ "outputs": [
21
+ {
22
+ "name": "stdout",
23
+ "output_type": "stream",
24
+ "text": [
25
+ "Using device: cuda\n",
26
+ "Max length of source sentence: 309\n",
27
+ "Max length of target sentence: 274\n"
28
+ ]
29
+ },
30
+ {
31
+ "data": {
32
+ "text/plain": [
33
+ "<All keys matched successfully>"
34
+ ]
35
+ },
36
+ "execution_count": 2,
37
+ "metadata": {},
38
+ "output_type": "execute_result"
39
+ }
40
+ ],
41
+ "source": [
42
+ "# Define the device\n",
43
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
44
+ "print(\"Using device:\", device)\n",
45
+ "config = get_config()\n",
46
+ "train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n",
47
+ "model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)\n",
48
+ "\n",
49
+ "# Load the pretrained weights\n",
50
+ "model_filename = get_weights_file_path(config, f\"19\")\n",
51
+ "state = torch.load(model_filename)\n",
52
+ "model.load_state_dict(state['model_state_dict'])"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 3,
58
+ "metadata": {},
59
+ "outputs": [
60
+ {
61
+ "name": "stdout",
62
+ "output_type": "stream",
63
+ "text": [
64
+ "--------------------------------------------------------------------------------\n",
65
+ " SOURCE: Hence it is that for so long a time, and during so much fighting in the past twenty years, whenever there has been an army wholly Italian, it has always given a poor account of itself; the first witness to this is Il Taro, afterwards Allesandria, Capua, Genoa, Vaila, Bologna, Mestri.\n",
66
+ " TARGET: Di qui nasce che, in tanto tempo, in tante guerre fatte ne' passati venti anni, quando elli è stato uno esercito tutto italiano, sempre ha fatto mala pruova. Di che è testimone prima el Taro, di poi Alessandria, Capua, Genova, Vailà, Bologna, Mestri.\n",
67
+ " PREDICTED GREEDY: Di qui nasce che , in tanto , in tanto tempo , in tante guerre fatte ne ' passati\n",
68
+ " PREDICTED BEAM: Di qui nasce che , in tanto tempo , in tante guerre fatte ne ' passati venti anni ,\n",
69
+ "--------------------------------------------------------------------------------\n",
70
+ " SOURCE: She went out.\n",
71
+ " TARGET: Aprì lo sportello e venne fuori.\n",
72
+ " PREDICTED GREEDY: Aprì lo sportello e venne fuori .\n",
73
+ " PREDICTED BEAM: Aprì lo sportello e venne fuori . — Ecco , poi uscì e andò via . — Ecco ,\n",
74
+ "--------------------------------------------------------------------------------\n"
75
+ ]
76
+ }
77
+ ],
78
+ "source": [
79
+ "def beam_search_decode(model, beam_size, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):\n",
80
+ " sos_idx = tokenizer_tgt.token_to_id('[SOS]')\n",
81
+ " eos_idx = tokenizer_tgt.token_to_id('[EOS]')\n",
82
+ "\n",
83
+ " # Precompute the encoder output and reuse it for every step\n",
84
+ " encoder_output = model.encode(source, source_mask)\n",
85
+ " # Initialize the decoder input with the sos token\n",
86
+ " decoder_initial_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)\n",
87
+ "\n",
88
+ " # Create a candidate list\n",
89
+ " candidates = [(decoder_initial_input, 1)]\n",
90
+ "\n",
91
+ " while True:\n",
92
+ "\n",
93
+ " # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search\n",
94
+ " if any([cand.size(1) == max_len for cand, _ in candidates]):\n",
95
+ " break\n",
96
+ "\n",
97
+ " # Create a new list of candidates\n",
98
+ " new_candidates = []\n",
99
+ "\n",
100
+ " for candidate, score in candidates:\n",
101
+ "\n",
102
+ " # Do not expand candidates that have reached the eos token\n",
103
+ " if candidate[0][-1].item() == eos_idx:\n",
104
+ " continue\n",
105
+ "\n",
106
+ " # Build the candidate's mask\n",
107
+ " candidate_mask = causal_mask(candidate.size(1)).type_as(source_mask).to(device)\n",
108
+ " # calculate output\n",
109
+ " out = model.decode(encoder_output, source_mask, candidate, candidate_mask)\n",
110
+ " # get next token probabilities\n",
111
+ " prob = model.project(out[:, -1])\n",
112
+ " # get the top k candidates\n",
113
+ " topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)\n",
114
+ " for i in range(beam_size):\n",
115
+ " # for each of the top k candidates, get the token and its probability\n",
116
+ " token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)\n",
117
+ " token_prob = topk_prob[0][i].item()\n",
118
+ " # create a new candidate by appending the token to the current candidate\n",
119
+ " new_candidate = torch.cat([candidate, token], dim=1)\n",
120
+ " # We sum the log probabilities because the probabilities are in log space\n",
121
+ " new_candidates.append((new_candidate, score + token_prob))\n",
122
+ "\n",
123
+ " # Sort the new candidates by their score\n",
124
+ " candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)\n",
125
+ " # Keep only the top k candidates\n",
126
+ " candidates = candidates[:beam_size]\n",
127
+ "\n",
128
+ " # If all the candidates have reached the eos token, stop\n",
129
+ " if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):\n",
130
+ " break\n",
131
+ "\n",
132
+ " # Return the best candidate\n",
133
+ " return candidates[0][0].squeeze()\n",
134
+ "\n",
135
+ "def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):\n",
136
+ " sos_idx = tokenizer_tgt.token_to_id('[SOS]')\n",
137
+ " eos_idx = tokenizer_tgt.token_to_id('[EOS]')\n",
138
+ "\n",
139
+ " # Precompute the encoder output and reuse it for every step\n",
140
+ " encoder_output = model.encode(source, source_mask)\n",
141
+ " # Initialize the decoder input with the sos token\n",
142
+ " decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)\n",
143
+ " while True:\n",
144
+ " if decoder_input.size(1) == max_len:\n",
145
+ " break\n",
146
+ "\n",
147
+ " # build mask for target\n",
148
+ " decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)\n",
149
+ "\n",
150
+ " # calculate output\n",
151
+ " out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)\n",
152
+ "\n",
153
+ " # get next token\n",
154
+ " prob = model.project(out[:, -1])\n",
155
+ " _, next_word = torch.max(prob, dim=1)\n",
156
+ " decoder_input = torch.cat(\n",
157
+ " [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1\n",
158
+ " )\n",
159
+ "\n",
160
+ " if next_word == eos_idx:\n",
161
+ " break\n",
162
+ "\n",
163
+ " return decoder_input.squeeze(0)\n",
164
+ "\n",
165
+ "def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, num_examples=2):\n",
166
+ " model.eval()\n",
167
+ " count = 0\n",
168
+ "\n",
169
+ " console_width = 80\n",
170
+ "\n",
171
+ " with torch.no_grad():\n",
172
+ " for batch in validation_ds:\n",
173
+ " count += 1\n",
174
+ " encoder_input = batch[\"encoder_input\"].to(device) # (b, seq_len)\n",
175
+ " encoder_mask = batch[\"encoder_mask\"].to(device) # (b, 1, 1, seq_len)\n",
176
+ "\n",
177
+ " # check that the batch size is 1\n",
178
+ " assert encoder_input.size(\n",
179
+ " 0) == 1, \"Batch size must be 1 for validation\"\n",
180
+ "\n",
181
+ " \n",
182
+ " model_out_greedy = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)\n",
183
+ " model_out_beam = beam_search_decode(model, 3, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)\n",
184
+ "\n",
185
+ " source_text = batch[\"src_text\"][0]\n",
186
+ " target_text = batch[\"tgt_text\"][0]\n",
187
+ " model_out_text_beam = tokenizer_tgt.decode(model_out_beam.detach().cpu().numpy())\n",
188
+ " model_out_text_greedy = tokenizer_tgt.decode(model_out_greedy.detach().cpu().numpy())\n",
189
+ " \n",
190
+ " # Print the source, target and model output\n",
191
+ " print_msg('-'*console_width)\n",
192
+ " print_msg(f\"{f'SOURCE: ':>20}{source_text}\")\n",
193
+ " print_msg(f\"{f'TARGET: ':>20}{target_text}\")\n",
194
+ " print_msg(f\"{f'PREDICTED GREEDY: ':>20}{model_out_text_greedy}\")\n",
195
+ " print_msg(f\"{f'PREDICTED BEAM: ':>20}{model_out_text_beam}\")\n",
196
+ "\n",
197
+ " if count == num_examples:\n",
198
+ " print_msg('-'*console_width)\n",
199
+ " break\n",
200
+ "\n",
201
+ "run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, 20, device, print_msg=print, num_examples=2)"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": []
210
+ }
211
+ ],
212
+ "metadata": {
213
+ "kernelspec": {
214
+ "display_name": "transformer",
215
+ "language": "python",
216
+ "name": "python3"
217
+ },
218
+ "language_info": {
219
+ "codemirror_mode": {
220
+ "name": "ipython",
221
+ "version": 3
222
+ },
223
+ "file_extension": ".py",
224
+ "mimetype": "text/x-python",
225
+ "name": "python",
226
+ "nbconvert_exporter": "python",
227
+ "pygments_lexer": "ipython3",
228
+ "version": "3.11.3"
229
+ },
230
+ "orig_nbformat": 4
231
+ },
232
+ "nbformat": 4,
233
+ "nbformat_minor": 2
234
+ }
Others/Colab_Train.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Others/Inference.ipynb ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pathlib import Path\n",
10
+ "import torch\n",
11
+ "import torch.nn as nn\n",
12
+ "from config import get_config, latest_weights_file_path\n",
13
+ "from train import get_model, get_ds, run_validation\n",
14
+ "from translate import translate"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "# Define the device\n",
24
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
25
+ "print(\"Using device:\", device)\n",
26
+ "config = get_config()\n",
27
+ "train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n",
28
+ "model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)\n",
29
+ "\n",
30
+ "# Load the pretrained weights\n",
31
+ "model_filename = latest_weights_file_path(config)\n",
32
+ "state = torch.load(model_filename)\n",
33
+ "model.load_state_dict(state['model_state_dict'])"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: print(msg), 0, None, num_examples=10)"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "t = translate(\"Why do I need to translate this?\")"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "t = translate(34)"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": []
69
+ }
70
+ ],
71
+ "metadata": {
72
+ "kernelspec": {
73
+ "display_name": "transformer",
74
+ "language": "python",
75
+ "name": "python3"
76
+ },
77
+ "language_info": {
78
+ "codemirror_mode": {
79
+ "name": "ipython",
80
+ "version": 3
81
+ },
82
+ "file_extension": ".py",
83
+ "mimetype": "text/x-python",
84
+ "name": "python",
85
+ "nbconvert_exporter": "python",
86
+ "pygments_lexer": "ipython3",
87
+ "version": "3.9.0"
88
+ },
89
+ "orig_nbformat": 4
90
+ },
91
+ "nbformat": 4,
92
+ "nbformat_minor": 2
93
+ }
Others/Local_Train.ipynb ADDED
@@ -0,0 +1,1832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "colab": {
8
+ "base_uri": "https://localhost:8080/",
9
+ "height": 198,
10
+ "referenced_widgets": [
11
+ "0ce327d5112b44dbb20e57752afc478a",
12
+ "423a3059ad1a4e01bd01095cf1b41e14",
13
+ "9cf2d2e2bfe24f2ab185165d79da8bdb",
14
+ "996ac47b200c427088ee7644fe886896",
15
+ "9b9addf13301466b9ef30b9d4b836a67",
16
+ "ec2051bf0e9343d394e8a0ecb4fd5ec8",
17
+ "56049bd375cd4512a0deaf69b7dae245",
18
+ "140f33387db341398bc39e9c47703df4",
19
+ "b3a8424c0b584a37ad2ede748085425c",
20
+ "cb7d88a70af746f2ae31416b4b670c63",
21
+ "4837276e5cf248449e287b1eeaef30ec",
22
+ "3ab0f2022e654458875c2c091908e8c9",
23
+ "f74bdeb79a224de8b1c85f4ca8657331",
24
+ "4eb62038f89d4a8cb2c46e6a7cc70150",
25
+ "9055fd09043642e0ae3d8a7a7c0ab31b",
26
+ "4a2ead337d5c4ded9f28c93a70db1f08",
27
+ "888a323362ae4daeac99915bcb3dcf10",
28
+ "4d0e364e9f274e8ea7447e4e01c7f28f",
29
+ "78a32764678a42f0a5a892f5275d88de",
30
+ "aa17c3a834694a978046808fc5d29da1",
31
+ "11e011e4acb24519bd41a054ddecbfb1",
32
+ "5d1a9518abd44c18b122e575a7548ed2",
33
+ "76e80fb236f5491597c992d1a809be33",
34
+ "f7359467b0214c5385de8ee4334f7ba3",
35
+ "a58ac736aa884eb9a27264cb04bb36ce",
36
+ "6e6f7b7cccaa4f0cbfc9311db257bea1",
37
+ "0656eee26364487f81580c3864e7a159",
38
+ "05240e68c55a458286f43967e7f90889",
39
+ "8cfa6df0ee654643bfdb4a3825e8fbbe",
40
+ "96baa91869eb478eb492754b98169470",
41
+ "bbda5260ca1c450386f9191e9f9dde97",
42
+ "6fc5bec49f17469db39e0d4b535b94e9",
43
+ "67822d28f8584e69abcb041b88377a9f",
44
+ "aa082ade829247dc8ea0d75cc8a5b2a7",
45
+ "83bc41f428b7492e9defdaa177f33a3e",
46
+ "7f168d0ea11c4ea1a96202d3a36ec389",
47
+ "ebb7ee3fd084466f9667771a99e6e3b2",
48
+ "1e3c2a94251b4e75af0413a88b53bfe1",
49
+ "a1188f80f78c49c7a822d71694e47074",
50
+ "068552491889440e8a66e61b9f013786",
51
+ "c88027eb3e1c4771ab57366070ecd553",
52
+ "df75b255bfb04057b553830b59f0a153",
53
+ "f0e5024d0d054c1eb8e01c4c8b027e79",
54
+ "937ee45f4d634d189c6d95c886e97bca",
55
+ "c2d14fa4280c48e0ae04859b73c80781",
56
+ "d3104837d9734834b7c87e87289b08df",
57
+ "02b02005adf241a4a0be8173ca3a4aee",
58
+ "b317ba38f2b145f9b0b49f523547684f",
59
+ "434340d109d1401d8868498a23b291cf",
60
+ "2c95f5b81fc84ad698fe77b52cb84076",
61
+ "ca588157678e4cc09c3fd760676efd39",
62
+ "c020b38c6d2c436e8b742fd87d3b8b89",
63
+ "3dc97a04373f484d9ccd1c46646d96cc",
64
+ "4aed1fa58b7342eba35c2106ec934019",
65
+ "60c72c47a8d84f0eab652822bed1ed09"
66
+ ]
67
+ },
68
+ "id": "gGDOaOoIwGc5",
69
+ "outputId": "4180e60a-8985-4795-8e72-373deabc1ebc"
70
+ },
71
+ "outputs": [],
72
+ "source": [
73
+ "from config import get_config\n",
74
+ "cfg = get_config()\n",
75
+ "cfg['batch_size'] = 6\n",
76
+ "cfg['preload'] = None\n",
77
+ "cfg['num_epochs'] = 30\n",
78
+ "\n",
79
+ "from train import train_model\n",
80
+ "\n",
81
+ "train_model(cfg)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": []
90
+ }
91
+ ],
92
+ "metadata": {
93
+ "accelerator": "GPU",
94
+ "colab": {
95
+ "gpuType": "T4",
96
+ "provenance": []
97
+ },
98
+ "gpuClass": "standard",
99
+ "kernelspec": {
100
+ "display_name": "Python 3",
101
+ "name": "python3"
102
+ },
103
+ "language_info": {
104
+ "codemirror_mode": {
105
+ "name": "ipython",
106
+ "version": 3
107
+ },
108
+ "file_extension": ".py",
109
+ "mimetype": "text/x-python",
110
+ "name": "python",
111
+ "nbconvert_exporter": "python",
112
+ "pygments_lexer": "ipython3",
113
+ "version": "3.10.6"
114
+ },
115
+ "widgets": {
116
+ "application/vnd.jupyter.widget-state+json": {
117
+ "02b02005adf241a4a0be8173ca3a4aee": {
118
+ "model_module": "@jupyter-widgets/controls",
119
+ "model_module_version": "1.5.0",
120
+ "model_name": "FloatProgressModel",
121
+ "state": {
122
+ "_dom_classes": [],
123
+ "_model_module": "@jupyter-widgets/controls",
124
+ "_model_module_version": "1.5.0",
125
+ "_model_name": "FloatProgressModel",
126
+ "_view_count": null,
127
+ "_view_module": "@jupyter-widgets/controls",
128
+ "_view_module_version": "1.5.0",
129
+ "_view_name": "ProgressView",
130
+ "bar_style": "",
131
+ "description": "",
132
+ "description_tooltip": null,
133
+ "layout": "IPY_MODEL_c020b38c6d2c436e8b742fd87d3b8b89",
134
+ "max": 32332,
135
+ "min": 0,
136
+ "orientation": "horizontal",
137
+ "style": "IPY_MODEL_3dc97a04373f484d9ccd1c46646d96cc",
138
+ "value": 32332
139
+ }
140
+ },
141
+ "05240e68c55a458286f43967e7f90889": {
142
+ "model_module": "@jupyter-widgets/base",
143
+ "model_module_version": "1.2.0",
144
+ "model_name": "LayoutModel",
145
+ "state": {
146
+ "_model_module": "@jupyter-widgets/base",
147
+ "_model_module_version": "1.2.0",
148
+ "_model_name": "LayoutModel",
149
+ "_view_count": null,
150
+ "_view_module": "@jupyter-widgets/base",
151
+ "_view_module_version": "1.2.0",
152
+ "_view_name": "LayoutView",
153
+ "align_content": null,
154
+ "align_items": null,
155
+ "align_self": null,
156
+ "border": null,
157
+ "bottom": null,
158
+ "display": null,
159
+ "flex": null,
160
+ "flex_flow": null,
161
+ "grid_area": null,
162
+ "grid_auto_columns": null,
163
+ "grid_auto_flow": null,
164
+ "grid_auto_rows": null,
165
+ "grid_column": null,
166
+ "grid_gap": null,
167
+ "grid_row": null,
168
+ "grid_template_areas": null,
169
+ "grid_template_columns": null,
170
+ "grid_template_rows": null,
171
+ "height": null,
172
+ "justify_content": null,
173
+ "justify_items": null,
174
+ "left": null,
175
+ "margin": null,
176
+ "max_height": null,
177
+ "max_width": null,
178
+ "min_height": null,
179
+ "min_width": null,
180
+ "object_fit": null,
181
+ "object_position": null,
182
+ "order": null,
183
+ "overflow": null,
184
+ "overflow_x": null,
185
+ "overflow_y": null,
186
+ "padding": null,
187
+ "right": null,
188
+ "top": null,
189
+ "visibility": null,
190
+ "width": null
191
+ }
192
+ },
193
+ "0656eee26364487f81580c3864e7a159": {
194
+ "model_module": "@jupyter-widgets/base",
195
+ "model_module_version": "1.2.0",
196
+ "model_name": "LayoutModel",
197
+ "state": {
198
+ "_model_module": "@jupyter-widgets/base",
199
+ "_model_module_version": "1.2.0",
200
+ "_model_name": "LayoutModel",
201
+ "_view_count": null,
202
+ "_view_module": "@jupyter-widgets/base",
203
+ "_view_module_version": "1.2.0",
204
+ "_view_name": "LayoutView",
205
+ "align_content": null,
206
+ "align_items": null,
207
+ "align_self": null,
208
+ "border": null,
209
+ "bottom": null,
210
+ "display": null,
211
+ "flex": null,
212
+ "flex_flow": null,
213
+ "grid_area": null,
214
+ "grid_auto_columns": null,
215
+ "grid_auto_flow": null,
216
+ "grid_auto_rows": null,
217
+ "grid_column": null,
218
+ "grid_gap": null,
219
+ "grid_row": null,
220
+ "grid_template_areas": null,
221
+ "grid_template_columns": null,
222
+ "grid_template_rows": null,
223
+ "height": null,
224
+ "justify_content": null,
225
+ "justify_items": null,
226
+ "left": null,
227
+ "margin": null,
228
+ "max_height": null,
229
+ "max_width": null,
230
+ "min_height": null,
231
+ "min_width": null,
232
+ "object_fit": null,
233
+ "object_position": null,
234
+ "order": null,
235
+ "overflow": null,
236
+ "overflow_x": null,
237
+ "overflow_y": null,
238
+ "padding": null,
239
+ "right": null,
240
+ "top": null,
241
+ "visibility": null,
242
+ "width": null
243
+ }
244
+ },
245
+ "068552491889440e8a66e61b9f013786": {
246
+ "model_module": "@jupyter-widgets/controls",
247
+ "model_module_version": "1.5.0",
248
+ "model_name": "DescriptionStyleModel",
249
+ "state": {
250
+ "_model_module": "@jupyter-widgets/controls",
251
+ "_model_module_version": "1.5.0",
252
+ "_model_name": "DescriptionStyleModel",
253
+ "_view_count": null,
254
+ "_view_module": "@jupyter-widgets/base",
255
+ "_view_module_version": "1.2.0",
256
+ "_view_name": "StyleView",
257
+ "description_width": ""
258
+ }
259
+ },
260
+ "0ce327d5112b44dbb20e57752afc478a": {
261
+ "model_module": "@jupyter-widgets/controls",
262
+ "model_module_version": "1.5.0",
263
+ "model_name": "HBoxModel",
264
+ "state": {
265
+ "_dom_classes": [],
266
+ "_model_module": "@jupyter-widgets/controls",
267
+ "_model_module_version": "1.5.0",
268
+ "_model_name": "HBoxModel",
269
+ "_view_count": null,
270
+ "_view_module": "@jupyter-widgets/controls",
271
+ "_view_module_version": "1.5.0",
272
+ "_view_name": "HBoxView",
273
+ "box_style": "",
274
+ "children": [
275
+ "IPY_MODEL_423a3059ad1a4e01bd01095cf1b41e14",
276
+ "IPY_MODEL_9cf2d2e2bfe24f2ab185165d79da8bdb",
277
+ "IPY_MODEL_996ac47b200c427088ee7644fe886896"
278
+ ],
279
+ "layout": "IPY_MODEL_9b9addf13301466b9ef30b9d4b836a67"
280
+ }
281
+ },
282
+ "11e011e4acb24519bd41a054ddecbfb1": {
283
+ "model_module": "@jupyter-widgets/base",
284
+ "model_module_version": "1.2.0",
285
+ "model_name": "LayoutModel",
286
+ "state": {
287
+ "_model_module": "@jupyter-widgets/base",
288
+ "_model_module_version": "1.2.0",
289
+ "_model_name": "LayoutModel",
290
+ "_view_count": null,
291
+ "_view_module": "@jupyter-widgets/base",
292
+ "_view_module_version": "1.2.0",
293
+ "_view_name": "LayoutView",
294
+ "align_content": null,
295
+ "align_items": null,
296
+ "align_self": null,
297
+ "border": null,
298
+ "bottom": null,
299
+ "display": null,
300
+ "flex": null,
301
+ "flex_flow": null,
302
+ "grid_area": null,
303
+ "grid_auto_columns": null,
304
+ "grid_auto_flow": null,
305
+ "grid_auto_rows": null,
306
+ "grid_column": null,
307
+ "grid_gap": null,
308
+ "grid_row": null,
309
+ "grid_template_areas": null,
310
+ "grid_template_columns": null,
311
+ "grid_template_rows": null,
312
+ "height": null,
313
+ "justify_content": null,
314
+ "justify_items": null,
315
+ "left": null,
316
+ "margin": null,
317
+ "max_height": null,
318
+ "max_width": null,
319
+ "min_height": null,
320
+ "min_width": null,
321
+ "object_fit": null,
322
+ "object_position": null,
323
+ "order": null,
324
+ "overflow": null,
325
+ "overflow_x": null,
326
+ "overflow_y": null,
327
+ "padding": null,
328
+ "right": null,
329
+ "top": null,
330
+ "visibility": null,
331
+ "width": null
332
+ }
333
+ },
334
+ "140f33387db341398bc39e9c47703df4": {
335
+ "model_module": "@jupyter-widgets/base",
336
+ "model_module_version": "1.2.0",
337
+ "model_name": "LayoutModel",
338
+ "state": {
339
+ "_model_module": "@jupyter-widgets/base",
340
+ "_model_module_version": "1.2.0",
341
+ "_model_name": "LayoutModel",
342
+ "_view_count": null,
343
+ "_view_module": "@jupyter-widgets/base",
344
+ "_view_module_version": "1.2.0",
345
+ "_view_name": "LayoutView",
346
+ "align_content": null,
347
+ "align_items": null,
348
+ "align_self": null,
349
+ "border": null,
350
+ "bottom": null,
351
+ "display": null,
352
+ "flex": null,
353
+ "flex_flow": null,
354
+ "grid_area": null,
355
+ "grid_auto_columns": null,
356
+ "grid_auto_flow": null,
357
+ "grid_auto_rows": null,
358
+ "grid_column": null,
359
+ "grid_gap": null,
360
+ "grid_row": null,
361
+ "grid_template_areas": null,
362
+ "grid_template_columns": null,
363
+ "grid_template_rows": null,
364
+ "height": null,
365
+ "justify_content": null,
366
+ "justify_items": null,
367
+ "left": null,
368
+ "margin": null,
369
+ "max_height": null,
370
+ "max_width": null,
371
+ "min_height": null,
372
+ "min_width": null,
373
+ "object_fit": null,
374
+ "object_position": null,
375
+ "order": null,
376
+ "overflow": null,
377
+ "overflow_x": null,
378
+ "overflow_y": null,
379
+ "padding": null,
380
+ "right": null,
381
+ "top": null,
382
+ "visibility": null,
383
+ "width": null
384
+ }
385
+ },
386
+ "1e3c2a94251b4e75af0413a88b53bfe1": {
387
+ "model_module": "@jupyter-widgets/base",
388
+ "model_module_version": "1.2.0",
389
+ "model_name": "LayoutModel",
390
+ "state": {
391
+ "_model_module": "@jupyter-widgets/base",
392
+ "_model_module_version": "1.2.0",
393
+ "_model_name": "LayoutModel",
394
+ "_view_count": null,
395
+ "_view_module": "@jupyter-widgets/base",
396
+ "_view_module_version": "1.2.0",
397
+ "_view_name": "LayoutView",
398
+ "align_content": null,
399
+ "align_items": null,
400
+ "align_self": null,
401
+ "border": null,
402
+ "bottom": null,
403
+ "display": null,
404
+ "flex": null,
405
+ "flex_flow": null,
406
+ "grid_area": null,
407
+ "grid_auto_columns": null,
408
+ "grid_auto_flow": null,
409
+ "grid_auto_rows": null,
410
+ "grid_column": null,
411
+ "grid_gap": null,
412
+ "grid_row": null,
413
+ "grid_template_areas": null,
414
+ "grid_template_columns": null,
415
+ "grid_template_rows": null,
416
+ "height": null,
417
+ "justify_content": null,
418
+ "justify_items": null,
419
+ "left": null,
420
+ "margin": null,
421
+ "max_height": null,
422
+ "max_width": null,
423
+ "min_height": null,
424
+ "min_width": null,
425
+ "object_fit": null,
426
+ "object_position": null,
427
+ "order": null,
428
+ "overflow": null,
429
+ "overflow_x": null,
430
+ "overflow_y": null,
431
+ "padding": null,
432
+ "right": null,
433
+ "top": null,
434
+ "visibility": null,
435
+ "width": null
436
+ }
437
+ },
438
+ "2c95f5b81fc84ad698fe77b52cb84076": {
439
+ "model_module": "@jupyter-widgets/base",
440
+ "model_module_version": "1.2.0",
441
+ "model_name": "LayoutModel",
442
+ "state": {
443
+ "_model_module": "@jupyter-widgets/base",
444
+ "_model_module_version": "1.2.0",
445
+ "_model_name": "LayoutModel",
446
+ "_view_count": null,
447
+ "_view_module": "@jupyter-widgets/base",
448
+ "_view_module_version": "1.2.0",
449
+ "_view_name": "LayoutView",
450
+ "align_content": null,
451
+ "align_items": null,
452
+ "align_self": null,
453
+ "border": null,
454
+ "bottom": null,
455
+ "display": null,
456
+ "flex": null,
457
+ "flex_flow": null,
458
+ "grid_area": null,
459
+ "grid_auto_columns": null,
460
+ "grid_auto_flow": null,
461
+ "grid_auto_rows": null,
462
+ "grid_column": null,
463
+ "grid_gap": null,
464
+ "grid_row": null,
465
+ "grid_template_areas": null,
466
+ "grid_template_columns": null,
467
+ "grid_template_rows": null,
468
+ "height": null,
469
+ "justify_content": null,
470
+ "justify_items": null,
471
+ "left": null,
472
+ "margin": null,
473
+ "max_height": null,
474
+ "max_width": null,
475
+ "min_height": null,
476
+ "min_width": null,
477
+ "object_fit": null,
478
+ "object_position": null,
479
+ "order": null,
480
+ "overflow": null,
481
+ "overflow_x": null,
482
+ "overflow_y": null,
483
+ "padding": null,
484
+ "right": null,
485
+ "top": null,
486
+ "visibility": null,
487
+ "width": null
488
+ }
489
+ },
490
+ "3ab0f2022e654458875c2c091908e8c9": {
491
+ "model_module": "@jupyter-widgets/controls",
492
+ "model_module_version": "1.5.0",
493
+ "model_name": "HBoxModel",
494
+ "state": {
495
+ "_dom_classes": [],
496
+ "_model_module": "@jupyter-widgets/controls",
497
+ "_model_module_version": "1.5.0",
498
+ "_model_name": "HBoxModel",
499
+ "_view_count": null,
500
+ "_view_module": "@jupyter-widgets/controls",
501
+ "_view_module_version": "1.5.0",
502
+ "_view_name": "HBoxView",
503
+ "box_style": "",
504
+ "children": [
505
+ "IPY_MODEL_f74bdeb79a224de8b1c85f4ca8657331",
506
+ "IPY_MODEL_4eb62038f89d4a8cb2c46e6a7cc70150",
507
+ "IPY_MODEL_9055fd09043642e0ae3d8a7a7c0ab31b"
508
+ ],
509
+ "layout": "IPY_MODEL_4a2ead337d5c4ded9f28c93a70db1f08"
510
+ }
511
+ },
512
+ "3dc97a04373f484d9ccd1c46646d96cc": {
513
+ "model_module": "@jupyter-widgets/controls",
514
+ "model_module_version": "1.5.0",
515
+ "model_name": "ProgressStyleModel",
516
+ "state": {
517
+ "_model_module": "@jupyter-widgets/controls",
518
+ "_model_module_version": "1.5.0",
519
+ "_model_name": "ProgressStyleModel",
520
+ "_view_count": null,
521
+ "_view_module": "@jupyter-widgets/base",
522
+ "_view_module_version": "1.2.0",
523
+ "_view_name": "StyleView",
524
+ "bar_color": null,
525
+ "description_width": ""
526
+ }
527
+ },
528
+ "423a3059ad1a4e01bd01095cf1b41e14": {
529
+ "model_module": "@jupyter-widgets/controls",
530
+ "model_module_version": "1.5.0",
531
+ "model_name": "HTMLModel",
532
+ "state": {
533
+ "_dom_classes": [],
534
+ "_model_module": "@jupyter-widgets/controls",
535
+ "_model_module_version": "1.5.0",
536
+ "_model_name": "HTMLModel",
537
+ "_view_count": null,
538
+ "_view_module": "@jupyter-widgets/controls",
539
+ "_view_module_version": "1.5.0",
540
+ "_view_name": "HTMLView",
541
+ "description": "",
542
+ "description_tooltip": null,
543
+ "layout": "IPY_MODEL_ec2051bf0e9343d394e8a0ecb4fd5ec8",
544
+ "placeholder": "​",
545
+ "style": "IPY_MODEL_56049bd375cd4512a0deaf69b7dae245",
546
+ "value": "Downloading builder script: 100%"
547
+ }
548
+ },
549
+ "434340d109d1401d8868498a23b291cf": {
550
+ "model_module": "@jupyter-widgets/base",
551
+ "model_module_version": "1.2.0",
552
+ "model_name": "LayoutModel",
553
+ "state": {
554
+ "_model_module": "@jupyter-widgets/base",
555
+ "_model_module_version": "1.2.0",
556
+ "_model_name": "LayoutModel",
557
+ "_view_count": null,
558
+ "_view_module": "@jupyter-widgets/base",
559
+ "_view_module_version": "1.2.0",
560
+ "_view_name": "LayoutView",
561
+ "align_content": null,
562
+ "align_items": null,
563
+ "align_self": null,
564
+ "border": null,
565
+ "bottom": null,
566
+ "display": null,
567
+ "flex": null,
568
+ "flex_flow": null,
569
+ "grid_area": null,
570
+ "grid_auto_columns": null,
571
+ "grid_auto_flow": null,
572
+ "grid_auto_rows": null,
573
+ "grid_column": null,
574
+ "grid_gap": null,
575
+ "grid_row": null,
576
+ "grid_template_areas": null,
577
+ "grid_template_columns": null,
578
+ "grid_template_rows": null,
579
+ "height": null,
580
+ "justify_content": null,
581
+ "justify_items": null,
582
+ "left": null,
583
+ "margin": null,
584
+ "max_height": null,
585
+ "max_width": null,
586
+ "min_height": null,
587
+ "min_width": null,
588
+ "object_fit": null,
589
+ "object_position": null,
590
+ "order": null,
591
+ "overflow": null,
592
+ "overflow_x": null,
593
+ "overflow_y": null,
594
+ "padding": null,
595
+ "right": null,
596
+ "top": null,
597
+ "visibility": "hidden",
598
+ "width": null
599
+ }
600
+ },
601
+ "4837276e5cf248449e287b1eeaef30ec": {
602
+ "model_module": "@jupyter-widgets/controls",
603
+ "model_module_version": "1.5.0",
604
+ "model_name": "DescriptionStyleModel",
605
+ "state": {
606
+ "_model_module": "@jupyter-widgets/controls",
607
+ "_model_module_version": "1.5.0",
608
+ "_model_name": "DescriptionStyleModel",
609
+ "_view_count": null,
610
+ "_view_module": "@jupyter-widgets/base",
611
+ "_view_module_version": "1.2.0",
612
+ "_view_name": "StyleView",
613
+ "description_width": ""
614
+ }
615
+ },
616
+ "4a2ead337d5c4ded9f28c93a70db1f08": {
617
+ "model_module": "@jupyter-widgets/base",
618
+ "model_module_version": "1.2.0",
619
+ "model_name": "LayoutModel",
620
+ "state": {
621
+ "_model_module": "@jupyter-widgets/base",
622
+ "_model_module_version": "1.2.0",
623
+ "_model_name": "LayoutModel",
624
+ "_view_count": null,
625
+ "_view_module": "@jupyter-widgets/base",
626
+ "_view_module_version": "1.2.0",
627
+ "_view_name": "LayoutView",
628
+ "align_content": null,
629
+ "align_items": null,
630
+ "align_self": null,
631
+ "border": null,
632
+ "bottom": null,
633
+ "display": null,
634
+ "flex": null,
635
+ "flex_flow": null,
636
+ "grid_area": null,
637
+ "grid_auto_columns": null,
638
+ "grid_auto_flow": null,
639
+ "grid_auto_rows": null,
640
+ "grid_column": null,
641
+ "grid_gap": null,
642
+ "grid_row": null,
643
+ "grid_template_areas": null,
644
+ "grid_template_columns": null,
645
+ "grid_template_rows": null,
646
+ "height": null,
647
+ "justify_content": null,
648
+ "justify_items": null,
649
+ "left": null,
650
+ "margin": null,
651
+ "max_height": null,
652
+ "max_width": null,
653
+ "min_height": null,
654
+ "min_width": null,
655
+ "object_fit": null,
656
+ "object_position": null,
657
+ "order": null,
658
+ "overflow": null,
659
+ "overflow_x": null,
660
+ "overflow_y": null,
661
+ "padding": null,
662
+ "right": null,
663
+ "top": null,
664
+ "visibility": null,
665
+ "width": null
666
+ }
667
+ },
668
+ "4aed1fa58b7342eba35c2106ec934019": {
669
+ "model_module": "@jupyter-widgets/base",
670
+ "model_module_version": "1.2.0",
671
+ "model_name": "LayoutModel",
672
+ "state": {
673
+ "_model_module": "@jupyter-widgets/base",
674
+ "_model_module_version": "1.2.0",
675
+ "_model_name": "LayoutModel",
676
+ "_view_count": null,
677
+ "_view_module": "@jupyter-widgets/base",
678
+ "_view_module_version": "1.2.0",
679
+ "_view_name": "LayoutView",
680
+ "align_content": null,
681
+ "align_items": null,
682
+ "align_self": null,
683
+ "border": null,
684
+ "bottom": null,
685
+ "display": null,
686
+ "flex": null,
687
+ "flex_flow": null,
688
+ "grid_area": null,
689
+ "grid_auto_columns": null,
690
+ "grid_auto_flow": null,
691
+ "grid_auto_rows": null,
692
+ "grid_column": null,
693
+ "grid_gap": null,
694
+ "grid_row": null,
695
+ "grid_template_areas": null,
696
+ "grid_template_columns": null,
697
+ "grid_template_rows": null,
698
+ "height": null,
699
+ "justify_content": null,
700
+ "justify_items": null,
701
+ "left": null,
702
+ "margin": null,
703
+ "max_height": null,
704
+ "max_width": null,
705
+ "min_height": null,
706
+ "min_width": null,
707
+ "object_fit": null,
708
+ "object_position": null,
709
+ "order": null,
710
+ "overflow": null,
711
+ "overflow_x": null,
712
+ "overflow_y": null,
713
+ "padding": null,
714
+ "right": null,
715
+ "top": null,
716
+ "visibility": null,
717
+ "width": null
718
+ }
719
+ },
720
+ "4d0e364e9f274e8ea7447e4e01c7f28f": {
721
+ "model_module": "@jupyter-widgets/controls",
722
+ "model_module_version": "1.5.0",
723
+ "model_name": "DescriptionStyleModel",
724
+ "state": {
725
+ "_model_module": "@jupyter-widgets/controls",
726
+ "_model_module_version": "1.5.0",
727
+ "_model_name": "DescriptionStyleModel",
728
+ "_view_count": null,
729
+ "_view_module": "@jupyter-widgets/base",
730
+ "_view_module_version": "1.2.0",
731
+ "_view_name": "StyleView",
732
+ "description_width": ""
733
+ }
734
+ },
735
+ "4eb62038f89d4a8cb2c46e6a7cc70150": {
736
+ "model_module": "@jupyter-widgets/controls",
737
+ "model_module_version": "1.5.0",
738
+ "model_name": "FloatProgressModel",
739
+ "state": {
740
+ "_dom_classes": [],
741
+ "_model_module": "@jupyter-widgets/controls",
742
+ "_model_module_version": "1.5.0",
743
+ "_model_name": "FloatProgressModel",
744
+ "_view_count": null,
745
+ "_view_module": "@jupyter-widgets/controls",
746
+ "_view_module_version": "1.5.0",
747
+ "_view_name": "ProgressView",
748
+ "bar_style": "success",
749
+ "description": "",
750
+ "description_tooltip": null,
751
+ "layout": "IPY_MODEL_78a32764678a42f0a5a892f5275d88de",
752
+ "max": 161154,
753
+ "min": 0,
754
+ "orientation": "horizontal",
755
+ "style": "IPY_MODEL_aa17c3a834694a978046808fc5d29da1",
756
+ "value": 161154
757
+ }
758
+ },
759
+ "56049bd375cd4512a0deaf69b7dae245": {
760
+ "model_module": "@jupyter-widgets/controls",
761
+ "model_module_version": "1.5.0",
762
+ "model_name": "DescriptionStyleModel",
763
+ "state": {
764
+ "_model_module": "@jupyter-widgets/controls",
765
+ "_model_module_version": "1.5.0",
766
+ "_model_name": "DescriptionStyleModel",
767
+ "_view_count": null,
768
+ "_view_module": "@jupyter-widgets/base",
769
+ "_view_module_version": "1.2.0",
770
+ "_view_name": "StyleView",
771
+ "description_width": ""
772
+ }
773
+ },
774
+ "5d1a9518abd44c18b122e575a7548ed2": {
775
+ "model_module": "@jupyter-widgets/controls",
776
+ "model_module_version": "1.5.0",
777
+ "model_name": "DescriptionStyleModel",
778
+ "state": {
779
+ "_model_module": "@jupyter-widgets/controls",
780
+ "_model_module_version": "1.5.0",
781
+ "_model_name": "DescriptionStyleModel",
782
+ "_view_count": null,
783
+ "_view_module": "@jupyter-widgets/base",
784
+ "_view_module_version": "1.2.0",
785
+ "_view_name": "StyleView",
786
+ "description_width": ""
787
+ }
788
+ },
789
+ "60c72c47a8d84f0eab652822bed1ed09": {
790
+ "model_module": "@jupyter-widgets/controls",
791
+ "model_module_version": "1.5.0",
792
+ "model_name": "DescriptionStyleModel",
793
+ "state": {
794
+ "_model_module": "@jupyter-widgets/controls",
795
+ "_model_module_version": "1.5.0",
796
+ "_model_name": "DescriptionStyleModel",
797
+ "_view_count": null,
798
+ "_view_module": "@jupyter-widgets/base",
799
+ "_view_module_version": "1.2.0",
800
+ "_view_name": "StyleView",
801
+ "description_width": ""
802
+ }
803
+ },
804
+ "67822d28f8584e69abcb041b88377a9f": {
805
+ "model_module": "@jupyter-widgets/controls",
806
+ "model_module_version": "1.5.0",
807
+ "model_name": "DescriptionStyleModel",
808
+ "state": {
809
+ "_model_module": "@jupyter-widgets/controls",
810
+ "_model_module_version": "1.5.0",
811
+ "_model_name": "DescriptionStyleModel",
812
+ "_view_count": null,
813
+ "_view_module": "@jupyter-widgets/base",
814
+ "_view_module_version": "1.2.0",
815
+ "_view_name": "StyleView",
816
+ "description_width": ""
817
+ }
818
+ },
819
+ "6e6f7b7cccaa4f0cbfc9311db257bea1": {
820
+ "model_module": "@jupyter-widgets/controls",
821
+ "model_module_version": "1.5.0",
822
+ "model_name": "HTMLModel",
823
+ "state": {
824
+ "_dom_classes": [],
825
+ "_model_module": "@jupyter-widgets/controls",
826
+ "_model_module_version": "1.5.0",
827
+ "_model_name": "HTMLModel",
828
+ "_view_count": null,
829
+ "_view_module": "@jupyter-widgets/controls",
830
+ "_view_module_version": "1.5.0",
831
+ "_view_name": "HTMLView",
832
+ "description": "",
833
+ "description_tooltip": null,
834
+ "layout": "IPY_MODEL_6fc5bec49f17469db39e0d4b535b94e9",
835
+ "placeholder": "​",
836
+ "style": "IPY_MODEL_67822d28f8584e69abcb041b88377a9f",
837
+ "value": " 20.5k/20.5k [00:00&lt;00:00, 1.34MB/s]"
838
+ }
839
+ },
840
+ "6fc5bec49f17469db39e0d4b535b94e9": {
841
+ "model_module": "@jupyter-widgets/base",
842
+ "model_module_version": "1.2.0",
843
+ "model_name": "LayoutModel",
844
+ "state": {
845
+ "_model_module": "@jupyter-widgets/base",
846
+ "_model_module_version": "1.2.0",
847
+ "_model_name": "LayoutModel",
848
+ "_view_count": null,
849
+ "_view_module": "@jupyter-widgets/base",
850
+ "_view_module_version": "1.2.0",
851
+ "_view_name": "LayoutView",
852
+ "align_content": null,
853
+ "align_items": null,
854
+ "align_self": null,
855
+ "border": null,
856
+ "bottom": null,
857
+ "display": null,
858
+ "flex": null,
859
+ "flex_flow": null,
860
+ "grid_area": null,
861
+ "grid_auto_columns": null,
862
+ "grid_auto_flow": null,
863
+ "grid_auto_rows": null,
864
+ "grid_column": null,
865
+ "grid_gap": null,
866
+ "grid_row": null,
867
+ "grid_template_areas": null,
868
+ "grid_template_columns": null,
869
+ "grid_template_rows": null,
870
+ "height": null,
871
+ "justify_content": null,
872
+ "justify_items": null,
873
+ "left": null,
874
+ "margin": null,
875
+ "max_height": null,
876
+ "max_width": null,
877
+ "min_height": null,
878
+ "min_width": null,
879
+ "object_fit": null,
880
+ "object_position": null,
881
+ "order": null,
882
+ "overflow": null,
883
+ "overflow_x": null,
884
+ "overflow_y": null,
885
+ "padding": null,
886
+ "right": null,
887
+ "top": null,
888
+ "visibility": null,
889
+ "width": null
890
+ }
891
+ },
892
+ "76e80fb236f5491597c992d1a809be33": {
893
+ "model_module": "@jupyter-widgets/controls",
894
+ "model_module_version": "1.5.0",
895
+ "model_name": "HBoxModel",
896
+ "state": {
897
+ "_dom_classes": [],
898
+ "_model_module": "@jupyter-widgets/controls",
899
+ "_model_module_version": "1.5.0",
900
+ "_model_name": "HBoxModel",
901
+ "_view_count": null,
902
+ "_view_module": "@jupyter-widgets/controls",
903
+ "_view_module_version": "1.5.0",
904
+ "_view_name": "HBoxView",
905
+ "box_style": "",
906
+ "children": [
907
+ "IPY_MODEL_f7359467b0214c5385de8ee4334f7ba3",
908
+ "IPY_MODEL_a58ac736aa884eb9a27264cb04bb36ce",
909
+ "IPY_MODEL_6e6f7b7cccaa4f0cbfc9311db257bea1"
910
+ ],
911
+ "layout": "IPY_MODEL_0656eee26364487f81580c3864e7a159"
912
+ }
913
+ },
914
+ "78a32764678a42f0a5a892f5275d88de": {
915
+ "model_module": "@jupyter-widgets/base",
916
+ "model_module_version": "1.2.0",
917
+ "model_name": "LayoutModel",
918
+ "state": {
919
+ "_model_module": "@jupyter-widgets/base",
920
+ "_model_module_version": "1.2.0",
921
+ "_model_name": "LayoutModel",
922
+ "_view_count": null,
923
+ "_view_module": "@jupyter-widgets/base",
924
+ "_view_module_version": "1.2.0",
925
+ "_view_name": "LayoutView",
926
+ "align_content": null,
927
+ "align_items": null,
928
+ "align_self": null,
929
+ "border": null,
930
+ "bottom": null,
931
+ "display": null,
932
+ "flex": null,
933
+ "flex_flow": null,
934
+ "grid_area": null,
935
+ "grid_auto_columns": null,
936
+ "grid_auto_flow": null,
937
+ "grid_auto_rows": null,
938
+ "grid_column": null,
939
+ "grid_gap": null,
940
+ "grid_row": null,
941
+ "grid_template_areas": null,
942
+ "grid_template_columns": null,
943
+ "grid_template_rows": null,
944
+ "height": null,
945
+ "justify_content": null,
946
+ "justify_items": null,
947
+ "left": null,
948
+ "margin": null,
949
+ "max_height": null,
950
+ "max_width": null,
951
+ "min_height": null,
952
+ "min_width": null,
953
+ "object_fit": null,
954
+ "object_position": null,
955
+ "order": null,
956
+ "overflow": null,
957
+ "overflow_x": null,
958
+ "overflow_y": null,
959
+ "padding": null,
960
+ "right": null,
961
+ "top": null,
962
+ "visibility": null,
963
+ "width": null
964
+ }
965
+ },
966
+ "7f168d0ea11c4ea1a96202d3a36ec389": {
967
+ "model_module": "@jupyter-widgets/controls",
968
+ "model_module_version": "1.5.0",
969
+ "model_name": "FloatProgressModel",
970
+ "state": {
971
+ "_dom_classes": [],
972
+ "_model_module": "@jupyter-widgets/controls",
973
+ "_model_module_version": "1.5.0",
974
+ "_model_name": "FloatProgressModel",
975
+ "_view_count": null,
976
+ "_view_module": "@jupyter-widgets/controls",
977
+ "_view_module_version": "1.5.0",
978
+ "_view_name": "ProgressView",
979
+ "bar_style": "success",
980
+ "description": "",
981
+ "description_tooltip": null,
982
+ "layout": "IPY_MODEL_c88027eb3e1c4771ab57366070ecd553",
983
+ "max": 3295251,
984
+ "min": 0,
985
+ "orientation": "horizontal",
986
+ "style": "IPY_MODEL_df75b255bfb04057b553830b59f0a153",
987
+ "value": 3295251
988
+ }
989
+ },
990
+ "83bc41f428b7492e9defdaa177f33a3e": {
991
+ "model_module": "@jupyter-widgets/controls",
992
+ "model_module_version": "1.5.0",
993
+ "model_name": "HTMLModel",
994
+ "state": {
995
+ "_dom_classes": [],
996
+ "_model_module": "@jupyter-widgets/controls",
997
+ "_model_module_version": "1.5.0",
998
+ "_model_name": "HTMLModel",
999
+ "_view_count": null,
1000
+ "_view_module": "@jupyter-widgets/controls",
1001
+ "_view_module_version": "1.5.0",
1002
+ "_view_name": "HTMLView",
1003
+ "description": "",
1004
+ "description_tooltip": null,
1005
+ "layout": "IPY_MODEL_a1188f80f78c49c7a822d71694e47074",
1006
+ "placeholder": "​",
1007
+ "style": "IPY_MODEL_068552491889440e8a66e61b9f013786",
1008
+ "value": "Downloading data: 100%"
1009
+ }
1010
+ },
1011
+ "888a323362ae4daeac99915bcb3dcf10": {
1012
+ "model_module": "@jupyter-widgets/base",
1013
+ "model_module_version": "1.2.0",
1014
+ "model_name": "LayoutModel",
1015
+ "state": {
1016
+ "_model_module": "@jupyter-widgets/base",
1017
+ "_model_module_version": "1.2.0",
1018
+ "_model_name": "LayoutModel",
1019
+ "_view_count": null,
1020
+ "_view_module": "@jupyter-widgets/base",
1021
+ "_view_module_version": "1.2.0",
1022
+ "_view_name": "LayoutView",
1023
+ "align_content": null,
1024
+ "align_items": null,
1025
+ "align_self": null,
1026
+ "border": null,
1027
+ "bottom": null,
1028
+ "display": null,
1029
+ "flex": null,
1030
+ "flex_flow": null,
1031
+ "grid_area": null,
1032
+ "grid_auto_columns": null,
1033
+ "grid_auto_flow": null,
1034
+ "grid_auto_rows": null,
1035
+ "grid_column": null,
1036
+ "grid_gap": null,
1037
+ "grid_row": null,
1038
+ "grid_template_areas": null,
1039
+ "grid_template_columns": null,
1040
+ "grid_template_rows": null,
1041
+ "height": null,
1042
+ "justify_content": null,
1043
+ "justify_items": null,
1044
+ "left": null,
1045
+ "margin": null,
1046
+ "max_height": null,
1047
+ "max_width": null,
1048
+ "min_height": null,
1049
+ "min_width": null,
1050
+ "object_fit": null,
1051
+ "object_position": null,
1052
+ "order": null,
1053
+ "overflow": null,
1054
+ "overflow_x": null,
1055
+ "overflow_y": null,
1056
+ "padding": null,
1057
+ "right": null,
1058
+ "top": null,
1059
+ "visibility": null,
1060
+ "width": null
1061
+ }
1062
+ },
1063
+ "8cfa6df0ee654643bfdb4a3825e8fbbe": {
1064
+ "model_module": "@jupyter-widgets/controls",
1065
+ "model_module_version": "1.5.0",
1066
+ "model_name": "DescriptionStyleModel",
1067
+ "state": {
1068
+ "_model_module": "@jupyter-widgets/controls",
1069
+ "_model_module_version": "1.5.0",
1070
+ "_model_name": "DescriptionStyleModel",
1071
+ "_view_count": null,
1072
+ "_view_module": "@jupyter-widgets/base",
1073
+ "_view_module_version": "1.2.0",
1074
+ "_view_name": "StyleView",
1075
+ "description_width": ""
1076
+ }
1077
+ },
1078
+ "9055fd09043642e0ae3d8a7a7c0ab31b": {
1079
+ "model_module": "@jupyter-widgets/controls",
1080
+ "model_module_version": "1.5.0",
1081
+ "model_name": "HTMLModel",
1082
+ "state": {
1083
+ "_dom_classes": [],
1084
+ "_model_module": "@jupyter-widgets/controls",
1085
+ "_model_module_version": "1.5.0",
1086
+ "_model_name": "HTMLModel",
1087
+ "_view_count": null,
1088
+ "_view_module": "@jupyter-widgets/controls",
1089
+ "_view_module_version": "1.5.0",
1090
+ "_view_name": "HTMLView",
1091
+ "description": "",
1092
+ "description_tooltip": null,
1093
+ "layout": "IPY_MODEL_11e011e4acb24519bd41a054ddecbfb1",
1094
+ "placeholder": "​",
1095
+ "style": "IPY_MODEL_5d1a9518abd44c18b122e575a7548ed2",
1096
+ "value": " 161k/161k [00:00&lt;00:00, 865kB/s]"
1097
+ }
1098
+ },
1099
+ "937ee45f4d634d189c6d95c886e97bca": {
1100
+ "model_module": "@jupyter-widgets/controls",
1101
+ "model_module_version": "1.5.0",
1102
+ "model_name": "DescriptionStyleModel",
1103
+ "state": {
1104
+ "_model_module": "@jupyter-widgets/controls",
1105
+ "_model_module_version": "1.5.0",
1106
+ "_model_name": "DescriptionStyleModel",
1107
+ "_view_count": null,
1108
+ "_view_module": "@jupyter-widgets/base",
1109
+ "_view_module_version": "1.2.0",
1110
+ "_view_name": "StyleView",
1111
+ "description_width": ""
1112
+ }
1113
+ },
1114
+ "96baa91869eb478eb492754b98169470": {
1115
+ "model_module": "@jupyter-widgets/base",
1116
+ "model_module_version": "1.2.0",
1117
+ "model_name": "LayoutModel",
1118
+ "state": {
1119
+ "_model_module": "@jupyter-widgets/base",
1120
+ "_model_module_version": "1.2.0",
1121
+ "_model_name": "LayoutModel",
1122
+ "_view_count": null,
1123
+ "_view_module": "@jupyter-widgets/base",
1124
+ "_view_module_version": "1.2.0",
1125
+ "_view_name": "LayoutView",
1126
+ "align_content": null,
1127
+ "align_items": null,
1128
+ "align_self": null,
1129
+ "border": null,
1130
+ "bottom": null,
1131
+ "display": null,
1132
+ "flex": null,
1133
+ "flex_flow": null,
1134
+ "grid_area": null,
1135
+ "grid_auto_columns": null,
1136
+ "grid_auto_flow": null,
1137
+ "grid_auto_rows": null,
1138
+ "grid_column": null,
1139
+ "grid_gap": null,
1140
+ "grid_row": null,
1141
+ "grid_template_areas": null,
1142
+ "grid_template_columns": null,
1143
+ "grid_template_rows": null,
1144
+ "height": null,
1145
+ "justify_content": null,
1146
+ "justify_items": null,
1147
+ "left": null,
1148
+ "margin": null,
1149
+ "max_height": null,
1150
+ "max_width": null,
1151
+ "min_height": null,
1152
+ "min_width": null,
1153
+ "object_fit": null,
1154
+ "object_position": null,
1155
+ "order": null,
1156
+ "overflow": null,
1157
+ "overflow_x": null,
1158
+ "overflow_y": null,
1159
+ "padding": null,
1160
+ "right": null,
1161
+ "top": null,
1162
+ "visibility": null,
1163
+ "width": null
1164
+ }
1165
+ },
1166
+ "996ac47b200c427088ee7644fe886896": {
1167
+ "model_module": "@jupyter-widgets/controls",
1168
+ "model_module_version": "1.5.0",
1169
+ "model_name": "HTMLModel",
1170
+ "state": {
1171
+ "_dom_classes": [],
1172
+ "_model_module": "@jupyter-widgets/controls",
1173
+ "_model_module_version": "1.5.0",
1174
+ "_model_name": "HTMLModel",
1175
+ "_view_count": null,
1176
+ "_view_module": "@jupyter-widgets/controls",
1177
+ "_view_module_version": "1.5.0",
1178
+ "_view_name": "HTMLView",
1179
+ "description": "",
1180
+ "description_tooltip": null,
1181
+ "layout": "IPY_MODEL_cb7d88a70af746f2ae31416b4b670c63",
1182
+ "placeholder": "​",
1183
+ "style": "IPY_MODEL_4837276e5cf248449e287b1eeaef30ec",
1184
+ "value": " 6.08k/6.08k [00:00&lt;00:00, 279kB/s]"
1185
+ }
1186
+ },
1187
+ "9b9addf13301466b9ef30b9d4b836a67": {
1188
+ "model_module": "@jupyter-widgets/base",
1189
+ "model_module_version": "1.2.0",
1190
+ "model_name": "LayoutModel",
1191
+ "state": {
1192
+ "_model_module": "@jupyter-widgets/base",
1193
+ "_model_module_version": "1.2.0",
1194
+ "_model_name": "LayoutModel",
1195
+ "_view_count": null,
1196
+ "_view_module": "@jupyter-widgets/base",
1197
+ "_view_module_version": "1.2.0",
1198
+ "_view_name": "LayoutView",
1199
+ "align_content": null,
1200
+ "align_items": null,
1201
+ "align_self": null,
1202
+ "border": null,
1203
+ "bottom": null,
1204
+ "display": null,
1205
+ "flex": null,
1206
+ "flex_flow": null,
1207
+ "grid_area": null,
1208
+ "grid_auto_columns": null,
1209
+ "grid_auto_flow": null,
1210
+ "grid_auto_rows": null,
1211
+ "grid_column": null,
1212
+ "grid_gap": null,
1213
+ "grid_row": null,
1214
+ "grid_template_areas": null,
1215
+ "grid_template_columns": null,
1216
+ "grid_template_rows": null,
1217
+ "height": null,
1218
+ "justify_content": null,
1219
+ "justify_items": null,
1220
+ "left": null,
1221
+ "margin": null,
1222
+ "max_height": null,
1223
+ "max_width": null,
1224
+ "min_height": null,
1225
+ "min_width": null,
1226
+ "object_fit": null,
1227
+ "object_position": null,
1228
+ "order": null,
1229
+ "overflow": null,
1230
+ "overflow_x": null,
1231
+ "overflow_y": null,
1232
+ "padding": null,
1233
+ "right": null,
1234
+ "top": null,
1235
+ "visibility": null,
1236
+ "width": null
1237
+ }
1238
+ },
1239
+ "9cf2d2e2bfe24f2ab185165d79da8bdb": {
1240
+ "model_module": "@jupyter-widgets/controls",
1241
+ "model_module_version": "1.5.0",
1242
+ "model_name": "FloatProgressModel",
1243
+ "state": {
1244
+ "_dom_classes": [],
1245
+ "_model_module": "@jupyter-widgets/controls",
1246
+ "_model_module_version": "1.5.0",
1247
+ "_model_name": "FloatProgressModel",
1248
+ "_view_count": null,
1249
+ "_view_module": "@jupyter-widgets/controls",
1250
+ "_view_module_version": "1.5.0",
1251
+ "_view_name": "ProgressView",
1252
+ "bar_style": "success",
1253
+ "description": "",
1254
+ "description_tooltip": null,
1255
+ "layout": "IPY_MODEL_140f33387db341398bc39e9c47703df4",
1256
+ "max": 6081,
1257
+ "min": 0,
1258
+ "orientation": "horizontal",
1259
+ "style": "IPY_MODEL_b3a8424c0b584a37ad2ede748085425c",
1260
+ "value": 6081
1261
+ }
1262
+ },
1263
+ "a1188f80f78c49c7a822d71694e47074": {
1264
+ "model_module": "@jupyter-widgets/base",
1265
+ "model_module_version": "1.2.0",
1266
+ "model_name": "LayoutModel",
1267
+ "state": {
1268
+ "_model_module": "@jupyter-widgets/base",
1269
+ "_model_module_version": "1.2.0",
1270
+ "_model_name": "LayoutModel",
1271
+ "_view_count": null,
1272
+ "_view_module": "@jupyter-widgets/base",
1273
+ "_view_module_version": "1.2.0",
1274
+ "_view_name": "LayoutView",
1275
+ "align_content": null,
1276
+ "align_items": null,
1277
+ "align_self": null,
1278
+ "border": null,
1279
+ "bottom": null,
1280
+ "display": null,
1281
+ "flex": null,
1282
+ "flex_flow": null,
1283
+ "grid_area": null,
1284
+ "grid_auto_columns": null,
1285
+ "grid_auto_flow": null,
1286
+ "grid_auto_rows": null,
1287
+ "grid_column": null,
1288
+ "grid_gap": null,
1289
+ "grid_row": null,
1290
+ "grid_template_areas": null,
1291
+ "grid_template_columns": null,
1292
+ "grid_template_rows": null,
1293
+ "height": null,
1294
+ "justify_content": null,
1295
+ "justify_items": null,
1296
+ "left": null,
1297
+ "margin": null,
1298
+ "max_height": null,
1299
+ "max_width": null,
1300
+ "min_height": null,
1301
+ "min_width": null,
1302
+ "object_fit": null,
1303
+ "object_position": null,
1304
+ "order": null,
1305
+ "overflow": null,
1306
+ "overflow_x": null,
1307
+ "overflow_y": null,
1308
+ "padding": null,
1309
+ "right": null,
1310
+ "top": null,
1311
+ "visibility": null,
1312
+ "width": null
1313
+ }
1314
+ },
1315
+ "a58ac736aa884eb9a27264cb04bb36ce": {
1316
+ "model_module": "@jupyter-widgets/controls",
1317
+ "model_module_version": "1.5.0",
1318
+ "model_name": "FloatProgressModel",
1319
+ "state": {
1320
+ "_dom_classes": [],
1321
+ "_model_module": "@jupyter-widgets/controls",
1322
+ "_model_module_version": "1.5.0",
1323
+ "_model_name": "FloatProgressModel",
1324
+ "_view_count": null,
1325
+ "_view_module": "@jupyter-widgets/controls",
1326
+ "_view_module_version": "1.5.0",
1327
+ "_view_name": "ProgressView",
1328
+ "bar_style": "success",
1329
+ "description": "",
1330
+ "description_tooltip": null,
1331
+ "layout": "IPY_MODEL_96baa91869eb478eb492754b98169470",
1332
+ "max": 20464,
1333
+ "min": 0,
1334
+ "orientation": "horizontal",
1335
+ "style": "IPY_MODEL_bbda5260ca1c450386f9191e9f9dde97",
1336
+ "value": 20464
1337
+ }
1338
+ },
1339
+ "aa082ade829247dc8ea0d75cc8a5b2a7": {
1340
+ "model_module": "@jupyter-widgets/controls",
1341
+ "model_module_version": "1.5.0",
1342
+ "model_name": "HBoxModel",
1343
+ "state": {
1344
+ "_dom_classes": [],
1345
+ "_model_module": "@jupyter-widgets/controls",
1346
+ "_model_module_version": "1.5.0",
1347
+ "_model_name": "HBoxModel",
1348
+ "_view_count": null,
1349
+ "_view_module": "@jupyter-widgets/controls",
1350
+ "_view_module_version": "1.5.0",
1351
+ "_view_name": "HBoxView",
1352
+ "box_style": "",
1353
+ "children": [
1354
+ "IPY_MODEL_83bc41f428b7492e9defdaa177f33a3e",
1355
+ "IPY_MODEL_7f168d0ea11c4ea1a96202d3a36ec389",
1356
+ "IPY_MODEL_ebb7ee3fd084466f9667771a99e6e3b2"
1357
+ ],
1358
+ "layout": "IPY_MODEL_1e3c2a94251b4e75af0413a88b53bfe1"
1359
+ }
1360
+ },
1361
+ "aa17c3a834694a978046808fc5d29da1": {
1362
+ "model_module": "@jupyter-widgets/controls",
1363
+ "model_module_version": "1.5.0",
1364
+ "model_name": "ProgressStyleModel",
1365
+ "state": {
1366
+ "_model_module": "@jupyter-widgets/controls",
1367
+ "_model_module_version": "1.5.0",
1368
+ "_model_name": "ProgressStyleModel",
1369
+ "_view_count": null,
1370
+ "_view_module": "@jupyter-widgets/base",
1371
+ "_view_module_version": "1.2.0",
1372
+ "_view_name": "StyleView",
1373
+ "bar_color": null,
1374
+ "description_width": ""
1375
+ }
1376
+ },
1377
+ "b317ba38f2b145f9b0b49f523547684f": {
1378
+ "model_module": "@jupyter-widgets/controls",
1379
+ "model_module_version": "1.5.0",
1380
+ "model_name": "HTMLModel",
1381
+ "state": {
1382
+ "_dom_classes": [],
1383
+ "_model_module": "@jupyter-widgets/controls",
1384
+ "_model_module_version": "1.5.0",
1385
+ "_model_name": "HTMLModel",
1386
+ "_view_count": null,
1387
+ "_view_module": "@jupyter-widgets/controls",
1388
+ "_view_module_version": "1.5.0",
1389
+ "_view_name": "HTMLView",
1390
+ "description": "",
1391
+ "description_tooltip": null,
1392
+ "layout": "IPY_MODEL_4aed1fa58b7342eba35c2106ec934019",
1393
+ "placeholder": "​",
1394
+ "style": "IPY_MODEL_60c72c47a8d84f0eab652822bed1ed09",
1395
+ "value": " 32332/32332 [00:01&lt;00:00, 27628.23 examples/s]"
1396
+ }
1397
+ },
1398
+ "b3a8424c0b584a37ad2ede748085425c": {
1399
+ "model_module": "@jupyter-widgets/controls",
1400
+ "model_module_version": "1.5.0",
1401
+ "model_name": "ProgressStyleModel",
1402
+ "state": {
1403
+ "_model_module": "@jupyter-widgets/controls",
1404
+ "_model_module_version": "1.5.0",
1405
+ "_model_name": "ProgressStyleModel",
1406
+ "_view_count": null,
1407
+ "_view_module": "@jupyter-widgets/base",
1408
+ "_view_module_version": "1.2.0",
1409
+ "_view_name": "StyleView",
1410
+ "bar_color": null,
1411
+ "description_width": ""
1412
+ }
1413
+ },
1414
+ "bbda5260ca1c450386f9191e9f9dde97": {
1415
+ "model_module": "@jupyter-widgets/controls",
1416
+ "model_module_version": "1.5.0",
1417
+ "model_name": "ProgressStyleModel",
1418
+ "state": {
1419
+ "_model_module": "@jupyter-widgets/controls",
1420
+ "_model_module_version": "1.5.0",
1421
+ "_model_name": "ProgressStyleModel",
1422
+ "_view_count": null,
1423
+ "_view_module": "@jupyter-widgets/base",
1424
+ "_view_module_version": "1.2.0",
1425
+ "_view_name": "StyleView",
1426
+ "bar_color": null,
1427
+ "description_width": ""
1428
+ }
1429
+ },
1430
+ "c020b38c6d2c436e8b742fd87d3b8b89": {
1431
+ "model_module": "@jupyter-widgets/base",
1432
+ "model_module_version": "1.2.0",
1433
+ "model_name": "LayoutModel",
1434
+ "state": {
1435
+ "_model_module": "@jupyter-widgets/base",
1436
+ "_model_module_version": "1.2.0",
1437
+ "_model_name": "LayoutModel",
1438
+ "_view_count": null,
1439
+ "_view_module": "@jupyter-widgets/base",
1440
+ "_view_module_version": "1.2.0",
1441
+ "_view_name": "LayoutView",
1442
+ "align_content": null,
1443
+ "align_items": null,
1444
+ "align_self": null,
1445
+ "border": null,
1446
+ "bottom": null,
1447
+ "display": null,
1448
+ "flex": null,
1449
+ "flex_flow": null,
1450
+ "grid_area": null,
1451
+ "grid_auto_columns": null,
1452
+ "grid_auto_flow": null,
1453
+ "grid_auto_rows": null,
1454
+ "grid_column": null,
1455
+ "grid_gap": null,
1456
+ "grid_row": null,
1457
+ "grid_template_areas": null,
1458
+ "grid_template_columns": null,
1459
+ "grid_template_rows": null,
1460
+ "height": null,
1461
+ "justify_content": null,
1462
+ "justify_items": null,
1463
+ "left": null,
1464
+ "margin": null,
1465
+ "max_height": null,
1466
+ "max_width": null,
1467
+ "min_height": null,
1468
+ "min_width": null,
1469
+ "object_fit": null,
1470
+ "object_position": null,
1471
+ "order": null,
1472
+ "overflow": null,
1473
+ "overflow_x": null,
1474
+ "overflow_y": null,
1475
+ "padding": null,
1476
+ "right": null,
1477
+ "top": null,
1478
+ "visibility": null,
1479
+ "width": null
1480
+ }
1481
+ },
1482
+ "c2d14fa4280c48e0ae04859b73c80781": {
1483
+ "model_module": "@jupyter-widgets/controls",
1484
+ "model_module_version": "1.5.0",
1485
+ "model_name": "HBoxModel",
1486
+ "state": {
1487
+ "_dom_classes": [],
1488
+ "_model_module": "@jupyter-widgets/controls",
1489
+ "_model_module_version": "1.5.0",
1490
+ "_model_name": "HBoxModel",
1491
+ "_view_count": null,
1492
+ "_view_module": "@jupyter-widgets/controls",
1493
+ "_view_module_version": "1.5.0",
1494
+ "_view_name": "HBoxView",
1495
+ "box_style": "",
1496
+ "children": [
1497
+ "IPY_MODEL_d3104837d9734834b7c87e87289b08df",
1498
+ "IPY_MODEL_02b02005adf241a4a0be8173ca3a4aee",
1499
+ "IPY_MODEL_b317ba38f2b145f9b0b49f523547684f"
1500
+ ],
1501
+ "layout": "IPY_MODEL_434340d109d1401d8868498a23b291cf"
1502
+ }
1503
+ },
1504
+ "c88027eb3e1c4771ab57366070ecd553": {
1505
+ "model_module": "@jupyter-widgets/base",
1506
+ "model_module_version": "1.2.0",
1507
+ "model_name": "LayoutModel",
1508
+ "state": {
1509
+ "_model_module": "@jupyter-widgets/base",
1510
+ "_model_module_version": "1.2.0",
1511
+ "_model_name": "LayoutModel",
1512
+ "_view_count": null,
1513
+ "_view_module": "@jupyter-widgets/base",
1514
+ "_view_module_version": "1.2.0",
1515
+ "_view_name": "LayoutView",
1516
+ "align_content": null,
1517
+ "align_items": null,
1518
+ "align_self": null,
1519
+ "border": null,
1520
+ "bottom": null,
1521
+ "display": null,
1522
+ "flex": null,
1523
+ "flex_flow": null,
1524
+ "grid_area": null,
1525
+ "grid_auto_columns": null,
1526
+ "grid_auto_flow": null,
1527
+ "grid_auto_rows": null,
1528
+ "grid_column": null,
1529
+ "grid_gap": null,
1530
+ "grid_row": null,
1531
+ "grid_template_areas": null,
1532
+ "grid_template_columns": null,
1533
+ "grid_template_rows": null,
1534
+ "height": null,
1535
+ "justify_content": null,
1536
+ "justify_items": null,
1537
+ "left": null,
1538
+ "margin": null,
1539
+ "max_height": null,
1540
+ "max_width": null,
1541
+ "min_height": null,
1542
+ "min_width": null,
1543
+ "object_fit": null,
1544
+ "object_position": null,
1545
+ "order": null,
1546
+ "overflow": null,
1547
+ "overflow_x": null,
1548
+ "overflow_y": null,
1549
+ "padding": null,
1550
+ "right": null,
1551
+ "top": null,
1552
+ "visibility": null,
1553
+ "width": null
1554
+ }
1555
+ },
1556
+ "ca588157678e4cc09c3fd760676efd39": {
1557
+ "model_module": "@jupyter-widgets/controls",
1558
+ "model_module_version": "1.5.0",
1559
+ "model_name": "DescriptionStyleModel",
1560
+ "state": {
1561
+ "_model_module": "@jupyter-widgets/controls",
1562
+ "_model_module_version": "1.5.0",
1563
+ "_model_name": "DescriptionStyleModel",
1564
+ "_view_count": null,
1565
+ "_view_module": "@jupyter-widgets/base",
1566
+ "_view_module_version": "1.2.0",
1567
+ "_view_name": "StyleView",
1568
+ "description_width": ""
1569
+ }
1570
+ },
1571
+ "cb7d88a70af746f2ae31416b4b670c63": {
1572
+ "model_module": "@jupyter-widgets/base",
1573
+ "model_module_version": "1.2.0",
1574
+ "model_name": "LayoutModel",
1575
+ "state": {
1576
+ "_model_module": "@jupyter-widgets/base",
1577
+ "_model_module_version": "1.2.0",
1578
+ "_model_name": "LayoutModel",
1579
+ "_view_count": null,
1580
+ "_view_module": "@jupyter-widgets/base",
1581
+ "_view_module_version": "1.2.0",
1582
+ "_view_name": "LayoutView",
1583
+ "align_content": null,
1584
+ "align_items": null,
1585
+ "align_self": null,
1586
+ "border": null,
1587
+ "bottom": null,
1588
+ "display": null,
1589
+ "flex": null,
1590
+ "flex_flow": null,
1591
+ "grid_area": null,
1592
+ "grid_auto_columns": null,
1593
+ "grid_auto_flow": null,
1594
+ "grid_auto_rows": null,
1595
+ "grid_column": null,
1596
+ "grid_gap": null,
1597
+ "grid_row": null,
1598
+ "grid_template_areas": null,
1599
+ "grid_template_columns": null,
1600
+ "grid_template_rows": null,
1601
+ "height": null,
1602
+ "justify_content": null,
1603
+ "justify_items": null,
1604
+ "left": null,
1605
+ "margin": null,
1606
+ "max_height": null,
1607
+ "max_width": null,
1608
+ "min_height": null,
1609
+ "min_width": null,
1610
+ "object_fit": null,
1611
+ "object_position": null,
1612
+ "order": null,
1613
+ "overflow": null,
1614
+ "overflow_x": null,
1615
+ "overflow_y": null,
1616
+ "padding": null,
1617
+ "right": null,
1618
+ "top": null,
1619
+ "visibility": null,
1620
+ "width": null
1621
+ }
1622
+ },
1623
+ "d3104837d9734834b7c87e87289b08df": {
1624
+ "model_module": "@jupyter-widgets/controls",
1625
+ "model_module_version": "1.5.0",
1626
+ "model_name": "HTMLModel",
1627
+ "state": {
1628
+ "_dom_classes": [],
1629
+ "_model_module": "@jupyter-widgets/controls",
1630
+ "_model_module_version": "1.5.0",
1631
+ "_model_name": "HTMLModel",
1632
+ "_view_count": null,
1633
+ "_view_module": "@jupyter-widgets/controls",
1634
+ "_view_module_version": "1.5.0",
1635
+ "_view_name": "HTMLView",
1636
+ "description": "",
1637
+ "description_tooltip": null,
1638
+ "layout": "IPY_MODEL_2c95f5b81fc84ad698fe77b52cb84076",
1639
+ "placeholder": "​",
1640
+ "style": "IPY_MODEL_ca588157678e4cc09c3fd760676efd39",
1641
+ "value": "Generating train split: 100%"
1642
+ }
1643
+ },
1644
+ "df75b255bfb04057b553830b59f0a153": {
1645
+ "model_module": "@jupyter-widgets/controls",
1646
+ "model_module_version": "1.5.0",
1647
+ "model_name": "ProgressStyleModel",
1648
+ "state": {
1649
+ "_model_module": "@jupyter-widgets/controls",
1650
+ "_model_module_version": "1.5.0",
1651
+ "_model_name": "ProgressStyleModel",
1652
+ "_view_count": null,
1653
+ "_view_module": "@jupyter-widgets/base",
1654
+ "_view_module_version": "1.2.0",
1655
+ "_view_name": "StyleView",
1656
+ "bar_color": null,
1657
+ "description_width": ""
1658
+ }
1659
+ },
1660
+ "ebb7ee3fd084466f9667771a99e6e3b2": {
1661
+ "model_module": "@jupyter-widgets/controls",
1662
+ "model_module_version": "1.5.0",
1663
+ "model_name": "HTMLModel",
1664
+ "state": {
1665
+ "_dom_classes": [],
1666
+ "_model_module": "@jupyter-widgets/controls",
1667
+ "_model_module_version": "1.5.0",
1668
+ "_model_name": "HTMLModel",
1669
+ "_view_count": null,
1670
+ "_view_module": "@jupyter-widgets/controls",
1671
+ "_view_module_version": "1.5.0",
1672
+ "_view_name": "HTMLView",
1673
+ "description": "",
1674
+ "description_tooltip": null,
1675
+ "layout": "IPY_MODEL_f0e5024d0d054c1eb8e01c4c8b027e79",
1676
+ "placeholder": "​",
1677
+ "style": "IPY_MODEL_937ee45f4d634d189c6d95c886e97bca",
1678
+ "value": " 3.30M/3.30M [00:01&lt;00:00, 2.77MB/s]"
1679
+ }
1680
+ },
1681
+ "ec2051bf0e9343d394e8a0ecb4fd5ec8": {
1682
+ "model_module": "@jupyter-widgets/base",
1683
+ "model_module_version": "1.2.0",
1684
+ "model_name": "LayoutModel",
1685
+ "state": {
1686
+ "_model_module": "@jupyter-widgets/base",
1687
+ "_model_module_version": "1.2.0",
1688
+ "_model_name": "LayoutModel",
1689
+ "_view_count": null,
1690
+ "_view_module": "@jupyter-widgets/base",
1691
+ "_view_module_version": "1.2.0",
1692
+ "_view_name": "LayoutView",
1693
+ "align_content": null,
1694
+ "align_items": null,
1695
+ "align_self": null,
1696
+ "border": null,
1697
+ "bottom": null,
1698
+ "display": null,
1699
+ "flex": null,
1700
+ "flex_flow": null,
1701
+ "grid_area": null,
1702
+ "grid_auto_columns": null,
1703
+ "grid_auto_flow": null,
1704
+ "grid_auto_rows": null,
1705
+ "grid_column": null,
1706
+ "grid_gap": null,
1707
+ "grid_row": null,
1708
+ "grid_template_areas": null,
1709
+ "grid_template_columns": null,
1710
+ "grid_template_rows": null,
1711
+ "height": null,
1712
+ "justify_content": null,
1713
+ "justify_items": null,
1714
+ "left": null,
1715
+ "margin": null,
1716
+ "max_height": null,
1717
+ "max_width": null,
1718
+ "min_height": null,
1719
+ "min_width": null,
1720
+ "object_fit": null,
1721
+ "object_position": null,
1722
+ "order": null,
1723
+ "overflow": null,
1724
+ "overflow_x": null,
1725
+ "overflow_y": null,
1726
+ "padding": null,
1727
+ "right": null,
1728
+ "top": null,
1729
+ "visibility": null,
1730
+ "width": null
1731
+ }
1732
+ },
1733
+ "f0e5024d0d054c1eb8e01c4c8b027e79": {
1734
+ "model_module": "@jupyter-widgets/base",
1735
+ "model_module_version": "1.2.0",
1736
+ "model_name": "LayoutModel",
1737
+ "state": {
1738
+ "_model_module": "@jupyter-widgets/base",
1739
+ "_model_module_version": "1.2.0",
1740
+ "_model_name": "LayoutModel",
1741
+ "_view_count": null,
1742
+ "_view_module": "@jupyter-widgets/base",
1743
+ "_view_module_version": "1.2.0",
1744
+ "_view_name": "LayoutView",
1745
+ "align_content": null,
1746
+ "align_items": null,
1747
+ "align_self": null,
1748
+ "border": null,
1749
+ "bottom": null,
1750
+ "display": null,
1751
+ "flex": null,
1752
+ "flex_flow": null,
1753
+ "grid_area": null,
1754
+ "grid_auto_columns": null,
1755
+ "grid_auto_flow": null,
1756
+ "grid_auto_rows": null,
1757
+ "grid_column": null,
1758
+ "grid_gap": null,
1759
+ "grid_row": null,
1760
+ "grid_template_areas": null,
1761
+ "grid_template_columns": null,
1762
+ "grid_template_rows": null,
1763
+ "height": null,
1764
+ "justify_content": null,
1765
+ "justify_items": null,
1766
+ "left": null,
1767
+ "margin": null,
1768
+ "max_height": null,
1769
+ "max_width": null,
1770
+ "min_height": null,
1771
+ "min_width": null,
1772
+ "object_fit": null,
1773
+ "object_position": null,
1774
+ "order": null,
1775
+ "overflow": null,
1776
+ "overflow_x": null,
1777
+ "overflow_y": null,
1778
+ "padding": null,
1779
+ "right": null,
1780
+ "top": null,
1781
+ "visibility": null,
1782
+ "width": null
1783
+ }
1784
+ },
1785
+ "f7359467b0214c5385de8ee4334f7ba3": {
1786
+ "model_module": "@jupyter-widgets/controls",
1787
+ "model_module_version": "1.5.0",
1788
+ "model_name": "HTMLModel",
1789
+ "state": {
1790
+ "_dom_classes": [],
1791
+ "_model_module": "@jupyter-widgets/controls",
1792
+ "_model_module_version": "1.5.0",
1793
+ "_model_name": "HTMLModel",
1794
+ "_view_count": null,
1795
+ "_view_module": "@jupyter-widgets/controls",
1796
+ "_view_module_version": "1.5.0",
1797
+ "_view_name": "HTMLView",
1798
+ "description": "",
1799
+ "description_tooltip": null,
1800
+ "layout": "IPY_MODEL_05240e68c55a458286f43967e7f90889",
1801
+ "placeholder": "​",
1802
+ "style": "IPY_MODEL_8cfa6df0ee654643bfdb4a3825e8fbbe",
1803
+ "value": "Downloading readme: 100%"
1804
+ }
1805
+ },
1806
+ "f74bdeb79a224de8b1c85f4ca8657331": {
1807
+ "model_module": "@jupyter-widgets/controls",
1808
+ "model_module_version": "1.5.0",
1809
+ "model_name": "HTMLModel",
1810
+ "state": {
1811
+ "_dom_classes": [],
1812
+ "_model_module": "@jupyter-widgets/controls",
1813
+ "_model_module_version": "1.5.0",
1814
+ "_model_name": "HTMLModel",
1815
+ "_view_count": null,
1816
+ "_view_module": "@jupyter-widgets/controls",
1817
+ "_view_module_version": "1.5.0",
1818
+ "_view_name": "HTMLView",
1819
+ "description": "",
1820
+ "description_tooltip": null,
1821
+ "layout": "IPY_MODEL_888a323362ae4daeac99915bcb3dcf10",
1822
+ "placeholder": "​",
1823
+ "style": "IPY_MODEL_4d0e364e9f274e8ea7447e4e01c7f28f",
1824
+ "value": "Downloading metadata: 100%"
1825
+ }
1826
+ }
1827
+ }
1828
+ }
1829
+ },
1830
+ "nbformat": 4,
1831
+ "nbformat_minor": 0
1832
+ }
Others/attention_visual.ipynb ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import torch.nn as nn\n",
11
+ "from model import Transformer\n",
12
+ "from config import get_config, get_weights_file_path\n",
13
+ "from train import get_model, get_ds, greedy_decode\n",
14
+ "import altair as alt\n",
15
+ "import pandas as pd\n",
16
+ "import numpy as np\n",
17
+ "import warnings\n",
18
+ "warnings.filterwarnings(\"ignore\")"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "# Define the device\n",
28
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
29
+ "print(\"Using device:\", device)"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "config = get_config()\n",
39
+ "train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)\n",
40
+ "model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)\n",
41
+ "\n",
42
+ "# Load the pretrained weights\n",
43
+ "model_filename = get_weights_file_path(config, f\"29\")\n",
44
+ "state = torch.load(model_filename)\n",
45
+ "model.load_state_dict(state['model_state_dict'])"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "def load_next_batch():\n",
55
+ " # Load a sample batch from the validation set\n",
56
+ " batch = next(iter(val_dataloader))\n",
57
+ " encoder_input = batch[\"encoder_input\"].to(device)\n",
58
+ " encoder_mask = batch[\"encoder_mask\"].to(device)\n",
59
+ " decoder_input = batch[\"decoder_input\"].to(device)\n",
60
+ " decoder_mask = batch[\"decoder_mask\"].to(device)\n",
61
+ "\n",
62
+ " encoder_input_tokens = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]\n",
63
+ " decoder_input_tokens = [vocab_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]\n",
64
+ "\n",
65
+ " # check that the batch size is 1\n",
66
+ " assert encoder_input.size(\n",
67
+ " 0) == 1, \"Batch size must be 1 for validation\"\n",
68
+ "\n",
69
+ " model_out = greedy_decode(\n",
70
+ " model, encoder_input, encoder_mask, vocab_src, vocab_tgt, config['seq_len'], device)\n",
71
+ " \n",
72
+ " return batch, encoder_input_tokens, decoder_input_tokens"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "def mtx2df(m, max_row, max_col, row_tokens, col_tokens):\n",
82
+ " return pd.DataFrame(\n",
83
+ " [\n",
84
+ " (\n",
85
+ " r,\n",
86
+ " c,\n",
87
+ " float(m[r, c]),\n",
88
+ " \"%.3d %s\" % (r, row_tokens[r] if len(row_tokens) > r else \"<blank>\"),\n",
89
+ " \"%.3d %s\" % (c, col_tokens[c] if len(col_tokens) > c else \"<blank>\"),\n",
90
+ " )\n",
91
+ " for r in range(m.shape[0])\n",
92
+ " for c in range(m.shape[1])\n",
93
+ " if r < max_row and c < max_col\n",
94
+ " ],\n",
95
+ " columns=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
96
+ " )\n",
97
+ "\n",
98
+ "def get_attn_map(attn_type: str, layer: int, head: int):\n",
99
+ " if attn_type == \"encoder\":\n",
100
+ " attn = model.encoder.layers[layer].self_attention_block.attention_scores\n",
101
+ " elif attn_type == \"decoder\":\n",
102
+ " attn = model.decoder.layers[layer].self_attention_block.attention_scores\n",
103
+ " elif attn_type == \"encoder-decoder\":\n",
104
+ " attn = model.decoder.layers[layer].cross_attention_block.attention_scores\n",
105
+ " return attn[0, head].data\n",
106
+ "\n",
107
+ "def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):\n",
108
+ " df = mtx2df(\n",
109
+ " get_attn_map(attn_type, layer, head),\n",
110
+ " max_sentence_len,\n",
111
+ " max_sentence_len,\n",
112
+ " row_tokens,\n",
113
+ " col_tokens,\n",
114
+ " )\n",
115
+ " return (\n",
116
+ " alt.Chart(data=df)\n",
117
+ " .mark_rect()\n",
118
+ " .encode(\n",
119
+ " x=alt.X(\"col_token\", axis=alt.Axis(title=\"\")),\n",
120
+ " y=alt.Y(\"row_token\", axis=alt.Axis(title=\"\")),\n",
121
+ " color=\"value\",\n",
122
+ " tooltip=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
123
+ " )\n",
124
+ " #.title(f\"Layer {layer} Head {head}\")\n",
125
+ " .properties(height=400, width=400, title=f\"Layer {layer} Head {head}\")\n",
126
+ " .interactive()\n",
127
+ " )\n",
128
+ "\n",
129
+ "def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):\n",
130
+ " charts = []\n",
131
+ " for layer in layers:\n",
132
+ " rowCharts = []\n",
133
+ " for head in heads:\n",
134
+ " rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))\n",
135
+ " charts.append(alt.hconcat(*rowCharts))\n",
136
+ " return alt.vconcat(*charts)"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()\n",
146
+ "print(f'Source: {batch[\"src_text\"][0]}')\n",
147
+ "print(f'Target: {batch[\"tgt_text\"][0]}')\n",
148
+ "sentence_len = encoder_input_tokens.index(\"[PAD]\")"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "layers = [0, 1, 2]\n",
158
+ "heads = [0, 1, 2, 3, 4, 5, 6, 7]\n",
159
+ "\n",
160
+ "# Encoder Self-Attention\n",
161
+ "get_all_attention_maps(\"encoder\", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))\n"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "# Encoder Self-Attention\n",
171
+ "get_all_attention_maps(\"decoder\", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "# Encoder Self-Attention\n",
181
+ "get_all_attention_maps(\"encoder-decoder\", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, sentence_len))"
182
+ ]
183
+ }
184
+ ],
185
+ "metadata": {
186
+ "kernelspec": {
187
+ "display_name": "transformer",
188
+ "language": "python",
189
+ "name": "python3"
190
+ },
191
+ "language_info": {
192
+ "codemirror_mode": {
193
+ "name": "ipython",
194
+ "version": 3
195
+ },
196
+ "file_extension": ".py",
197
+ "mimetype": "text/x-python",
198
+ "name": "python",
199
+ "nbconvert_exporter": "python",
200
+ "pygments_lexer": "ipython3",
201
+ "version": "3.10.6"
202
+ },
203
+ "orig_nbformat": 4
204
+ },
205
+ "nbformat": 4,
206
+ "nbformat_minor": 2
207
+ }
Others/conda.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may be used to create an environment using:
2
+ # $ conda create --name <env> --file <this file>
3
+ # platform: linux-64
4
+ @EXPLICIT
5
+ https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.conda
6
+ https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2023.08.22-h06a4308_0.conda
7
+ https://repo.anaconda.com/pkgs/main/linux-64/ld_impl_linux-64-2.38-h1181459_1.conda
8
+ https://repo.anaconda.com/pkgs/main/linux-64/libstdcxx-ng-11.2.0-h1234567_1.conda
9
+ https://repo.anaconda.com/pkgs/main/noarch/tzdata-2023c-h04d1e81_0.conda
10
+ https://repo.anaconda.com/pkgs/main/linux-64/libgomp-11.2.0-h1234567_1.conda
11
+ https://repo.anaconda.com/pkgs/main/linux-64/_openmp_mutex-5.1-1_gnu.conda
12
+ https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-11.2.0-h1234567_1.conda
13
+ https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.4.4-h6a678d5_0.conda
14
+ https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.4-h6a678d5_0.conda
15
+ https://repo.anaconda.com/pkgs/main/linux-64/openssl-3.0.12-h7f8727e_0.conda
16
+ https://repo.anaconda.com/pkgs/main/linux-64/xz-5.4.5-h5eee18b_0.conda
17
+ https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.13-h5eee18b_0.conda
18
+ https://repo.anaconda.com/pkgs/main/linux-64/readline-8.2-h5eee18b_0.conda
19
+ https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.12-h1ccaba5_0.conda
20
+ https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.41.2-h5eee18b_0.conda
21
+ https://repo.anaconda.com/pkgs/main/linux-64/python-3.9.18-h955ad1f_0.conda
22
+ https://repo.anaconda.com/pkgs/main/linux-64/setuptools-68.0.0-py39h06a4308_0.conda
23
+ https://repo.anaconda.com/pkgs/main/linux-64/wheel-0.41.2-py39h06a4308_0.conda
24
+ https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3.1-py39h06a4308_0.conda
Others/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Use python 3.9
2
+
3
+ torch==2.0.1
4
+ torchvision==0.15.2
5
+ torchaudio==2.0.2
6
+ torchtext==0.15.2
7
+ datasets==2.15.0
8
+ tokenizers==0.13.3
9
+ torchmetrics==1.0.3
10
+ tensorboard==2.13.0
11
+ altair==5.1.1
12
+ wandb==0.15.9
Others/train_wb.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import build_transformer
2
+ from dataset import BilingualDataset, causal_mask
3
+ from config import get_config, get_weights_file_path
4
+
5
+ import torchtext.datasets as datasets
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Dataset, DataLoader, random_split
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+
11
+ import warnings
12
+ from tqdm import tqdm
13
+ import os
14
+ from pathlib import Path
15
+
16
+ # Huggingface datasets and tokenizers
17
+ from datasets import load_dataset
18
+ from tokenizers import Tokenizer
19
+ from tokenizers.models import WordLevel
20
+ from tokenizers.trainers import WordLevelTrainer
21
+ from tokenizers.pre_tokenizers import Whitespace
22
+
23
+ import wandb
24
+
25
+ import torchmetrics
26
+
27
+ def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
28
+ sos_idx = tokenizer_tgt.token_to_id('[SOS]')
29
+ eos_idx = tokenizer_tgt.token_to_id('[EOS]')
30
+
31
+ # Precompute the encoder output and reuse it for every step
32
+ encoder_output = model.encode(source, source_mask)
33
+ # Initialize the decoder input with the sos token
34
+ decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
35
+ while True:
36
+ if decoder_input.size(1) == max_len:
37
+ break
38
+
39
+ # build mask for target
40
+ decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
41
+
42
+ # calculate output
43
+ out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
44
+
45
+ # get next token
46
+ prob = model.project(out[:, -1])
47
+ _, next_word = torch.max(prob, dim=1)
48
+ decoder_input = torch.cat(
49
+ [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
50
+ )
51
+
52
+ if next_word == eos_idx:
53
+ break
54
+
55
+ return decoder_input.squeeze(0)
56
+
57
+
58
+ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, num_examples=2):
59
+ model.eval()
60
+ count = 0
61
+
62
+ source_texts = []
63
+ expected = []
64
+ predicted = []
65
+
66
+ try:
67
+ # get the console window width
68
+ with os.popen('stty size', 'r') as console:
69
+ _, console_width = console.read().split()
70
+ console_width = int(console_width)
71
+ except:
72
+ # If we can't get the console width, use 80 as default
73
+ console_width = 80
74
+
75
+ with torch.no_grad():
76
+ for batch in validation_ds:
77
+ count += 1
78
+ encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
79
+ encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)
80
+
81
+ # check that the batch size is 1
82
+ assert encoder_input.size(
83
+ 0) == 1, "Batch size must be 1 for validation"
84
+
85
+ model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
86
+
87
+ source_text = batch["src_text"][0]
88
+ target_text = batch["tgt_text"][0]
89
+ model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
90
+
91
+ source_texts.append(source_text)
92
+ expected.append(target_text)
93
+ predicted.append(model_out_text)
94
+
95
+ # Print the source, target and model output
96
+ print_msg('-'*console_width)
97
+ print_msg(f"{f'SOURCE: ':>12}{source_text}")
98
+ print_msg(f"{f'TARGET: ':>12}{target_text}")
99
+ print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
100
+
101
+ if count == num_examples:
102
+ print_msg('-'*console_width)
103
+ break
104
+
105
+
106
+ # Evaluate the character error rate
107
+ # Compute the char error rate
108
+ metric = torchmetrics.CharErrorRate()
109
+ cer = metric(predicted, expected)
110
+ wandb.log({'validation/cer': cer, 'global_step': global_step})
111
+
112
+ # Compute the word error rate
113
+ metric = torchmetrics.WordErrorRate()
114
+ wer = metric(predicted, expected)
115
+ wandb.log({'validation/wer': wer, 'global_step': global_step})
116
+
117
+ # Compute the BLEU metric
118
+ metric = torchmetrics.BLEUScore()
119
+ bleu = metric(predicted, expected)
120
+ wandb.log({'validation/BLEU': bleu, 'global_step': global_step})
121
+
122
+ def get_all_sentences(ds, lang):
123
+ for item in ds:
124
+ yield item['translation'][lang]
125
+
126
+ def get_or_build_tokenizer(config, ds, lang):
127
+ tokenizer_path = Path(config['tokenizer_file'].format(lang))
128
+ if not Path.exists(tokenizer_path):
129
+ # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
130
+ tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
131
+ tokenizer.pre_tokenizer = Whitespace()
132
+ trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
133
+ tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
134
+ tokenizer.save(str(tokenizer_path))
135
+ else:
136
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
137
+ return tokenizer
138
+
139
+ def get_ds(config):
140
+ # It only has the train split, so we divide it overselves
141
+ ds_raw = load_dataset('opus_books', f"{config['lang_src']}-{config['lang_tgt']}", split='train')
142
+
143
+ # Build tokenizers
144
+ tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
145
+ tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])
146
+
147
+ # Keep 90% for training, 10% for validation
148
+ train_ds_size = int(0.9 * len(ds_raw))
149
+ val_ds_size = len(ds_raw) - train_ds_size
150
+ train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
151
+
152
+ train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
153
+ val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
154
+
155
+ # Find the maximum length of each sentence in the source and target sentence
156
+ max_len_src = 0
157
+ max_len_tgt = 0
158
+
159
+ for item in ds_raw:
160
+ src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
161
+ tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
162
+ max_len_src = max(max_len_src, len(src_ids))
163
+ max_len_tgt = max(max_len_tgt, len(tgt_ids))
164
+
165
+ print(f'Max length of source sentence: {max_len_src}')
166
+ print(f'Max length of target sentence: {max_len_tgt}')
167
+
168
+
169
+ train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
170
+ val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
171
+
172
+ return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
173
+
174
+ def get_model(config, vocab_src_len, vocab_tgt_len):
175
+ model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
176
+ return model
177
+
178
+ def train_model(config):
179
+ # Define the device
180
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
181
+ print("Using device:", device)
182
+
183
+ # Make sure the weights folder exists
184
+ Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
185
+
186
+ train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
187
+ model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
188
+
189
+ optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
190
+
191
+ # If the user specified a model to preload before training, load it
192
+ initial_epoch = 0
193
+ global_step = 0
194
+ if config['preload']:
195
+ model_filename = get_weights_file_path(config, config['preload'])
196
+ print(f'Preloading model {model_filename}')
197
+ state = torch.load(model_filename)
198
+ model.load_state_dict(state['model_state_dict'])
199
+ initial_epoch = state['epoch'] + 1
200
+ optimizer.load_state_dict(state['optimizer_state_dict'])
201
+ global_step = state['global_step']
202
+ del state
203
+
204
+ loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
205
+
206
+ # define our custom x axis metric
207
+ wandb.define_metric("global_step")
208
+ # define which metrics will be plotted against it
209
+ wandb.define_metric("validation/*", step_metric="global_step")
210
+ wandb.define_metric("train/*", step_metric="global_step")
211
+
212
+ for epoch in range(initial_epoch, config['num_epochs']):
213
+ torch.cuda.empty_cache()
214
+ model.train()
215
+ batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
216
+ for batch in batch_iterator:
217
+
218
+ encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
219
+ decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
220
+ encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
221
+ decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
222
+
223
+ # Run the tensors through the encoder, decoder and the projection layer
224
+ encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
225
+ decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
226
+ proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)
227
+
228
+ # Compare the output with the label
229
+ label = batch['label'].to(device) # (B, seq_len)
230
+
231
+ # Compute the loss using a simple cross entropy
232
+ loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
233
+ batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
234
+
235
+ # Log the loss
236
+ wandb.log({'train/loss': loss.item(), 'global_step': global_step})
237
+
238
+ # Backpropagate the loss
239
+ loss.backward()
240
+
241
+ # Update the weights
242
+ optimizer.step()
243
+ optimizer.zero_grad(set_to_none=True)
244
+
245
+ global_step += 1
246
+
247
+ # Run validation at the end of every epoch
248
+ run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step)
249
+
250
+ # Save the model at the end of every epoch
251
+ model_filename = get_weights_file_path(config, f"{epoch:02d}")
252
+ torch.save({
253
+ 'epoch': epoch,
254
+ 'model_state_dict': model.state_dict(),
255
+ 'optimizer_state_dict': optimizer.state_dict(),
256
+ 'global_step': global_step
257
+ }, model_filename)
258
+
259
+
260
+ if __name__ == '__main__':
261
+ warnings.filterwarnings("ignore")
262
+ config = get_config()
263
+ config['num_epochs'] = 30
264
+ config['preload'] = None
265
+
266
+ wandb.init(
267
+ # set the wandb project where this run will be logged
268
+ project="pytorch-transformer",
269
+
270
+ # track hyperparameters and run metadata
271
+ config=config
272
+ )
273
+
274
+ train_model(config)
Others/translate.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from config import get_config, latest_weights_file_path
3
+ from model import build_transformer
4
+ from tokenizers import Tokenizer
5
+ from datasets import load_dataset
6
+ from dataset import BilingualDataset
7
+ import torch
8
+ import sys
9
+
10
+ def translate(sentence: str):
11
+ # Define the device, tokenizers, and model
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print("Using device:", device)
14
+ config = get_config()
15
+ tokenizer_src = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_src']))))
16
+ tokenizer_tgt = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_tgt']))))
17
+ model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device)
18
+
19
+ # Load the pretrained weights
20
+ model_filename = latest_weights_file_path(config)
21
+ state = torch.load(model_filename)
22
+ model.load_state_dict(state['model_state_dict'])
23
+
24
+ # if the sentence is a number use it as an index to the test set
25
+ label = ""
26
+ if type(sentence) == int or sentence.isdigit():
27
+ id = int(sentence)
28
+ ds = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='all')
29
+ ds = BilingualDataset(ds, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
30
+ sentence = ds[id]['src_text']
31
+ label = ds[id]["tgt_text"]
32
+ seq_len = config['seq_len']
33
+
34
+ # translate the sentence
35
+ model.eval()
36
+ with torch.no_grad():
37
+ # Precompute the encoder output and reuse it for every generation step
38
+ source = tokenizer_src.encode(sentence)
39
+ source = torch.cat([
40
+ torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64),
41
+ torch.tensor(source.ids, dtype=torch.int64),
42
+ torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
43
+ torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (seq_len - len(source.ids) - 2), dtype=torch.int64)
44
+ ], dim=0).to(device)
45
+ source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
46
+ encoder_output = model.encode(source, source_mask)
47
+
48
+ # Initialize the decoder input with the sos token
49
+ decoder_input = torch.empty(1, 1).fill_(tokenizer_tgt.token_to_id('[SOS]')).type_as(source).to(device)
50
+
51
+ # Print the source sentence and target start prompt
52
+ if label != "": print(f"{f'ID: ':>12}{id}")
53
+ print(f"{f'SOURCE: ':>12}{sentence}")
54
+ if label != "": print(f"{f'TARGET: ':>12}{label}")
55
+ print(f"{f'PREDICTED: ':>12}", end='')
56
+
57
+ # Generate the translation word by word
58
+ while decoder_input.size(1) < seq_len:
59
+ # build mask for target and calculate output
60
+ decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1).type(torch.int).type_as(source_mask).to(device)
61
+ out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
62
+
63
+ # project next token
64
+ prob = model.project(out[:, -1])
65
+ _, next_word = torch.max(prob, dim=1)
66
+ decoder_input = torch.cat([decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1)
67
+
68
+ # print the translated word
69
+ print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ')
70
+
71
+ # break if we predict the end of sentence token
72
+ if next_word == tokenizer_tgt.token_to_id('[EOS]'):
73
+ break
74
+
75
+ # convert ids to tokens
76
+ return tokenizer_tgt.decode(decoder_input[0].tolist())
77
+
78
+ #read sentence from argument
79
+ translate(sys.argv[1] if len(sys.argv) > 1 else "I am not a very good a student.")