vjt commited on
Commit
0f303b0
1 Parent(s): 43d7c98

Training in progress, epoch 1

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ checkpoint-*/
.ipynb_checkpoints/T5Train-checkpoint.ipynb ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "3ef6a441",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Requirement already satisfied: nltk in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (3.8.1)\n",
14
+ "Requirement already satisfied: click in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from nltk) (8.1.3)\n",
15
+ "Requirement already satisfied: tqdm in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from nltk) (4.64.1)\n",
16
+ "Requirement already satisfied: joblib in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from nltk) (1.2.0)\n",
17
+ "Requirement already satisfied: regex>=2021.8.3 in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from nltk) (2022.10.31)\n",
18
+ "Requirement already satisfied: colorama in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from click->nltk) (0.4.6)\n",
19
+ "Requirement already satisfied: rouge_score in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (0.1.2)\n",
20
+ "Requirement already satisfied: numpy in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from rouge_score) (1.24.1)\n",
21
+ "Requirement already satisfied: absl-py in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from rouge_score) (1.4.0)\n",
22
+ "Requirement already satisfied: six>=1.14.0 in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from rouge_score) (1.16.0)\n",
23
+ "Requirement already satisfied: nltk in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from rouge_score) (3.8.1)\n",
24
+ "Requirement already satisfied: joblib in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from nltk->rouge_score) (1.2.0)\n",
25
+ "Requirement already satisfied: tqdm in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from nltk->rouge_score) (4.64.1)\n",
26
+ "Requirement already satisfied: regex>=2021.8.3 in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from nltk->rouge_score) (2022.10.31)\n",
27
+ "Requirement already satisfied: click in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from nltk->rouge_score) (8.1.3)\n",
28
+ "Requirement already satisfied: colorama in c:\\users\\vjmar\\documents\\1. code\\pythonenvs\\hf-env\\lib\\site-packages (from click->nltk->rouge_score) (0.4.6)\n"
29
+ ]
30
+ }
31
+ ],
32
+ "source": [
33
+ "# !pip install transformers\n",
34
+ "!pip install nltk\n",
35
+ "!pip install rouge_score\n",
36
+ "\n",
37
+ "%load_ext autoreload\n",
38
+ "%autoreload 2"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "id": "845c8640",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": []
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 2,
52
+ "id": "23e534d2",
53
+ "metadata": {},
54
+ "outputs": [
55
+ {
56
+ "name": "stderr",
57
+ "output_type": "stream",
58
+ "text": [
59
+ "C:\\Users\\vjmar\\Documents\\1. Code\\PythonEnvs\\hf-env\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
60
+ " from .autonotebook import tqdm as notebook_tqdm\n"
61
+ ]
62
+ },
63
+ {
64
+ "name": "stdout",
65
+ "output_type": "stream",
66
+ "text": [
67
+ "| ID | GPU | MEM |\n",
68
+ "------------------\n",
69
+ "| 0 | 5% | 13% |\n",
70
+ "None\n",
71
+ "---------------------------------------------------------------\n",
72
+ "Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.\n",
73
+ "Token is valid.\n",
74
+ "Your token has been saved to C:\\Users\\vjmar\\.cache\\huggingface\\token\n",
75
+ "Login successful\n"
76
+ ]
77
+ }
78
+ ],
79
+ "source": [
80
+ "import GPUtil\n",
81
+ "from huggingface_hub import HfApi, HfFolder, login\n",
82
+ "\n",
83
+ "print(GPUtil.showUtilization())\n",
84
+ "print(\"---------------------------------------------------------------\")\n",
85
+ "token = \"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\"\n",
86
+ "login(\"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\")\n",
87
+ "! git config --global credential.helper store"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 3,
93
+ "id": "2b5a41be",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "CKPT = 't5-base'\n",
98
+ "from transformers import AutoTokenizer, T5ForConditionalGeneration\n",
99
+ "model = T5ForConditionalGeneration.from_pretrained(CKPT)"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 4,
105
+ "id": "75c5f40c",
106
+ "metadata": {},
107
+ "outputs": [
108
+ {
109
+ "name": "stderr",
110
+ "output_type": "stream",
111
+ "text": [
112
+ "C:\\Users\\vjmar\\Documents\\1. Code\\PythonEnvs\\hf-env\\lib\\site-packages\\transformers\\models\\t5\\tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
113
+ "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
114
+ "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
115
+ "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
116
+ "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
117
+ " warnings.warn(\n"
118
+ ]
119
+ }
120
+ ],
121
+ "source": [
122
+ "tokenizer = AutoTokenizer.from_pretrained(CKPT)"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "id": "ca3c201b",
128
+ "metadata": {},
129
+ "source": [
130
+ "# Data"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": 5,
136
+ "id": "f9ab72e4",
137
+ "metadata": {},
138
+ "outputs": [
139
+ {
140
+ "name": "stderr",
141
+ "output_type": "stream",
142
+ "text": [
143
+ "Found cached dataset wikisql (C:/Users/vjmar/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)\n",
144
+ "Found cached dataset wikisql (C:/Users/vjmar/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)\n"
145
+ ]
146
+ }
147
+ ],
148
+ "source": [
149
+ "try:\n",
150
+ " from datasets import load_dataset\n",
151
+ "except ModuleNotFoundError:\n",
152
+ " !pip install datasets\n",
153
+ " from datasets import load_dataset\n",
154
+ "\n",
155
+ "train_data = load_dataset('wikisql', split='train+validation')\n",
156
+ "test_data = load_dataset('wikisql', split='test')"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": 6,
162
+ "id": "0e62f295",
163
+ "metadata": {},
164
+ "outputs": [
165
+ {
166
+ "name": "stderr",
167
+ "output_type": "stream",
168
+ "text": [
169
+ "Loading cached processed dataset at C:\\Users\\vjmar\\.cache\\huggingface\\datasets\\wikisql\\default\\0.1.0\\7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d\\cache-19a43a9806773ee1.arrow\n",
170
+ "Loading cached processed dataset at C:\\Users\\vjmar\\.cache\\huggingface\\datasets\\wikisql\\default\\0.1.0\\7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d\\cache-620e43f13a2f425c.arrow\n"
171
+ ]
172
+ }
173
+ ],
174
+ "source": [
175
+ "def format_dataset(example):\n",
176
+ " try:\n",
177
+ " condition:str = example['sql']['conds']['condition'][0]\n",
178
+ " except:\n",
179
+ " condition = \"\"\n",
180
+ " target = f\"{example['sql']['human_readable']}\"\n",
181
+ " \n",
182
+ " if condition.lower() in target.lower() and condition != \"\":\n",
183
+ " target = target.lower().replace(condition.lower(), f\"'{condition}'\")\n",
184
+ "\n",
185
+ " cols = \"\"\n",
186
+ " for item in example['table']['header']:\n",
187
+ " cols = cols + item.lower() + \", \"\n",
188
+ " \n",
189
+ "\n",
190
+ " obj = {'input': f\"translate to SQL: {example['question']} | table: {cols})\".replace(\", )\", \"\" ),\n",
191
+ " \"target\": target}\n",
192
+ " return obj\n",
193
+ "\n",
194
+ "# Apply Data Formatting\n",
195
+ "train_data = train_data.map(format_dataset, remove_columns=train_data.column_names)\n",
196
+ "test_data = test_data.map(format_dataset, remove_columns=test_data.column_names)"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "id": "e68f9896",
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": []
206
+ },
207
+ {
208
+ "cell_type": "markdown",
209
+ "id": "f47e6cd6",
210
+ "metadata": {},
211
+ "source": [
212
+ "# Data Format for Training"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": 12,
218
+ "id": "15ec294c",
219
+ "metadata": {},
220
+ "outputs": [],
221
+ "source": [
222
+ "def map_to_length(x): # map article and summary len to dict as well as if sample is longer than 512 tokens\n",
223
+ " \n",
224
+ " # from transformers import AutoTokenizer \n",
225
+ " # tokenizer = AutoTokenizer.from_pretrained(\"t5-base\") \n",
226
+ " x[\"input_len\"] = len(tokenizer(x[\"input\"]).input_ids)\n",
227
+ " x[\"input_longer_256\"] = int(x[\"input_len\"] > 256)\n",
228
+ " x[\"input_longer_128\"] = int(x[\"input_len\"] > 128)\n",
229
+ " x[\"input_longer_64\"] = int(x[\"input_len\"] > 64)\n",
230
+ " x[\"out_len\"] = len(tokenizer(x[\"target\"]).input_ids)\n",
231
+ " x[\"out_longer_256\"] = int(x[\"out_len\"] > 256)\n",
232
+ " x[\"out_longer_128\"] = int(x[\"out_len\"] > 128)\n",
233
+ " x[\"out_longer_64\"] = int(x[\"out_len\"] > 64)\n",
234
+ " return x\n"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": 13,
240
+ "id": "7b5df2e4",
241
+ "metadata": {},
242
+ "outputs": [
243
+ {
244
+ "name": "stdout",
245
+ "output_type": "stream",
246
+ "text": [
247
+ "<class 'datasets.arrow_dataset.Dataset'>\n"
248
+ ]
249
+ },
250
+ {
251
+ "name": "stderr",
252
+ "output_type": "stream",
253
+ "text": [
254
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:04<00:00, 2380.77ex/s]\n"
255
+ ]
256
+ }
257
+ ],
258
+ "source": [
259
+ "sample_size = 10000\n",
260
+ "print(type(train_data))\n",
261
+ "data_stats = train_data.select(range(sample_size)).map(map_to_length) #, num_proc=4"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": 14,
267
+ "id": "e4589f66",
268
+ "metadata": {},
269
+ "outputs": [
270
+ {
271
+ "name": "stderr",
272
+ "output_type": "stream",
273
+ "text": [
274
+ "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 24.68ba/s]\n",
275
+ "Loading cached processed dataset at C:\\Users\\vjmar\\.cache\\huggingface\\datasets\\wikisql\\default\\0.1.0\\7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d\\cache-aefcd3f1e400ed5a.arrow\n"
276
+ ]
277
+ },
278
+ {
279
+ "name": "stdout",
280
+ "output_type": "stream",
281
+ "text": [
282
+ "Input Mean: 46.515, %-Input > 256:0.0, %-Input > 128:0.0037, %-Input > 64:0.0712 Output Mean:19.1137, %-Output > 256:0.0, %-Output > 128:0.0002, %-Output > 64:0.0007\n"
283
+ ]
284
+ },
285
+ {
286
+ "name": "stderr",
287
+ "output_type": "stream",
288
+ "text": [
289
+ " 0%| | 0/16 [00:00<?, ?ba/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
290
+ "C:\\Users\\vjmar\\Documents\\1. Code\\PythonEnvs\\hf-env\\lib\\site-packages\\transformers\\tokenization_utils_base.py:2339: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
291
+ " warnings.warn(\n",
292
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:04<00:00, 3.88ba/s]\n"
293
+ ]
294
+ }
295
+ ],
296
+ "source": [
297
+ "def compute_and_print_stats(x):\n",
298
+ " if len(x[\"input_len\"]) == sample_size:\n",
299
+ " print(\n",
300
+ " \"Input Mean: {}, %-Input > 256:{}, %-Input > 128:{}, %-Input > 64:{} Output Mean:{}, %-Output > 256:{}, %-Output > 128:{}, %-Output > 64:{}\".format(\n",
301
+ " sum(x[\"input_len\"]) / sample_size,\n",
302
+ " sum(x[\"input_longer_256\"]) / sample_size,\n",
303
+ " sum(x[\"input_longer_128\"]) / sample_size,\n",
304
+ " sum(x[\"input_longer_64\"]) / sample_size, \n",
305
+ " sum(x[\"out_len\"]) / sample_size,\n",
306
+ " sum(x[\"out_longer_256\"]) / sample_size,\n",
307
+ " sum(x[\"out_longer_128\"]) / sample_size,\n",
308
+ " sum(x[\"out_longer_64\"]) / sample_size,\n",
309
+ " )\n",
310
+ " )\n",
311
+ "\n",
312
+ "output = data_stats.map(\n",
313
+ " compute_and_print_stats, \n",
314
+ " batched=True,\n",
315
+ " batch_size=-1,\n",
316
+ ")\n",
317
+ "\n",
318
+ "# tokenize the examples\n",
319
+ "def convert_to_features(example_batch):\n",
320
+ " input_encodings = tokenizer.batch_encode_plus(example_batch['input'], pad_to_max_length=True, max_length=64)\n",
321
+ " target_encodings = tokenizer.batch_encode_plus(example_batch['target'], pad_to_max_length=True, max_length=64)\n",
322
+ "\n",
323
+ " encodings = {\n",
324
+ " 'input_ids': input_encodings['input_ids'], \n",
325
+ " 'attention_mask': input_encodings['attention_mask'],\n",
326
+ " 'labels': target_encodings['input_ids'],\n",
327
+ " 'decoder_attention_mask': target_encodings['attention_mask']\n",
328
+ " }\n",
329
+ "\n",
330
+ " return encodings\n",
331
+ "\n",
332
+ "train_data = train_data.map(convert_to_features, batched=True, remove_columns=train_data.column_names)\n",
333
+ "test_data = test_data.map(convert_to_features, batched=True, remove_columns=test_data.column_names)\n",
334
+ "\n",
335
+ "columns = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask']\n",
336
+ "\n",
337
+ "train_data.set_format(type='torch', columns=columns)\n",
338
+ "test_data.set_format(type='torch', columns=columns)"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "markdown",
343
+ "id": "d439da79",
344
+ "metadata": {},
345
+ "source": [
346
+ "# Trainer"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": 15,
352
+ "id": "f1cee70c",
353
+ "metadata": {},
354
+ "outputs": [],
355
+ "source": [
356
+ "from transformers import Seq2SeqTrainer\n",
357
+ "from transformers import Seq2SeqTrainingArguments\n",
358
+ "import os\n",
359
+ "\n",
360
+ "training_args = Seq2SeqTrainingArguments(\n",
361
+ " output_dir=str(os.getcwd()),\n",
362
+ " per_device_train_batch_size=16,\n",
363
+ " num_train_epochs=5,\n",
364
+ " per_device_eval_batch_size=16,\n",
365
+ " predict_with_generate=True,\n",
366
+ " evaluation_strategy=\"epoch\",\n",
367
+ " do_train=True,\n",
368
+ " do_eval=True,\n",
369
+ " logging_steps=500,\n",
370
+ " save_strategy=\"epoch\",\n",
371
+ " #save_steps=1000,\n",
372
+ " #eval_steps=1000,\n",
373
+ " overwrite_output_dir=True,\n",
374
+ " save_total_limit=3,\n",
375
+ " load_best_model_at_end=True,\n",
376
+ " push_to_hub=True\n",
377
+ " #fp16=True, \n",
378
+ ")"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": 16,
384
+ "id": "4ee61c54",
385
+ "metadata": {},
386
+ "outputs": [
387
+ {
388
+ "name": "stderr",
389
+ "output_type": "stream",
390
+ "text": [
391
+ "C:\\Users\\vjmar\\AppData\\Local\\Temp\\ipykernel_29244\\418146841.py:3: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
392
+ " rouge = load_metric(\"rouge\")\n"
393
+ ]
394
+ }
395
+ ],
396
+ "source": [
397
+ "from datasets import load_metric\n",
398
+ "\n",
399
+ "rouge = load_metric(\"rouge\")\n",
400
+ "\n",
401
+ "def compute_metrics(pred):\n",
402
+ " labels_ids = pred.label_ids\n",
403
+ " pred_ids = pred.predictions\n",
404
+ "\n",
405
+ " # all unnecessary tokens are removed\n",
406
+ " pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
407
+ " labels_ids[labels_ids == -100] = tokenizer.pad_token_id\n",
408
+ " label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)\n",
409
+ "\n",
410
+ " rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=[\"rouge2\"])[\"rouge2\"].mid\n",
411
+ "\n",
412
+ " return {\n",
413
+ " \"rouge2_precision\": round(rouge_output.precision, 4),\n",
414
+ " \"rouge2_recall\": round(rouge_output.recall, 4),\n",
415
+ " \"rouge2_fmeasure\": round(rouge_output.fmeasure, 4),\n",
416
+ " }"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "markdown",
421
+ "id": "f6c0f580",
422
+ "metadata": {},
423
+ "source": [
424
+ "# Define Trainer"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": null,
430
+ "id": "b71acd7c",
431
+ "metadata": {},
432
+ "outputs": [
433
+ {
434
+ "name": "stderr",
435
+ "output_type": "stream",
436
+ "text": [
437
+ "Cloning https://huggingface.co/vjt/T5Training into local empty directory.\n"
438
+ ]
439
+ }
440
+ ],
441
+ "source": [
442
+ "# instantiate trainer\n",
443
+ "trainer = Seq2SeqTrainer(\n",
444
+ " model=model,\n",
445
+ " args=training_args,\n",
446
+ " compute_metrics=compute_metrics,\n",
447
+ " train_dataset=train_data,\n",
448
+ " eval_dataset=test_data,\n",
449
+ ")\n",
450
+ "import os\n",
451
+ "trainer.evaluate()\n",
452
+ "trainer.train()\n",
453
+ "trainer.save_model()\n",
454
+ "tokenizer.save_pretrained(os.getcwd())\n",
455
+ "trainer.create_model_card()\n",
456
+ "trainer.push_to_hub()"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "markdown",
461
+ "id": "76ca29ea",
462
+ "metadata": {},
463
+ "source": [
464
+ "# Test Model"
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "code",
469
+ "execution_count": null,
470
+ "id": "d39e7e80",
471
+ "metadata": {},
472
+ "outputs": [],
473
+ "source": [
474
+ "CKPT = os.join(os.getcwd(), 't5-base-finetuned-wikisql')\n",
475
+ "from transformers import AutoTokenizer, T5ForConditionalGeneration\n",
476
+ "tokenizer = AutoTokenizer.from_pretrained(CKPT)\n",
477
+ "model = T5ForConditionalGeneration.from_pretrained(CKPT)"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "code",
482
+ "execution_count": null,
483
+ "id": "58f4258c",
484
+ "metadata": {},
485
+ "outputs": [],
486
+ "source": [
487
+ "test_data = load_dataset('wikisql', split='test')"
488
+ ]
489
+ },
490
+ {
491
+ "cell_type": "code",
492
+ "execution_count": null,
493
+ "id": "ecb1ddde",
494
+ "metadata": {},
495
+ "outputs": [],
496
+ "source": [
497
+ "def translate_to_sql(text):\n",
498
+ " inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt')\n",
499
+ " input_ids = inputs.input_ids\n",
500
+ " attention_mask = inputs.attention_mask\n",
501
+ " output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)\n",
502
+ "\n",
503
+ " return tokenizer.decode(output[0], skip_special_tokens=True)"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "execution_count": null,
509
+ "id": "506e28e2",
510
+ "metadata": {},
511
+ "outputs": [],
512
+ "source": [
513
+ "for i in range(0,100,10):\n",
514
+ " print('translate to SQL: ' + test_data[i]['question'])\n",
515
+ " print('Predict. :' + translate_to_sql('translate to SQL: ' + test_data[i]['question']))\n",
516
+ " print('Expected: ' + test_data[i]['sql']['human_readable'])\n",
517
+ " print('=================================\\n')"
518
+ ]
519
+ },
520
+ {
521
+ "cell_type": "code",
522
+ "execution_count": null,
523
+ "id": "18f1cdfe",
524
+ "metadata": {},
525
+ "outputs": [],
526
+ "source": [
527
+ "text = \"translate to SQL: Which employee has the highest salary? Columns: employee_id, name, year, parameters, engineer\"\n",
528
+ "translate_to_sql(text)"
529
+ ]
530
+ },
531
+ {
532
+ "cell_type": "code",
533
+ "execution_count": null,
534
+ "id": "8bd0a073",
535
+ "metadata": {},
536
+ "outputs": [],
537
+ "source": []
538
+ }
539
+ ],
540
+ "metadata": {
541
+ "kernelspec": {
542
+ "display_name": "Python 3 (ipykernel)",
543
+ "language": "python",
544
+ "name": "python3"
545
+ },
546
+ "language_info": {
547
+ "codemirror_mode": {
548
+ "name": "ipython",
549
+ "version": 3
550
+ },
551
+ "file_extension": ".py",
552
+ "mimetype": "text/x-python",
553
+ "name": "python",
554
+ "nbconvert_exporter": "python",
555
+ "pygments_lexer": "ipython3",
556
+ "version": "3.8.5"
557
+ }
558
+ },
559
+ "nbformat": 4,
560
+ "nbformat_minor": 5
561
+ }
T5Train.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "t5-base",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 3072,
7
+ "d_kv": 64,
8
+ "d_model": 768,
9
+ "decoder_start_token_id": 0,
10
+ "dense_act_fn": "relu",
11
+ "dropout_rate": 0.1,
12
+ "eos_token_id": 1,
13
+ "feed_forward_proj": "relu",
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "is_gated_act": false,
17
+ "layer_norm_epsilon": 1e-06,
18
+ "model_type": "t5",
19
+ "n_positions": 512,
20
+ "num_decoder_layers": 12,
21
+ "num_heads": 12,
22
+ "num_layers": 12,
23
+ "output_past": true,
24
+ "pad_token_id": 0,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 32,
27
+ "task_specific_params": {
28
+ "summarization": {
29
+ "early_stopping": true,
30
+ "length_penalty": 2.0,
31
+ "max_length": 200,
32
+ "min_length": 30,
33
+ "no_repeat_ngram_size": 3,
34
+ "num_beams": 4,
35
+ "prefix": "summarize: "
36
+ },
37
+ "translation_en_to_de": {
38
+ "early_stopping": true,
39
+ "max_length": 300,
40
+ "num_beams": 4,
41
+ "prefix": "translate English to German: "
42
+ },
43
+ "translation_en_to_fr": {
44
+ "early_stopping": true,
45
+ "max_length": 300,
46
+ "num_beams": 4,
47
+ "prefix": "translate English to French: "
48
+ },
49
+ "translation_en_to_ro": {
50
+ "early_stopping": true,
51
+ "max_length": 300,
52
+ "num_beams": 4,
53
+ "prefix": "translate English to Romanian: "
54
+ }
55
+ },
56
+ "torch_dtype": "float32",
57
+ "transformers_version": "4.26.0",
58
+ "use_cache": true,
59
+ "vocab_size": 32128
60
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0230022aa7695bf04abd9e9df4c8dc672585e9502d74cf44536af76b08d462b3
3
+ size 891702929
runs/Jan27_10-28-12_Vince-Desktop/1674815969.93369/events.out.tfevents.1674815969.Vince-Desktop.29244.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:153031d0285df04013da2988af5558a36310102c42fcbde5500da89058adbe4d
3
+ size 6026
runs/Jan27_10-28-12_Vince-Desktop/events.out.tfevents.1674815969.Vince-Desktop.29244.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc6cbc5aab81cde1107d9d8466367cb802b080ba4d8b81e3d0a462dbd7404e78
3
+ size 6961
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ece606258a66c05a9a0c6ef55f2f476fa63347892d0d444978458027c340a25
3
+ size 3707