gbarone77 commited on
Commit
ef52d5e
1 Parent(s): 139306f

Upload 8 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/mt5-small",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 1024,
7
+ "d_kv": 64,
8
+ "d_model": 512,
9
+ "decoder_start_token_id": 0,
10
+ "dense_act_fn": "gelu_new",
11
+ "dropout_rate": 0.1,
12
+ "eos_token_id": 1,
13
+ "feed_forward_proj": "gated-gelu",
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "is_gated_act": true,
17
+ "layer_norm_epsilon": 1e-06,
18
+ "model_type": "t5",
19
+ "num_decoder_layers": 8,
20
+ "num_heads": 6,
21
+ "num_layers": 8,
22
+ "pad_token_id": 0,
23
+ "relative_attention_max_distance": 128,
24
+ "relative_attention_num_buckets": 32,
25
+ "tie_word_embeddings": false,
26
+ "tokenizer_class": "T5Tokenizer",
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.27.3",
29
+ "use_cache": true,
30
+ "vocab_size": 250112
31
+ }
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "decoder_start_token_id": 0,
3
+ "eos_token_id": 1,
4
+ "pad_token_id": 0,
5
+ "transformers_version": "4.27.3"
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "eos_token": "</s>",
3
+ "pad_token": "<pad>",
4
+ "unk_token": "<unk>"
5
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef78f86560d809067d12bac6c09f19a462cb3af3f54d2b8acbba26e1433125d6
3
+ size 4309802
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abe2b5605652fa74942bdf17bccc1636d8f92ed1b7ef33dce78d25f95ef781e0
3
+ size 16330634
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": null,
3
+ "eos_token": "</s>",
4
+ "extra_ids": 0,
5
+ "model_max_length": 1000000000000000019884624838656,
6
+ "pad_token": "<pad>",
7
+ "sp_model_kwargs": {},
8
+ "special_tokens_map_file": "/root/.cache/huggingface/hub/models--google--mt5-small/snapshots/38f23af8ec210eb6c376d40e9c56bd25a80f195d/special_tokens_map.json",
9
+ "tokenizer_class": "T5Tokenizer",
10
+ "unk_token": "<unk>"
11
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de8cfefd5b49c071bc237de9353ff46ded13e449735463e53a54607af24a0322
3
+ size 3695
txt2sql_mt5_small_training.ipynb ADDED
@@ -0,0 +1,1620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "e9ca44ab-68d4-4361-a7fb-1f887f1b06c0",
7
+ "metadata": {
8
+ "papermill": {
9
+ "duration": 20.056463,
10
+ "end_time": "2023-02-01T13:28:53.560235",
11
+ "exception": false,
12
+ "start_time": "2023-02-01T13:28:33.503772",
13
+ "status": "completed"
14
+ },
15
+ "tags": []
16
+ },
17
+ "outputs": [],
18
+ "source": [
19
+ "!pip install -q transformers datasets"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 2,
25
+ "id": "d5482d72-f55e-4b09-befc-a0b71fb0f6b3",
26
+ "metadata": {
27
+ "papermill": {
28
+ "duration": 0.126709,
29
+ "end_time": "2023-02-01T13:28:53.696755",
30
+ "exception": false,
31
+ "start_time": "2023-02-01T13:28:53.570046",
32
+ "status": "completed"
33
+ },
34
+ "tags": []
35
+ },
36
+ "outputs": [
37
+ {
38
+ "name": "stderr",
39
+ "output_type": "stream",
40
+ "text": [
41
+ "Warning: Unexpected command-line argument -f found.\n",
42
+ "Warning: Unexpected command-line argument /root/.local/share/jupyter/runtime/kernel-92e2dce4-3520-4966-a7b3-b12619e1a0d7.json found.\n"
43
+ ]
44
+ }
45
+ ],
46
+ "source": [
47
+ "import valohai\n",
48
+ "\n",
49
+ "valohai.prepare(\n",
50
+ " step='train-model',\n",
51
+ " image='pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime', \n",
52
+ " default_parameters={ \n",
53
+ " 'epochs': 10,\n",
54
+ " 'model': 'google/mt5-small',\n",
55
+ " }\n",
56
+ ")\n",
57
+ "output_path = valohai.outputs().path('model')"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 3,
63
+ "id": "7d8321e3-caf8-4f1b-8f4e-568df5e9608c",
64
+ "metadata": {
65
+ "papermill": {
66
+ "duration": 1.139645,
67
+ "end_time": "2023-02-01T13:28:54.844272",
68
+ "exception": false,
69
+ "start_time": "2023-02-01T13:28:53.704627",
70
+ "status": "completed"
71
+ },
72
+ "tags": []
73
+ },
74
+ "outputs": [
75
+ {
76
+ "name": "stderr",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "/opt/conda/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
80
+ " from .autonotebook import tqdm as notebook_tqdm\n"
81
+ ]
82
+ },
83
+ {
84
+ "name": "stdout",
85
+ "output_type": "stream",
86
+ "text": [
87
+ "cuda\n"
88
+ ]
89
+ }
90
+ ],
91
+ "source": [
92
+ "import torch\n",
93
+ "\n",
94
+ "torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
95
+ "print(torch_device)"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 4,
101
+ "id": "f4484a17-8ba2-45a8-b537-24c44bb5bb7c",
102
+ "metadata": {
103
+ "papermill": {
104
+ "duration": 0.782457,
105
+ "end_time": "2023-02-01T13:28:55.633345",
106
+ "exception": false,
107
+ "start_time": "2023-02-01T13:28:54.850888",
108
+ "status": "completed"
109
+ },
110
+ "tags": []
111
+ },
112
+ "outputs": [
113
+ {
114
+ "name": "stdout",
115
+ "output_type": "stream",
116
+ "text": [
117
+ "Mon Mar 27 07:02:29 2023 \n",
118
+ "+-----------------------------------------------------------------------------+\n",
119
+ "| NVIDIA-SMI 470.129.06 Driver Version: 470.129.06 CUDA Version: 11.4 |\n",
120
+ "|-------------------------------+----------------------+----------------------+\n",
121
+ "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
122
+ "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
123
+ "| | | MIG M. |\n",
124
+ "|===============================+======================+======================|\n",
125
+ "| 0 NVIDIA RTX A6000 On | 00000000:05:00.0 Off | Off |\n",
126
+ "| 30% 31C P8 15W / 300W | 3MiB / 48685MiB | 0% Default |\n",
127
+ "| | | N/A |\n",
128
+ "+-------------------------------+----------------------+----------------------+\n",
129
+ " \n",
130
+ "+-----------------------------------------------------------------------------+\n",
131
+ "| Processes: |\n",
132
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
133
+ "| ID ID Usage |\n",
134
+ "|=============================================================================|\n",
135
+ "| No running processes found |\n",
136
+ "+-----------------------------------------------------------------------------+\n"
137
+ ]
138
+ }
139
+ ],
140
+ "source": [
141
+ "! nvidia-smi"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": 5,
147
+ "id": "73334e06-3bf2-4e94-9870-fe3a487398c3",
148
+ "metadata": {
149
+ "papermill": {
150
+ "duration": 45.306651,
151
+ "end_time": "2023-02-01T13:29:40.951034",
152
+ "exception": false,
153
+ "start_time": "2023-02-01T13:28:55.644383",
154
+ "status": "completed"
155
+ },
156
+ "tags": []
157
+ },
158
+ "outputs": [
159
+ {
160
+ "name": "stderr",
161
+ "output_type": "stream",
162
+ "text": [
163
+ "Found cached dataset wikisql (/root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)\n",
164
+ "Found cached dataset wikisql (/root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)\n"
165
+ ]
166
+ }
167
+ ],
168
+ "source": [
169
+ "from datasets import load_dataset\n",
170
+ "\n",
171
+ "train_data = load_dataset('wikisql', split='train+validation')\n",
172
+ "test_data = load_dataset('wikisql', split='test')"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 6,
178
+ "id": "cf5379de-aeb5-4a1c-8d23-9ad1e56dc445",
179
+ "metadata": {
180
+ "papermill": {
181
+ "duration": 0.038407,
182
+ "end_time": "2023-02-01T13:29:41.013026",
183
+ "exception": false,
184
+ "start_time": "2023-02-01T13:29:40.974619",
185
+ "status": "completed"
186
+ },
187
+ "tags": []
188
+ },
189
+ "outputs": [],
190
+ "source": [
191
+ "def format_dataset(example):\n",
192
+ " return {'input': 'translate to SQL: ' + example['question'] + ' table ID: ' + ', '.join(str(x) for x in example['table']['header']), 'target': example['sql']['human_readable']}"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": 7,
198
+ "id": "1ce6feef-eab2-4b7a-86f0-c663e5790c5d",
199
+ "metadata": {
200
+ "papermill": {
201
+ "duration": 17.729786,
202
+ "end_time": "2023-02-01T13:29:58.768354",
203
+ "exception": false,
204
+ "start_time": "2023-02-01T13:29:41.038568",
205
+ "status": "completed"
206
+ },
207
+ "tags": []
208
+ },
209
+ "outputs": [
210
+ {
211
+ "name": "stderr",
212
+ "output_type": "stream",
213
+ "text": [
214
+ "Loading cached processed dataset at /root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d/cache-1ea43016a8276f85.arrow\n"
215
+ ]
216
+ }
217
+ ],
218
+ "source": [
219
+ "train_data = train_data.map(format_dataset, remove_columns=train_data.column_names)"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 8,
225
+ "id": "03862b72-56e4-40ab-aae2-81604f69d608",
226
+ "metadata": {
227
+ "papermill": {
228
+ "duration": 4.566604,
229
+ "end_time": "2023-02-01T13:30:03.373278",
230
+ "exception": false,
231
+ "start_time": "2023-02-01T13:29:58.806674",
232
+ "status": "completed"
233
+ },
234
+ "tags": []
235
+ },
236
+ "outputs": [
237
+ {
238
+ "name": "stderr",
239
+ "output_type": "stream",
240
+ "text": [
241
+ "Loading cached processed dataset at /root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d/cache-b9e3da7e258b7aa5.arrow\n"
242
+ ]
243
+ }
244
+ ],
245
+ "source": [
246
+ "test_data = test_data.map(format_dataset, remove_columns=test_data.column_names)"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": 9,
252
+ "id": "6246e5c3-4d91-4c65-9ee9-bfc366339e97",
253
+ "metadata": {
254
+ "tags": []
255
+ },
256
+ "outputs": [
257
+ {
258
+ "name": "stdout",
259
+ "output_type": "stream",
260
+ "text": [
261
+ "Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.7/site-packages (0.1.97)\n",
262
+ "Collecting protobuf==3.20.*\n",
263
+ " Downloading protobuf-3.20.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)\n",
264
+ "\u001b[K |████████████████████████████████| 1.0 MB 4.4 MB/s eta 0:00:01\n",
265
+ "\u001b[?25hInstalling collected packages: protobuf\n",
266
+ " Attempting uninstall: protobuf\n",
267
+ " Found existing installation: protobuf 4.22.1\n",
268
+ " Uninstalling protobuf-4.22.1:\n",
269
+ " Successfully uninstalled protobuf-4.22.1\n",
270
+ "Successfully installed protobuf-3.20.3\n"
271
+ ]
272
+ }
273
+ ],
274
+ "source": [
275
+ "!pip install sentencepiece\n",
276
+ "!pip install protobuf==3.20.*"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": 10,
282
+ "id": "f162ac75-aeda-409a-af8c-f70f5a1d7cbd",
283
+ "metadata": {
284
+ "papermill": {
285
+ "duration": 16.204849,
286
+ "end_time": "2023-02-01T13:30:19.617815",
287
+ "exception": false,
288
+ "start_time": "2023-02-01T13:30:03.412966",
289
+ "status": "completed"
290
+ },
291
+ "tags": []
292
+ },
293
+ "outputs": [
294
+ {
295
+ "name": "stderr",
296
+ "output_type": "stream",
297
+ "text": [
298
+ "/opt/conda/lib/python3.7/site-packages/transformers/convert_slow_tokenizer.py:447: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
299
+ " \"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option\"\n",
300
+ "You are using a model of type mt5 to instantiate a model of type t5. This is not supported for all configurations of models and can yield errors.\n",
301
+ "Downloading pytorch_model.bin: 100%|██████████| 1.20G/1.20G [00:16<00:00, 72.6MB/s]\n",
302
+ "Downloading (…)neration_config.json: 100%|██████████| 147/147 [00:00<00:00, 31.2kB/s]\n"
303
+ ]
304
+ }
305
+ ],
306
+ "source": [
307
+ "CKPT = valohai.parameters(\"model\").value\n",
308
+ "from transformers import AutoTokenizer, T5ForConditionalGeneration\n",
309
+ "tokenizer = AutoTokenizer.from_pretrained(CKPT)\n",
310
+ "model = T5ForConditionalGeneration.from_pretrained(CKPT).to(torch_device)"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": 11,
316
+ "id": "6e2c9c3b-dfd1-4a34-ad77-3c8f69ac4854",
317
+ "metadata": {
318
+ "papermill": {
319
+ "duration": 2.058386,
320
+ "end_time": "2023-02-01T13:30:21.722091",
321
+ "exception": false,
322
+ "start_time": "2023-02-01T13:30:19.663705",
323
+ "status": "completed"
324
+ },
325
+ "tags": []
326
+ },
327
+ "outputs": [
328
+ {
329
+ "name": "stderr",
330
+ "output_type": "stream",
331
+ "text": [
332
+ " "
333
+ ]
334
+ },
335
+ {
336
+ "name": "stdout",
337
+ "output_type": "stream",
338
+ "text": [
339
+ "Input Mean: 47.4798, %-Input > 256:0.0, %-Input > 128:0.001, %-Input > 64:0.0684 Output Mean:19.4288, %-Output > 256:0.0, %-Output > 128:0.0002, %-Output > 64:0.0004\n"
340
+ ]
341
+ },
342
+ {
343
+ "name": "stderr",
344
+ "output_type": "stream",
345
+ "text": [
346
+ "\r"
347
+ ]
348
+ }
349
+ ],
350
+ "source": [
351
+ "# map article and summary len to dict as well as if sample is longer than 512 tokens\n",
352
+ "def map_to_length(x):\n",
353
+ " x[\"input_len\"] = len(tokenizer(x[\"input\"]).input_ids)\n",
354
+ " x[\"input_longer_256\"] = int(x[\"input_len\"] > 256)\n",
355
+ " x[\"input_longer_128\"] = int(x[\"input_len\"] > 128)\n",
356
+ " x[\"input_longer_64\"] = int(x[\"input_len\"] > 64)\n",
357
+ " x[\"out_len\"] = len(tokenizer(x[\"target\"]).input_ids)\n",
358
+ " x[\"out_longer_256\"] = int(x[\"out_len\"] > 256)\n",
359
+ " x[\"out_longer_128\"] = int(x[\"out_len\"] > 128)\n",
360
+ " x[\"out_longer_64\"] = int(x[\"out_len\"] > 64)\n",
361
+ " return x\n",
362
+ "\n",
363
+ "sample_size = 10000\n",
364
+ "data_stats = train_data.select(range(sample_size)).map(map_to_length, num_proc=4)\n",
365
+ "\n",
366
+ "def compute_and_print_stats(x):\n",
367
+ " if len(x[\"input_len\"]) == sample_size:\n",
368
+ " print(\n",
369
+ " \"Input Mean: {}, %-Input > 256:{}, %-Input > 128:{}, %-Input > 64:{} Output Mean:{}, %-Output > 256:{}, %-Output > 128:{}, %-Output > 64:{}\".format(\n",
370
+ " sum(x[\"input_len\"]) / sample_size,\n",
371
+ " sum(x[\"input_longer_256\"]) / sample_size,\n",
372
+ " sum(x[\"input_longer_128\"]) / sample_size,\n",
373
+ " sum(x[\"input_longer_64\"]) / sample_size, \n",
374
+ " sum(x[\"out_len\"]) / sample_size,\n",
375
+ " sum(x[\"out_longer_256\"]) / sample_size,\n",
376
+ " sum(x[\"out_longer_128\"]) / sample_size,\n",
377
+ " sum(x[\"out_longer_64\"]) / sample_size,\n",
378
+ " )\n",
379
+ " )\n",
380
+ "\n",
381
+ "output = data_stats.map(\n",
382
+ " compute_and_print_stats, \n",
383
+ " batched=True,\n",
384
+ " batch_size=-1,\n",
385
+ ") "
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": 12,
391
+ "id": "d6b69f36-bd57-46e4-b77e-a0017ffbf64e",
392
+ "metadata": {
393
+ "papermill": {
394
+ "duration": 0.063495,
395
+ "end_time": "2023-02-01T13:30:21.834853",
396
+ "exception": false,
397
+ "start_time": "2023-02-01T13:30:21.771358",
398
+ "status": "completed"
399
+ },
400
+ "tags": []
401
+ },
402
+ "outputs": [],
403
+ "source": [
404
+ "# tokenize the examples\n",
405
+ "def convert_to_features(example_batch):\n",
406
+ " input_encodings = tokenizer.batch_encode_plus(example_batch['input'], pad_to_max_length=True, max_length=100, truncation=True)\n",
407
+ " target_encodings = tokenizer.batch_encode_plus(example_batch['target'], pad_to_max_length=True, max_length=100, truncation=True)\n",
408
+ "\n",
409
+ " encodings = {\n",
410
+ " 'input_ids': input_encodings['input_ids'], \n",
411
+ " 'attention_mask': input_encodings['attention_mask'],\n",
412
+ " 'labels': target_encodings['input_ids'],\n",
413
+ " 'decoder_attention_mask': target_encodings['attention_mask']\n",
414
+ " }\n",
415
+ "\n",
416
+ " return encodings "
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": 13,
422
+ "id": "67b3b61d-e1ae-435e-8f55-46fa219ea3e2",
423
+ "metadata": {
424
+ "papermill": {
425
+ "duration": 23.172287,
426
+ "end_time": "2023-02-01T13:30:45.056685",
427
+ "exception": false,
428
+ "start_time": "2023-02-01T13:30:21.884398",
429
+ "status": "completed"
430
+ },
431
+ "tags": []
432
+ },
433
+ "outputs": [
434
+ {
435
+ "name": "stderr",
436
+ "output_type": "stream",
437
+ "text": [
438
+ "Map: 0%| | 0/64776 [00:00<?, ? examples/s]/opt/conda/lib/python3.7/site-packages/transformers/tokenization_utils_base.py:2352: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
439
+ " FutureWarning,\n",
440
+ " \r"
441
+ ]
442
+ }
443
+ ],
444
+ "source": [
445
+ "train_data = train_data.map(convert_to_features, batched=True, remove_columns=train_data.column_names)\n",
446
+ "test_data = test_data.map(convert_to_features, batched=True, remove_columns=test_data.column_names)\n",
447
+ "\n",
448
+ "columns = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask']\n",
449
+ "\n",
450
+ "train_data.set_format(type='torch', columns=columns)\n",
451
+ "test_data.set_format(type='torch', columns=columns)"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": 14,
457
+ "id": "69d37693-8c5a-45c2-a9fd-dfab43ed71fa",
458
+ "metadata": {
459
+ "papermill": {
460
+ "duration": 0.106751,
461
+ "end_time": "2023-02-01T13:30:45.221681",
462
+ "exception": false,
463
+ "start_time": "2023-02-01T13:30:45.114930",
464
+ "status": "completed"
465
+ },
466
+ "tags": []
467
+ },
468
+ "outputs": [],
469
+ "source": [
470
+ "from transformers import Seq2SeqTrainer\n",
471
+ "from transformers import Seq2SeqTrainingArguments"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": 15,
477
+ "id": "644e81ec-1c23-4a2d-a488-f9354c237815",
478
+ "metadata": {
479
+ "papermill": {
480
+ "duration": 0.069207,
481
+ "end_time": "2023-02-01T13:30:45.347009",
482
+ "exception": false,
483
+ "start_time": "2023-02-01T13:30:45.277802",
484
+ "status": "completed"
485
+ },
486
+ "tags": []
487
+ },
488
+ "outputs": [],
489
+ "source": [
490
+ "# set training arguments - Feel free to adapt it\n",
491
+ "training_args = Seq2SeqTrainingArguments(\n",
492
+ " output_dir=output_path,\n",
493
+ " per_device_train_batch_size=16,\n",
494
+ " num_train_epochs=valohai.parameters(\"epochs\").value,\n",
495
+ " per_device_eval_batch_size=16,\n",
496
+ " predict_with_generate=True,\n",
497
+ " evaluation_strategy=\"epoch\",\n",
498
+ " do_train=True,\n",
499
+ " do_eval=True,\n",
500
+ " logging_steps=500,\n",
501
+ " save_strategy=\"epoch\",\n",
502
+ " #save_steps=1000,\n",
503
+ " #eval_steps=1000,\n",
504
+ " overwrite_output_dir=True,\n",
505
+ " save_total_limit=1,\n",
506
+ " load_best_model_at_end=True,\n",
507
+ " push_to_hub=False\n",
508
+ " #fp16=True, \n",
509
+ ")"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "code",
514
+ "execution_count": 16,
515
+ "id": "46d2344c-df83-4495-b700-71e1308f60f1",
516
+ "metadata": {
517
+ "papermill": {
518
+ "duration": 4.757895,
519
+ "end_time": "2023-02-01T13:30:50.160794",
520
+ "exception": false,
521
+ "start_time": "2023-02-01T13:30:45.402899",
522
+ "status": "completed"
523
+ },
524
+ "tags": []
525
+ },
526
+ "outputs": [],
527
+ "source": [
528
+ "! pip install -q rouge_score"
529
+ ]
530
+ },
531
+ {
532
+ "cell_type": "code",
533
+ "execution_count": 17,
534
+ "id": "63ca930c-9cd3-4880-beb6-dd44057069bb",
535
+ "metadata": {
536
+ "papermill": {
537
+ "duration": 1.098239,
538
+ "end_time": "2023-02-01T13:30:51.318015",
539
+ "exception": false,
540
+ "start_time": "2023-02-01T13:30:50.219776",
541
+ "status": "completed"
542
+ },
543
+ "tags": []
544
+ },
545
+ "outputs": [
546
+ {
547
+ "name": "stderr",
548
+ "output_type": "stream",
549
+ "text": [
550
+ "/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:2: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
551
+ " \n"
552
+ ]
553
+ }
554
+ ],
555
+ "source": [
556
+ "from datasets import load_metric\n",
557
+ "rouge = load_metric(\"rouge\")\n",
558
+ "\n",
559
+ "def compute_metrics(pred):\n",
560
+ " labels_ids = pred.label_ids\n",
561
+ " pred_ids = pred.predictions\n",
562
+ "\n",
563
+ " # all unnecessary tokens are removed\n",
564
+ " pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
565
+ " labels_ids[labels_ids == -100] = tokenizer.pad_token_id\n",
566
+ " label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)\n",
567
+ "\n",
568
+ " rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=[\"rouge2\"])[\"rouge2\"].mid\n",
569
+ "\n",
570
+ " return {\n",
571
+ " \"rouge2_precision\": round(rouge_output.precision, 4),\n",
572
+ " \"rouge2_recall\": round(rouge_output.recall, 4),\n",
573
+ " \"rouge2_fmeasure\": round(rouge_output.fmeasure, 4),\n",
574
+ " }"
575
+ ]
576
+ },
577
+ {
578
+ "cell_type": "code",
579
+ "execution_count": 18,
580
+ "id": "2977e566-8714-4164-b7ad-2706dbd26be8",
581
+ "metadata": {
582
+ "papermill": {
583
+ "duration": 0.074325,
584
+ "end_time": "2023-02-01T13:30:51.451387",
585
+ "exception": false,
586
+ "start_time": "2023-02-01T13:30:51.377062",
587
+ "status": "completed"
588
+ },
589
+ "tags": []
590
+ },
591
+ "outputs": [],
592
+ "source": [
593
+ "# instantiate trainer\n",
594
+ "trainer = Seq2SeqTrainer(\n",
595
+ " model=model,\n",
596
+ " args=training_args,\n",
597
+ " compute_metrics=compute_metrics,\n",
598
+ " train_dataset=train_data,\n",
599
+ " eval_dataset=test_data,\n",
600
+ ")"
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "code",
605
+ "execution_count": 19,
606
+ "id": "8dce01a3-61b2-4cb4-b5d2-319d0e946083",
607
+ "metadata": {
608
+ "papermill": {
609
+ "duration": 227.616733,
610
+ "end_time": "2023-02-01T13:34:39.125675",
611
+ "exception": false,
612
+ "start_time": "2023-02-01T13:30:51.508942",
613
+ "status": "completed"
614
+ },
615
+ "tags": []
616
+ },
617
+ "outputs": [
618
+ {
619
+ "data": {
620
+ "text/html": [
621
+ "\n",
622
+ " <div>\n",
623
+ " \n",
624
+ " <progress value='1986' max='993' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
625
+ " [993/993 14:21]\n",
626
+ " </div>\n",
627
+ " "
628
+ ],
629
+ "text/plain": [
630
+ "<IPython.core.display.HTML object>"
631
+ ]
632
+ },
633
+ "metadata": {},
634
+ "output_type": "display_data"
635
+ },
636
+ {
637
+ "data": {
638
+ "text/plain": [
639
+ "{'eval_loss': 42.09397506713867,\n",
640
+ " 'eval_rouge2_precision': 0.002,\n",
641
+ " 'eval_rouge2_recall': 0.0009,\n",
642
+ " 'eval_rouge2_fmeasure': 0.0012,\n",
643
+ " 'eval_runtime': 77.1,\n",
644
+ " 'eval_samples_per_second': 205.94,\n",
645
+ " 'eval_steps_per_second': 12.879}"
646
+ ]
647
+ },
648
+ "execution_count": 19,
649
+ "metadata": {},
650
+ "output_type": "execute_result"
651
+ }
652
+ ],
653
+ "source": [
654
+ "trainer.evaluate()"
655
+ ]
656
+ },
657
+ {
658
+ "cell_type": "code",
659
+ "execution_count": 20,
660
+ "id": "55e1a216-8034-49ef-aea4-5fee281d07f3",
661
+ "metadata": {
662
+ "papermill": {
663
+ "duration": 6776.251162,
664
+ "end_time": "2023-02-01T15:27:35.554942",
665
+ "exception": false,
666
+ "start_time": "2023-02-01T13:34:39.303780",
667
+ "status": "completed"
668
+ },
669
+ "tags": []
670
+ },
671
+ "outputs": [
672
+ {
673
+ "name": "stderr",
674
+ "output_type": "stream",
675
+ "text": [
676
+ "/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
677
+ " FutureWarning,\n"
678
+ ]
679
+ },
680
+ {
681
+ "data": {
682
+ "text/html": [
683
+ "\n",
684
+ " <div>\n",
685
+ " \n",
686
+ " <progress value='40490' max='40490' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
687
+ " [40490/40490 2:12:07, Epoch 10/10]\n",
688
+ " </div>\n",
689
+ " <table border=\"1\" class=\"dataframe\">\n",
690
+ " <thead>\n",
691
+ " <tr style=\"text-align: left;\">\n",
692
+ " <th>Epoch</th>\n",
693
+ " <th>Training Loss</th>\n",
694
+ " <th>Validation Loss</th>\n",
695
+ " <th>Rouge2 Precision</th>\n",
696
+ " <th>Rouge2 Recall</th>\n",
697
+ " <th>Rouge2 Fmeasure</th>\n",
698
+ " </tr>\n",
699
+ " </thead>\n",
700
+ " <tbody>\n",
701
+ " <tr>\n",
702
+ " <td>1</td>\n",
703
+ " <td>0.103200</td>\n",
704
+ " <td>0.051379</td>\n",
705
+ " <td>0.901000</td>\n",
706
+ " <td>0.817300</td>\n",
707
+ " <td>0.849700</td>\n",
708
+ " </tr>\n",
709
+ " <tr>\n",
710
+ " <td>2</td>\n",
711
+ " <td>0.065800</td>\n",
712
+ " <td>0.038024</td>\n",
713
+ " <td>0.917400</td>\n",
714
+ " <td>0.838200</td>\n",
715
+ " <td>0.869300</td>\n",
716
+ " </tr>\n",
717
+ " <tr>\n",
718
+ " <td>3</td>\n",
719
+ " <td>0.054700</td>\n",
720
+ " <td>0.033012</td>\n",
721
+ " <td>0.923000</td>\n",
722
+ " <td>0.844100</td>\n",
723
+ " <td>0.875000</td>\n",
724
+ " </tr>\n",
725
+ " <tr>\n",
726
+ " <td>4</td>\n",
727
+ " <td>0.045900</td>\n",
728
+ " <td>0.030169</td>\n",
729
+ " <td>0.928600</td>\n",
730
+ " <td>0.847300</td>\n",
731
+ " <td>0.880000</td>\n",
732
+ " </tr>\n",
733
+ " <tr>\n",
734
+ " <td>5</td>\n",
735
+ " <td>0.040100</td>\n",
736
+ " <td>0.028730</td>\n",
737
+ " <td>0.930800</td>\n",
738
+ " <td>0.849800</td>\n",
739
+ " <td>0.882400</td>\n",
740
+ " </tr>\n",
741
+ " <tr>\n",
742
+ " <td>6</td>\n",
743
+ " <td>0.039300</td>\n",
744
+ " <td>0.027651</td>\n",
745
+ " <td>0.931800</td>\n",
746
+ " <td>0.850700</td>\n",
747
+ " <td>0.883300</td>\n",
748
+ " </tr>\n",
749
+ " <tr>\n",
750
+ " <td>7</td>\n",
751
+ " <td>0.036000</td>\n",
752
+ " <td>0.027332</td>\n",
753
+ " <td>0.932900</td>\n",
754
+ " <td>0.852000</td>\n",
755
+ " <td>0.884600</td>\n",
756
+ " </tr>\n",
757
+ " <tr>\n",
758
+ " <td>8</td>\n",
759
+ " <td>0.033500</td>\n",
760
+ " <td>0.026453</td>\n",
761
+ " <td>0.933100</td>\n",
762
+ " <td>0.852300</td>\n",
763
+ " <td>0.884900</td>\n",
764
+ " </tr>\n",
765
+ " <tr>\n",
766
+ " <td>9</td>\n",
767
+ " <td>0.032800</td>\n",
768
+ " <td>0.026168</td>\n",
769
+ " <td>0.934200</td>\n",
770
+ " <td>0.853100</td>\n",
771
+ " <td>0.885800</td>\n",
772
+ " </tr>\n",
773
+ " <tr>\n",
774
+ " <td>10</td>\n",
775
+ " <td>0.032300</td>\n",
776
+ " <td>0.026122</td>\n",
777
+ " <td>0.934300</td>\n",
778
+ " <td>0.853100</td>\n",
779
+ " <td>0.885900</td>\n",
780
+ " </tr>\n",
781
+ " </tbody>\n",
782
+ "</table><p>"
783
+ ],
784
+ "text/plain": [
785
+ "<IPython.core.display.HTML object>"
786
+ ]
787
+ },
788
+ "metadata": {},
789
+ "output_type": "display_data"
790
+ },
791
+ {
792
+ "data": {
793
+ "text/plain": [
794
+ "TrainOutput(global_step=40490, training_loss=0.2895770631857524, metrics={'train_runtime': 7927.7437, 'train_samples_per_second': 81.708, 'train_steps_per_second': 5.107, 'total_flos': 6.689509761024e+16, 'train_loss': 0.2895770631857524, 'epoch': 10.0})"
795
+ ]
796
+ },
797
+ "execution_count": 20,
798
+ "metadata": {},
799
+ "output_type": "execute_result"
800
+ }
801
+ ],
802
+ "source": [
803
+ "trainer.train()"
804
+ ]
805
+ },
806
+ {
807
+ "cell_type": "code",
808
+ "execution_count": 21,
809
+ "id": "0bd48e53-90c3-483f-b17d-ac996f801977",
810
+ "metadata": {
811
+ "papermill": {
812
+ "duration": 1.119331,
813
+ "end_time": "2023-02-01T15:27:37.410693",
814
+ "exception": false,
815
+ "start_time": "2023-02-01T15:27:36.291362",
816
+ "status": "completed"
817
+ },
818
+ "tags": []
819
+ },
820
+ "outputs": [
821
+ {
822
+ "data": {
823
+ "text/plain": [
824
+ "('/valohai/outputs/model/tokenizer_config.json',\n",
825
+ " '/valohai/outputs/model/special_tokens_map.json',\n",
826
+ " '/valohai/outputs/model/spiece.model',\n",
827
+ " '/valohai/outputs/model/added_tokens.json',\n",
828
+ " '/valohai/outputs/model/tokenizer.json')"
829
+ ]
830
+ },
831
+ "execution_count": 21,
832
+ "metadata": {},
833
+ "output_type": "execute_result"
834
+ }
835
+ ],
836
+ "source": [
837
+ "trainer.save_model(output_path)\n",
838
+ "tokenizer.save_pretrained(output_path)"
839
+ ]
840
+ },
841
+ {
842
+ "cell_type": "code",
843
+ "execution_count": 22,
844
+ "id": "38dbdcd0-14b3-4270-8ad8-03059b6d63de",
845
+ "metadata": {
846
+ "papermill": {
847
+ "duration": 1.717893,
848
+ "end_time": "2023-02-01T15:27:39.866396",
849
+ "exception": false,
850
+ "start_time": "2023-02-01T15:27:38.148503",
851
+ "status": "completed"
852
+ },
853
+ "tags": []
854
+ },
855
+ "outputs": [],
856
+ "source": [
857
+ "CKPT = output_path\n",
858
+ "\n",
859
+ "tokenizer = AutoTokenizer.from_pretrained(CKPT, local_files_only=True)\n",
860
+ "model = T5ForConditionalGeneration.from_pretrained(CKPT, local_files_only=True).to(torch_device)"
861
+ ]
862
+ },
863
+ {
864
+ "cell_type": "code",
865
+ "execution_count": 23,
866
+ "id": "90587462-5328-4f32-bad9-4fbfb00d7bb7",
867
+ "metadata": {
868
+ "papermill": {
869
+ "duration": 18.569639,
870
+ "end_time": "2023-02-01T15:27:59.089768",
871
+ "exception": false,
872
+ "start_time": "2023-02-01T15:27:40.520129",
873
+ "status": "completed"
874
+ },
875
+ "tags": []
876
+ },
877
+ "outputs": [
878
+ {
879
+ "name": "stdout",
880
+ "output_type": "stream",
881
+ "text": [
882
+ "Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.7/site-packages (0.1.97)\n",
883
+ "Collecting pandasql\n",
884
+ " Downloading pandasql-0.7.3.tar.gz (26 kB)\n",
885
+ "Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from pandasql) (1.21.2)\n",
886
+ "Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from pandasql) (1.3.5)\n",
887
+ "Collecting sqlalchemy\n",
888
+ " Downloading SQLAlchemy-2.0.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.7 MB)\n",
889
+ "\u001b[K |████████████████████████████████| 2.7 MB 6.9 MB/s eta 0:00:01\n",
890
+ "\u001b[?25hRequirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas->pandasql) (2.8.2)\n",
891
+ "Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas->pandasql) (2021.3)\n",
892
+ "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas->pandasql) (1.16.0)\n",
893
+ "Requirement already satisfied: typing-extensions>=4.2.0 in /opt/conda/lib/python3.7/site-packages (from sqlalchemy->pandasql) (4.5.0)\n",
894
+ "Collecting greenlet!=0.4.17\n",
895
+ " Downloading greenlet-2.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (566 kB)\n",
896
+ "\u001b[K |████████████████████████████████| 566 kB 111.8 MB/s eta 0:00:01\n",
897
+ "\u001b[?25hRequirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from sqlalchemy->pandasql) (6.1.0)\n",
898
+ "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->sqlalchemy->pandasql) (3.15.0)\n",
899
+ "Building wheels for collected packages: pandasql\n",
900
+ " Building wheel for pandasql (setup.py) ... \u001b[?25ldone\n",
901
+ "\u001b[?25h Created wheel for pandasql: filename=pandasql-0.7.3-py3-none-any.whl size=26782 sha256=110b83989487b7b983fb80e3ede92a519027d1ddfd6988e2012175878ee93522\n",
902
+ " Stored in directory: /root/.cache/pip/wheels/5c/4b/ec/41f4e116c8053c3654e2c2a47c62b4fca34cc67ef7b55deb7f\n",
903
+ "Successfully built pandasql\n",
904
+ "Installing collected packages: greenlet, sqlalchemy, pandasql\n",
905
+ "Successfully installed greenlet-2.0.2 pandasql-0.7.3 sqlalchemy-2.0.7\n",
906
+ "Collecting python-Levenshtein\n",
907
+ " Downloading python_Levenshtein-0.20.9-py3-none-any.whl (9.4 kB)\n",
908
+ "Collecting Levenshtein==0.20.9\n",
909
+ " Downloading Levenshtein-0.20.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (175 kB)\n",
910
+ "\u001b[K |████████████████████████████████| 175 kB 4.3 MB/s eta 0:00:01\n",
911
+ "\u001b[?25hCollecting rapidfuzz<3.0.0,>=2.3.0\n",
912
+ " Downloading rapidfuzz-2.13.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)\n",
913
+ "\u001b[K |████████████████████████████████| 2.2 MB 59.5 MB/s eta 0:00:01\n",
914
+ "\u001b[?25hInstalling collected packages: rapidfuzz, Levenshtein, python-Levenshtein\n",
915
+ "Successfully installed Levenshtein-0.20.9 python-Levenshtein-0.20.9 rapidfuzz-2.13.7\n",
916
+ "Collecting sacremoses\n",
917
+ " Downloading sacremoses-0.0.53.tar.gz (880 kB)\n",
918
+ "\u001b[K |████████████████████████████████| 880 kB 4.3 MB/s eta 0:00:01\n",
919
+ "\u001b[?25hRequirement already satisfied: regex in /opt/conda/lib/python3.7/site-packages (from sacremoses) (2022.10.31)\n",
920
+ "Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from sacremoses) (1.16.0)\n",
921
+ "Requirement already satisfied: click in /opt/conda/lib/python3.7/site-packages (from sacremoses) (8.1.3)\n",
922
+ "Requirement already satisfied: joblib in /opt/conda/lib/python3.7/site-packages (from sacremoses) (1.2.0)\n",
923
+ "Requirement already satisfied: tqdm in /opt/conda/lib/python3.7/site-packages (from sacremoses) (4.65.0)\n",
924
+ "Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from click->sacremoses) (6.1.0)\n",
925
+ "Requirement already satisfied: typing-extensions>=3.6.4 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->click->sacremoses) (4.5.0)\n",
926
+ "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->click->sacremoses) (3.15.0)\n",
927
+ "Building wheels for collected packages: sacremoses\n",
928
+ " Building wheel for sacremoses (setup.py) ... \u001b[?25ldone\n",
929
+ "\u001b[?25h Created wheel for sacremoses: filename=sacremoses-0.0.53-py3-none-any.whl size=895259 sha256=0f511e2624db29b2126dc7c3aec8ba54e44a1d89fa58a852332349de3af597b3\n",
930
+ " Stored in directory: /root/.cache/pip/wheels/87/39/dd/a83eeef36d0bf98e7a4d1933a4ad2d660295a40613079bafc9\n",
931
+ "Successfully built sacremoses\n",
932
+ "Installing collected packages: sacremoses\n",
933
+ "Successfully installed sacremoses-0.0.53\n"
934
+ ]
935
+ }
936
+ ],
937
+ "source": [
938
+ "!pip install sentencepiece\n",
939
+ "!pip install pandasql\n",
940
+ "!pip install python-Levenshtein\n",
941
+ "!pip install sacremoses"
942
+ ]
943
+ },
944
+ {
945
+ "cell_type": "code",
946
+ "execution_count": 24,
947
+ "id": "403dc883-13b6-4661-bb8c-678ec22840ab",
948
+ "metadata": {
949
+ "papermill": {
950
+ "duration": 0.819632,
951
+ "end_time": "2023-02-01T15:28:00.631919",
952
+ "exception": false,
953
+ "start_time": "2023-02-01T15:27:59.812287",
954
+ "status": "completed"
955
+ },
956
+ "tags": []
957
+ },
958
+ "outputs": [],
959
+ "source": [
960
+ "import Levenshtein\n",
961
+ "import re\n",
962
+ "from collections import Counter\n",
963
+ "\n",
964
+ "#Get columns in query\n",
965
+ "def get_columns_name_in_query(query):\n",
966
+ " cols_from_select = get_cols_name_for_select(query) \n",
967
+ " cols_from_where = get_cols_name_for_where(query)\n",
968
+ " return list(set(cols_from_select + cols_from_where))\n",
969
+ "\n",
970
+ "#Translate query in natural language from italian to english (input: string; output: string)\n",
971
+ "def translate2en(query):\n",
972
+ " translated = model_t.generate(**tokenizer_t(query, return_tensors=\"pt\", padding=True))\n",
973
+ " query = [tokenizer_t.decode(t, skip_special_tokens=True) for t in translated]\n",
974
+ " return query\n",
975
+ "\n",
976
+ "# Sometime column name maybe ill-defined. This function replace weird chars with underscore (input:list; output:string)\n",
977
+ "def replace_nonalphanumeric_chars_with_us(l):\n",
978
+ " well_defined = [re.sub('[^0-9a-zA-Z]+', '_', s) for s in l]\n",
979
+ " return well_defined\n",
980
+ "\n",
981
+ "# Adjust column name using columns name from original table (input: column name in SQL query (string), \n",
982
+ "#list of columns names from table (string); output: corrected column name (if needed) (string))\n",
983
+ "def adjust_col_name(col_name, columns_available): \n",
984
+ " columns_available = [x.upper() for x in columns_available]\n",
985
+ " if col_name.upper() in set(columns_available):\n",
986
+ " return col_name\n",
987
+ " else:\n",
988
+ " max = -100\n",
989
+ " most_similar_column = 'column123456789011'\n",
990
+ " for col in columns_available: \n",
991
+ " score = -Levenshtein.distance(col_name, col) \n",
992
+ " if score > max:\n",
993
+ " most_similar_column = col \n",
994
+ " max = score \n",
995
+ " return most_similar_column\n",
996
+ "\n",
997
+ "def min_positive(a,b):\n",
998
+ " if (b < a) and (b > 0): return b\n",
999
+ " else: return a\n",
1000
+ "\n",
1001
+ "#Return corrected syntax for aggregator operators (input: string; output: string)\n",
1002
+ "#USE only for wikisql dataset\n",
1003
+ "def aggregator_parser(query): \n",
1004
+ " query = query.upper() \n",
1005
+ " if query.find('SELECT MAX') > -1:\n",
1006
+ " end = min_positive(query.find('FROM'), query.find(',')) \n",
1007
+ " adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ')\n",
1008
+ " return adjusted_query\n",
1009
+ " elif query.find('SELECT COUNT') > -1:\n",
1010
+ " end = min_positive(query.find('FROM'), query.find(','))\n",
1011
+ " adjusted_query = query.replace(query[12:end],'(' + query[13:end-1] + ') ')\n",
1012
+ " return adjusted_query\n",
1013
+ " elif query.find('SELECT MIN') > -1:\n",
1014
+ " end = min_positive(query.find('FROM'), query.find(','))\n",
1015
+ " adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ')\n",
1016
+ " return adjusted_query\n",
1017
+ " #elif query.find('SELECT DISTINCT') > -1:\n",
1018
+ " #end = query.find('FROM')\n",
1019
+ " #adjusted_query = query.replace(query[15:end],'(' + query[16:end-1] + ') ')\n",
1020
+ " #return adjusted_query\n",
1021
+ " else: \n",
1022
+ " return query\n",
1023
+ "\n",
1024
+ "#Return columns name from SELECT operator (input: string; output: list)\n",
1025
+ "def get_cols_name_for_select(query):\n",
1026
+ " query = query.upper() \n",
1027
+ " if query.find('SELECT DISTINCT') > -1:\n",
1028
+ " end = query.find('FROM')\n",
1029
+ " cols = query[15:end-1].split(',')\n",
1030
+ " elif query.find('SELECT MAX') > -1:\n",
1031
+ " end = query.find('FROM')\n",
1032
+ " cols = query[10:end-1].split(',') \n",
1033
+ " elif query.find('SELECT MIN') > -1:\n",
1034
+ " end = query.find('FROM')\n",
1035
+ " cols = query[10:end-1].split(',') \n",
1036
+ " elif query.find('SELECT COUNT') > -1:\n",
1037
+ " end = query.find('FROM')\n",
1038
+ " cols = query[13:end-1].split(',') \n",
1039
+ " elif query.find('SELECT') > -1:\n",
1040
+ " end = query.find('FROM')\n",
1041
+ " cols = query[7:end-1].split(',') \n",
1042
+ " else: \n",
1043
+ " cols = [''] \n",
1044
+ " return [x.replace(' ','').replace(')','').replace('(','').upper() for x in cols]\n",
1045
+ "\n",
1046
+ "def get_indexes(l):\n",
1047
+ " ops = []\n",
1048
+ " idx = []\n",
1049
+ " for i in range(len(l)):\n",
1050
+ " if l[i] in ['=', '>', '<', '>=', '<=', '<>', 'LIKE', 'AND', 'OR']:\n",
1051
+ " idx.append(i)\n",
1052
+ " return idx\n",
1053
+ "\n",
1054
+ "def add_spaces_cmp_operators(string):\n",
1055
+ " ops = ['=', '>', '<', '>=', '<=', '<>']\n",
1056
+ " for op in ops:\n",
1057
+ " string = string.replace(op, ' ' + op + ' ') \n",
1058
+ " return ' '.join(string.split())\n",
1059
+ "\n",
1060
+ "#Check if string and add quotes (input: string; output: string)\n",
1061
+ "#USE only for wikisql dataset\n",
1062
+ "def add_quotes_to_string(query):\n",
1063
+ " query = query.upper()\n",
1064
+ " if query.find('WHERE') > 0:\n",
1065
+ " query_list = query.split(' ')\n",
1066
+ " query_list = [x.replace(' ','') for x in query_list]\n",
1067
+ " query_list[:] = filter(None, query_list) \n",
1068
+ " idx_list = get_indexes(query_list) \n",
1069
+ " idx_list.append(len(query_list)) \n",
1070
+ " subs = []\n",
1071
+ " for i in range(len(idx_list)):\n",
1072
+ " if i % 2 == 0:\n",
1073
+ " b = idx_list[i] + 1\n",
1074
+ " e = idx_list[i+1] - 1\n",
1075
+ " if b != e:\n",
1076
+ " s = ''\n",
1077
+ " for ix in range(b,e + 1): \n",
1078
+ " s = s + query_list[ix] + ' ' \n",
1079
+ " s = s[:-1] \n",
1080
+ " else:\n",
1081
+ " s = query_list[b] \n",
1082
+ " if not(s.isnumeric()):\n",
1083
+ " s = \"'\" + s + \"'\"\n",
1084
+ " subs.append((idx_list[i] + 1, idx_list[i+1] - 1, s)) \n",
1085
+ " subs = subs[::-1] \n",
1086
+ " for i in range(len(subs)):\n",
1087
+ " e = subs[i]\n",
1088
+ " if e[0] == e[1]:\n",
1089
+ " query_list[e[0]] = e[2]\n",
1090
+ " else:\n",
1091
+ " query_list[e[0]] = e[2]\n",
1092
+ " idx = e[1]\n",
1093
+ " while idx > e[0]:\n",
1094
+ " query_list.pop(idx)\n",
1095
+ " idx = idx - 1\n",
1096
+ " final_query = ''\n",
1097
+ " for word in query_list:\n",
1098
+ " final_query = final_query + word + ' ' \n",
1099
+ " return final_query[:-1]\n",
1100
+ " else:\n",
1101
+ " return query\n",
1102
+ "\n",
1103
+ "#Get values from where clause (input: string; output: list)\n",
1104
+ "def get_values_for_query_filter(query):\n",
1105
+ " query = query.upper()\n",
1106
+ " if query.find('WHERE') > 0:\n",
1107
+ " query_list = query.split(' ')\n",
1108
+ " query_list = [x.replace(' ','') for x in query_list]\n",
1109
+ " query_list[:] = filter(None, query_list) \n",
1110
+ " idx_list = get_indexes(query_list) \n",
1111
+ " idx_list.append(len(query_list)) \n",
1112
+ " subs = []\n",
1113
+ " for i in range(len(idx_list)):\n",
1114
+ " if i % 2 == 0:\n",
1115
+ " b = idx_list[i] + 1\n",
1116
+ " e = idx_list[i+1] - 1\n",
1117
+ " if b != e:\n",
1118
+ " s = ''\n",
1119
+ " for ix in range(b,e + 1): \n",
1120
+ " s = s + query_list[ix] + ' ' \n",
1121
+ " s = s[:-1] \n",
1122
+ " else:\n",
1123
+ " s = query_list[b] \n",
1124
+ " subs.append(s.replace(\"'\",\"\"))\n",
1125
+ " return subs\n",
1126
+ "\n",
1127
+ "\n",
1128
+ "# Get columns name after where (input: string, output: list)\n",
1129
+ "def get_cols_name_for_where(query):\n",
1130
+ " query = query.upper()\n",
1131
+ " subs = [] \n",
1132
+ " if query.find('WHERE') > 0:\n",
1133
+ " query_list = query.split(' ')\n",
1134
+ " query_list = [x.replace(' ','') for x in query_list]\n",
1135
+ " query_list[:] = filter(None, query_list) \n",
1136
+ " idx_list = get_indexes(query_list) \n",
1137
+ " #idx_list.append(len(query_list))\n",
1138
+ " idx_list.insert(0, query_list.index('WHERE')) \n",
1139
+ " for i in range(len(idx_list)-1):\n",
1140
+ " if i % 2 == 0: \n",
1141
+ " b = idx_list[i] + 1\n",
1142
+ " e = idx_list[i+1] - 1\n",
1143
+ " if b != e:\n",
1144
+ " s = ''\n",
1145
+ " for ix in range(b,e + 1): \n",
1146
+ " s = s + query_list[ix] + ' ' \n",
1147
+ " s = s[:-1] \n",
1148
+ " else:\n",
1149
+ " s = query_list[b]\n",
1150
+ " subs.append(s) \n",
1151
+ " return subs \n",
1152
+ "\n",
1153
+ "def check_if_number(s):\n",
1154
+ " try:\n",
1155
+ " a = float(s)\n",
1156
+ " return True\n",
1157
+ " except:\n",
1158
+ " return False\n",
1159
+ "\n",
1160
+ "#Correct missing compare operator (input: string; output: string)\n",
1161
+ "#T5 seems to have problem with '<' operator so if there is none this is used.\n",
1162
+ "def check_if_correct_cmp_operators(query):\n",
1163
+ " query = query.upper()\n",
1164
+ " if query.find('WHERE') > 0:\n",
1165
+ " query = add_spaces_cmp_operators(query)\n",
1166
+ " query_list = query.split(' ')\n",
1167
+ " w = query_list.index('WHERE')\n",
1168
+ " cmp_operators = ['=', '>', '<', '>=', '<=', '<>', 'LIKE']\n",
1169
+ " s = 0\n",
1170
+ " for op in cmp_operators:\n",
1171
+ " s = s + query_list.count(op)\n",
1172
+ " if s == 0: \n",
1173
+ " if check_if_number(query_list[-1]):\n",
1174
+ " query_list.insert(len(query_list)-1,'<')\n",
1175
+ " else:\n",
1176
+ " query_list.insert(len(query_list)-1,'=')\n",
1177
+ " return ' '.join(query_list)\n",
1178
+ " else:\n",
1179
+ " return query\n",
1180
+ " else: return query\n",
1181
+ " \n",
1182
+ "\n",
1183
+ "\n",
1184
+ "#Correct SQL syntax using info from table (input: string, list; ouput:string)\n",
1185
+ "#Use only for wikisql dataset\n",
1186
+ "def correct_query(query, table_columns): \n",
1187
+ " query = check_if_correct_cmp_operators(query)\n",
1188
+ " query = add_spaces_cmp_operators(query) \n",
1189
+ " #try: \n",
1190
+ " query = aggregator_parser(query) \n",
1191
+ " #except: pass \n",
1192
+ " #try: \n",
1193
+ " query = add_quotes_to_string(query) \n",
1194
+ " #except: pass \n",
1195
+ " #try:\n",
1196
+ " cols_name = get_columns_name_in_query(query) \n",
1197
+ " for col in cols_name: \n",
1198
+ " corrected_col = adjust_col_name(col, table_columns) \n",
1199
+ " query = query.replace(col, corrected_col)\n",
1200
+ " #except: pass\n",
1201
+ " return query\n",
1202
+ "\n",
1203
+ "def correct_mispelling(question, query): \n",
1204
+ " query = query.upper()\n",
1205
+ " if query.find('WHERE') > 0:\n",
1206
+ " question = question.upper()\n",
1207
+ " corrections = []\n",
1208
+ " values = get_values_for_query_filter(query)\n",
1209
+ " for value in values: \n",
1210
+ " l = len(value.split(' '))\n",
1211
+ " tokens = question.replace(' ', ' ').split(' ')\n",
1212
+ " l_gram = ''\n",
1213
+ " max = -100\n",
1214
+ " for i in range(0, len(tokens)-l+1, 1):\n",
1215
+ " filter = ' '.join(tokens[i:i+l]).strip('.,?')\n",
1216
+ " #filter = re.sub(r\"[,.;@#?!&$]+\\ *\", \" \", filter).strip() \n",
1217
+ " score = -Levenshtein.distance(value, filter) \n",
1218
+ " if score > max:\n",
1219
+ " max = score\n",
1220
+ " correct_filter = filter \n",
1221
+ " corrections.append([value, correct_filter]) \n",
1222
+ " for corr in corrections:\n",
1223
+ " query = query.replace(corr[0], corr[1])\n",
1224
+ " return query"
1225
+ ]
1226
+ },
1227
+ {
1228
+ "cell_type": "code",
1229
+ "execution_count": 25,
1230
+ "id": "4683a145-3f8c-4e0e-a7e1-e8b8573ecc35",
1231
+ "metadata": {
1232
+ "papermill": {
1233
+ "duration": 0.740263,
1234
+ "end_time": "2023-02-01T15:28:02.036850",
1235
+ "exception": false,
1236
+ "start_time": "2023-02-01T15:28:01.296587",
1237
+ "status": "completed"
1238
+ },
1239
+ "tags": []
1240
+ },
1241
+ "outputs": [],
1242
+ "source": [
1243
+ "def translate_to_sql(text):\n",
1244
+ " inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt').to(torch_device)\n",
1245
+ " input_ids = inputs.input_ids\n",
1246
+ " attention_mask = inputs.attention_mask\n",
1247
+ " output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)\n",
1248
+ "\n",
1249
+ " return tokenizer.decode(output[0], skip_special_tokens=True)"
1250
+ ]
1251
+ },
1252
+ {
1253
+ "cell_type": "code",
1254
+ "execution_count": null,
1255
+ "id": "0be55a1e-2ad4-4a4b-9beb-57623059c768",
1256
+ "metadata": {
1257
+ "papermill": {
1258
+ "duration": 1669.9823,
1259
+ "end_time": "2023-02-01T15:55:52.681707",
1260
+ "exception": false,
1261
+ "start_time": "2023-02-01T15:28:02.699407",
1262
+ "status": "completed"
1263
+ },
1264
+ "tags": []
1265
+ },
1266
+ "outputs": [
1267
+ {
1268
+ "name": "stderr",
1269
+ "output_type": "stream",
1270
+ "text": [
1271
+ "WARNING:datasets.builder:Found cached dataset wikisql (/root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)\n"
1272
+ ]
1273
+ },
1274
+ {
1275
+ "name": "stdout",
1276
+ "output_type": "stream",
1277
+ "text": [
1278
+ "15878\n",
1279
+ "0.0 0.01 %\n",
1280
+ "0.6595744680851063 0.51 %\n",
1281
+ "0.7263157894736842 1.01 %\n",
1282
+ "0.6928571428571428 1.51 %\n",
1283
+ "0.6878306878306878 2.01 %\n",
1284
+ "0.7058823529411765 2.51 %\n",
1285
+ "0.7142857142857143 3.01 %\n",
1286
+ "0.6795252225519288 3.51 %\n",
1287
+ "0.6909090909090909 4.01 %\n",
1288
+ "0.7020785219399538 4.51 %\n",
1289
+ "0.7 5.01 %\n",
1290
+ "0.7 5.51 %\n",
1291
+ "0.6972318339100346 6.01 %\n",
1292
+ "0.6990445859872612 6.51 %\n",
1293
+ "0.7071005917159763 7.01 %\n",
1294
+ "0.7158620689655173 7.51 %\n",
1295
+ "0.7174193548387097 8.01 %\n",
1296
+ "0.7127272727272728 8.51 %\n",
1297
+ "0.7177142857142857 9.01 %\n",
1298
+ "0.7096424702058505 9.51 %\n",
1299
+ "0.7057613168724279 10.01 %\n",
1300
+ "0.7045009784735812 10.51 %\n",
1301
+ "0.7052238805970149 11.01 %\n",
1302
+ "0.6978609625668449 11.51 %\n",
1303
+ "0.7028181041844578 12.01 %\n",
1304
+ "0.7 12.51 %\n",
1305
+ "0.7021276595744681 13.01 %\n",
1306
+ "0.7025796661608498 13.51 %\n",
1307
+ "0.706140350877193 14.01 %\n",
1308
+ "0.7080394922425952 14.51 %\n",
1309
+ "0.7100954979536153 15.01 %\n",
1310
+ "0.7058047493403694 15.51 %\n",
1311
+ "0.7049808429118773 16.01 %\n",
1312
+ "0.7054455445544554 16.51 %\n",
1313
+ "0.7063063063063063 17.01 %\n",
1314
+ "0.7046117921774664 17.51 %\n",
1315
+ "0.7060830017055145 18.01 %\n",
1316
+ "0.7037037037037037 18.51 %\n",
1317
+ "0.7002152852529602 19.01 %\n",
1318
+ "0.7017819706498952 19.51 %\n",
1319
+ "0.7017364657814096 20.01 %\n",
1320
+ "0.7016932270916335 20.51 %\n",
1321
+ "0.6987366375121478 21.01 %\n",
1322
+ "0.7009034712315739 21.51 %\n",
1323
+ "0.7007910656119125 22.01 %\n",
1324
+ "0.6995903504779244 22.51 %\n",
1325
+ "0.6994657168299199 23.01 %\n",
1326
+ "0.6980392156862745 23.51 %\n",
1327
+ "0.6961538461538461 24.01 %\n",
1328
+ "0.6961071578066137 24.51 %\n",
1329
+ "0.6978269782697827 25.01 %\n",
1330
+ "0.6994777018883086 25.51 %\n",
1331
+ "0.6992510839574301 26.01 %\n",
1332
+ "0.699265558562041 26.51 %\n",
1333
+ "0.6994307400379507 27.01 %\n",
1334
+ "0.6994413407821229 27.51 %\n",
1335
+ "0.6998171846435101 28.01 %\n",
1336
+ "0.7023339317773788 28.51 %\n",
1337
+ "0.702893436838391 29.01 %\n",
1338
+ "0.7018763029881863 29.51 %\n",
1339
+ "0.7004781420765027 30.01 %\n",
1340
+ "0.7021490933512424 30.51 %\n",
1341
+ "0.7007926023778072 31.01 %\n",
1342
+ "0.701885565669701 31.51 %\n",
1343
+ "0.7037747920665387 32.01 %\n",
1344
+ "0.7034005037783375 32.51 %\n",
1345
+ "0.703657780533168 33.01 %\n",
1346
+ "0.7042124542124543 33.51 %\n",
1347
+ "0.7037593984962406 34.01 %\n",
1348
+ "0.7045925925925925 34.51 %\n",
1349
+ "0.7033576642335766 35.01 %\n",
1350
+ "0.7020725388601037 35.51 %\n",
1351
+ "0.7021881216254617 36.01 %\n",
1352
+ "0.7045135968601065 36.51 %\n",
1353
+ "0.703816371681416 37.01 %\n",
1354
+ "0.7054009819967266 37.51 %\n",
1355
+ "0.7053283100107642 38.01 %\n",
1356
+ "0.7063197026022305 38.51 %\n",
1357
+ "0.7072851153039832 39.01 %\n",
1358
+ "0.707815734989648 39.51 %\n",
1359
+ "0.7072049054675523 40.01 %\n",
1360
+ "0.7080494574817058 40.51 %\n",
1361
+ "0.706556968337073 41.01 %\n",
1362
+ "0.7067224821472544 41.51 %\n",
1363
+ "0.7041135434207361 42.51 %\n",
1364
+ "0.7061340941512125 43.01 %\n",
1365
+ "0.7056195626616506 43.51 %\n",
1366
+ "0.7048570764582849 44.01 %\n",
1367
+ "0.70434183321847 44.51 %\n",
1368
+ "0.7051340299863699 45.01 %\n",
1369
+ "0.7056179775280899 45.51 %\n",
1370
+ "0.7052678372971771 46.01 %\n",
1371
+ "0.7047702791822379 46.51 %\n",
1372
+ "0.7036312241791693 47.01 %\n",
1373
+ "0.703807270380727 47.51 %\n",
1374
+ "0.704405192594169 48.01 %\n",
1375
+ "0.7039376710886502 48.51 %\n",
1376
+ "0.7032715148989372 49.01 %\n",
1377
+ "0.7025577557755776 49.51 %\n",
1378
+ "0.7021233156390363 50.01 %\n",
1379
+ "0.7030523549626035 50.51 %\n",
1380
+ "0.7028216930158094 51.01 %\n",
1381
+ "0.7049147839873167 51.51 %\n",
1382
+ "0.7049469964664311 52.01 %\n",
1383
+ "0.7056765163297045 52.51 %\n",
1384
+ "0.7060069310743166 53.01 %\n",
1385
+ "0.7065673921344024 53.51 %\n",
1386
+ "0.7078290468986385 54.01 %\n",
1387
+ "0.7077557137504683 54.51 %\n",
1388
+ "0.707629478373863 55.01 %\n",
1389
+ "0.709475620975161 55.51 %\n",
1390
+ "0.7108477666362808 56.01 %\n",
1391
+ "0.7125813449023861 56.51 %\n",
1392
+ "0.7136200716845879 57.01 %\n",
1393
+ "0.7140319715808171 57.51 %\n",
1394
+ "0.7151408450704225 58.01 %\n",
1395
+ "0.715782122905028 58.51 %\n",
1396
+ "0.7158186223606784 59.01 %\n",
1397
+ "0.7163289630512515 60.01 %\n",
1398
+ "0.7156366092536305 60.51 %\n",
1399
+ "0.7153382451440053 61.01 %\n",
1400
+ "0.7152108933909 61.51 %\n",
1401
+ "0.7151565074135091 62.01 %\n",
1402
+ "0.7163398692810458 62.51 %\n",
1403
+ "0.7165316045380875 63.01 %\n",
1404
+ "0.7157099212091976 63.51 %\n",
1405
+ "0.7162226830435476 64.01 %\n",
1406
+ "0.7173603418262383 64.51 %\n",
1407
+ "0.7172240540116188 65.01 %\n",
1408
+ "0.7178927680798005 65.51 %\n",
1409
+ "0.7174013921113689 66.01 %\n",
1410
+ "0.7172678434382195 66.51 %\n",
1411
+ "0.7177456207159177 67.01 %\n",
1412
+ "0.7183673469387755 67.51 %\n",
1413
+ "0.7195798949737434 68.01 %\n",
1414
+ "0.7207743857036486 68.51 %\n",
1415
+ "0.721359940872136 69.01 %\n",
1416
+ "0.7217901687454146 69.51 %\n",
1417
+ "0.7223597960670065 70.01 %\n",
1418
+ "0.7228410241573846 70.51 %\n",
1419
+ "0.722820623294557 71.01 %\n",
1420
+ "0.7224757558471192 71.51 %\n",
1421
+ "0.7225998300764656 72.01 %\n",
1422
+ "0.7231439820022497 72.51 %\n",
1423
+ "0.7228898826159866 73.01 %\n",
1424
+ "0.7233865371269952 73.51 %\n",
1425
+ "0.7239762856748931 74.01 %\n",
1426
+ "0.7236373596274993 74.51 %\n",
1427
+ "0.7238846572361263 75.01 %\n",
1428
+ "0.7239935152661443 75.51 %\n",
1429
+ "0.724771873322598 76.01 %\n",
1430
+ "0.7247400693148494 76.51 %\n",
1431
+ "0.7251655629139073 77.01 %\n",
1432
+ "0.7263157894736842 77.51 %\n",
1433
+ "0.727141922825376 78.01 %\n",
1434
+ "0.7273554256010396 78.51 %\n",
1435
+ "0.7268853305785123 79.01 %\n",
1436
+ "0.7272260713369259 79.51 %\n",
1437
+ "0.7276785714285714 80.01 %\n",
1438
+ "0.7281368821292775 80.51 %\n",
1439
+ "0.7278911564625851 81.01 %\n",
1440
+ "0.728274480340596 81.51 %\n",
1441
+ "0.7276007964161274 82.01 %\n",
1442
+ "0.7273964131106988 82.51 %\n",
1443
+ "0.7280885064535956 83.01 %\n",
1444
+ "0.7286726961623075 83.51 %\n",
1445
+ "0.7278911564625851 84.01 %\n",
1446
+ "0.7276020284955325 84.51 %\n",
1447
+ "0.728233457427645 85.01 %\n",
1448
+ "0.7277068162826787 85.51 %\n",
1449
+ "0.7281034892000949 86.01 %\n",
1450
+ "0.728259587020649 86.51 %\n",
1451
+ "0.7276246334310851 87.01 %\n",
1452
+ "0.7276630308656301 88.01 %\n",
1453
+ "0.7275245239469129 88.51 %\n",
1454
+ "0.7265006312406749 89.01 %\n",
1455
+ "0.726649920073076 89.51 %\n",
1456
+ "0.727262404905189 90.01 %\n",
1457
+ "0.7269030946464875 90.51 %\n",
1458
+ "0.7272012578616353 91.01 %\n",
1459
+ "0.7273235031277927 91.51 %\n",
1460
+ "0.7276061346965993 92.01 %\n",
1461
+ "0.7284687672747374 92.51 %\n",
1462
+ "0.7286327136728633 93.01 %\n",
1463
+ "0.7277601488127804 93.51 %\n",
1464
+ "0.7278267493742518 94.01 %\n",
1465
+ "0.7268593699253004 94.51 %\n",
1466
+ "0.7275665194441452 95.01 %\n",
1467
+ "0.7268245632836781 95.51 %\n",
1468
+ "0.7272824232081911 96.01 %\n",
1469
+ "0.7267614601018676 96.51 %\n",
1470
+ "0.7260346283783784 97.01 %\n",
1471
+ "0.7256878806973325 97.51 %\n",
1472
+ "0.7253447555369829 98.01 %\n",
1473
+ "0.7251559251559252 98.51 %\n",
1474
+ "0.7258248009101251 99.01 %\n",
1475
+ "0.7258130918073281 99.51 %\n"
1476
+ ]
1477
+ }
1478
+ ],
1479
+ "source": [
1480
+ "test_data = load_dataset('wikisql', split='test')\n",
1481
+ "\n",
1482
+ "print(len(test_data))\n",
1483
+ "n =10000\n",
1484
+ "\n",
1485
+ "count = 0\n",
1486
+ "correct_samples = 0\n",
1487
+ "for i in range(0,n,1):\n",
1488
+ " #print('processed', 100*(i+1)/n,'%') \n",
1489
+ " question = 'translate to SQL: ' + test_data[i]['question'] + ' table ID: ' + ', '.join(str(x) for x in test_data[i]['table']['header']) \n",
1490
+ " sql = translate_to_sql(question)\n",
1491
+ " #print(sql, test_data[i]['question'])\n",
1492
+ " #output = correct_query(sql, test_data[i]['table']['header']) \n",
1493
+ " #output = correct_mispelling(test_data[i]['question'], output)\n",
1494
+ " #target = correct_query(test_data[i]['sql']['human_readable'], test_data[i]['table']['header'])\n",
1495
+ " try: \n",
1496
+ " output = correct_query(sql, test_data[i]['table']['header'])\n",
1497
+ " output = correct_mispelling(test_data[i]['question'], output)\n",
1498
+ " target = correct_query(test_data[i]['sql']['human_readable'], test_data[i]['table']['header'])\n",
1499
+ " #output = sql\n",
1500
+ " #target = test_data[i]['sql']['human_readable']\n",
1501
+ " correct_samples = correct_samples + 1\n",
1502
+ " if output.lower() == target.lower():\n",
1503
+ " count = count + 1 \n",
1504
+ " else:\n",
1505
+ " #print(question)\n",
1506
+ " #print(output) \n",
1507
+ " #print(target) \n",
1508
+ " pass\n",
1509
+ " if i % 50 == 0:\n",
1510
+ " print(count/correct_samples, 100*(i+1)/n,'%') \n",
1511
+ " except Exception as err:\n",
1512
+ " #print(f\"Unexpected {err=}, {type(err)=}\")\n",
1513
+ " #print('---Error-- ') \n",
1514
+ " #print(sql) \n",
1515
+ " #print(test_data[i]['sql']['human_readable'])\n",
1516
+ " #print(test_data[i]['table']['header'])\n",
1517
+ " pass\n",
1518
+ " #output = translate_to_sql(question)\n",
1519
+ " #target = test_data[i]['sql']['human_readable']\n",
1520
+ " #print(question)\n",
1521
+ " #print(output) \n",
1522
+ " #print(target) \n",
1523
+ "print(count/n)\n",
1524
+ "print(count/correct_samples)\n",
1525
+ "print(correct_samples)"
1526
+ ]
1527
+ },
1528
+ {
1529
+ "cell_type": "code",
1530
+ "execution_count": null,
1531
+ "id": "8a64367c-63e8-4621-9dc8-f80c7944809f",
1532
+ "metadata": {
1533
+ "papermill": {
1534
+ "duration": 1.247871,
1535
+ "end_time": "2023-02-01T15:55:55.285210",
1536
+ "exception": false,
1537
+ "start_time": "2023-02-01T15:55:54.037339",
1538
+ "status": "completed"
1539
+ },
1540
+ "tags": []
1541
+ },
1542
+ "outputs": [],
1543
+ "source": [
1544
+ "with valohai.logger() as logger:\n",
1545
+ " logger.log('accuracy', count/correct_samples)\n",
1546
+ " "
1547
+ ]
1548
+ },
1549
+ {
1550
+ "cell_type": "code",
1551
+ "execution_count": null,
1552
+ "id": "94830d9c-7688-4b66-8b07-0bc5b0b0f8d1",
1553
+ "metadata": {
1554
+ "papermill": {
1555
+ "duration": 1.337796,
1556
+ "end_time": "2023-02-01T15:55:57.940900",
1557
+ "exception": false,
1558
+ "start_time": "2023-02-01T15:55:56.603104",
1559
+ "status": "completed"
1560
+ },
1561
+ "tags": []
1562
+ },
1563
+ "outputs": [],
1564
+ "source": [
1565
+ "print(count)\n",
1566
+ "print(correct_samples)"
1567
+ ]
1568
+ },
1569
+ {
1570
+ "cell_type": "code",
1571
+ "execution_count": null,
1572
+ "id": "e4d06297-007d-4d42-8fc4-e153a857a953",
1573
+ "metadata": {
1574
+ "papermill": {
1575
+ "duration": 1.317653,
1576
+ "end_time": "2023-02-01T15:56:00.574479",
1577
+ "exception": false,
1578
+ "start_time": "2023-02-01T15:55:59.256826",
1579
+ "status": "completed"
1580
+ },
1581
+ "tags": []
1582
+ },
1583
+ "outputs": [],
1584
+ "source": []
1585
+ }
1586
+ ],
1587
+ "metadata": {
1588
+ "kernelspec": {
1589
+ "display_name": "Python 3 (ipykernel)",
1590
+ "language": "python",
1591
+ "name": "python3"
1592
+ },
1593
+ "language_info": {
1594
+ "codemirror_mode": {
1595
+ "name": "ipython",
1596
+ "version": 3
1597
+ },
1598
+ "file_extension": ".py",
1599
+ "mimetype": "text/x-python",
1600
+ "name": "python",
1601
+ "nbconvert_exporter": "python",
1602
+ "pygments_lexer": "ipython3",
1603
+ "version": "3.7.11"
1604
+ },
1605
+ "papermill": {
1606
+ "default_parameters": {},
1607
+ "duration": 8853.368468,
1608
+ "end_time": "2023-02-01T15:56:05.444982",
1609
+ "environment_variables": {},
1610
+ "exception": null,
1611
+ "input_path": "/valohai/repository/txt2sql_t5_small_training.ipynb",
1612
+ "output_path": "/valohai/outputs/txt2sql_t5_small_training.ipynb",
1613
+ "parameters": {},
1614
+ "start_time": "2023-02-01T13:28:32.076514",
1615
+ "version": "2.3.3"
1616
+ }
1617
+ },
1618
+ "nbformat": 4,
1619
+ "nbformat_minor": 5
1620
+ }