benschlagman commited on
Commit
e059918
1 Parent(s): 6445424

Delete fine_tuning_google_colab.py

Browse files
Files changed (1) hide show
  1. fine_tuning_google_colab.py +0 -481
fine_tuning_google_colab.py DELETED
@@ -1,481 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """Fine-tuning Google Colab
3
-
4
- Automatically generated by Colaboratory.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1owUKmaxQ_M6Ybsc3R9pR7_3a-J24ZKwk
8
-
9
- ## Introduction: TAPAS
10
-
11
- * Original TAPAS paper (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.398/
12
- * Follow-up paper on intermediate pre-training (EMMNLP Findings 2020): https://www.aclweb.org/anthology/2020.findings-emnlp.27/
13
- * Original Github repository: https://github.com/google-research/tapas
14
- * Blog post: https://ai.googleblog.com/2020/04/using-neural-networks-to-find-answers.html
15
-
16
- TAPAS is an algorithm that (among other tasks) can answer questions about tabular data. It is essentially a BERT model with relative position embeddings and additional token type ids that encode tabular structure, and 2 classification heads on top: one for **cell selection** and one for (optionally) performing an **aggregation** among selected cells (such as summing or counting).
17
-
18
- Similar to BERT, the base `TapasModel` is pre-trained using the masked language modeling (MLM) objective on a large collection of tables from Wikipedia and associated texts. In addition, the authors further pre-trained the model on an second task (table entailment) to increase the numerical reasoning capabilities of TAPAS (as explained in the follow-up paper), which further improves performance on downstream tasks.
19
-
20
- In this notebook, we are going to fine-tune `TapasForQuestionAnswering` on [Sequential Question Answering (SQA)](https://www.microsoft.com/en-us/research/publication/search-based-neural-structured-learning-sequential-question-answering/), a dataset built by Microsoft Research which deals with asking questions related to a table in a **conversational set-up**. We are going to do so as in the original paper, by adding a randomly initialized cell selection head on top of the pre-trained base model (note that SQA does not have questions that involve aggregation and hence no aggregation head), and then fine-tuning them altogether.
21
-
22
- First, we install both the Transformers library as well as the dependency on [`torch-scatter`](https://github.com/rusty1s/pytorch_scatter), which the model requires.
23
- """
24
-
25
- ! rm -r transformers
26
- ! git clone https://github.com/huggingface/transformers.git
27
- ! cd transformers
28
- ! pip install ./transformers
29
-
30
- !pip3 install torch==1.7.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
31
- !pip3 install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
32
-
33
- """We also install a small portion from the SQA training dataset, for demonstration purposes. This is a TSV file containing table-question pairs. Besides this, we also download the `table_csv` directory, which contains the actual tabular data.
34
-
35
- Note that you can download the entire SQA dataset on the [official website](https://www.microsoft.com/en-us/download/details.aspx?id=54253).
36
- """
37
-
38
- import requests, zipfile, io
39
- import os
40
-
41
- def download_files(dir_name):
42
- if not os.path.exists(dir_name):
43
- # 28 training examples from the SQA training set + table csv data
44
- urls = ["https://www.dropbox.com/s/j6n2kq4n7ugippc/qa.zip?dl=1",
45
- "https://www.dropbox.com/s/mlg46w8lwsn90ne/table_csv.zip?dl=1"
46
- ]
47
- for url in urls:
48
- r = requests.get(url)
49
- z = zipfile.ZipFile(io.BytesIO(r.content))
50
- z.extractall()
51
-
52
- dir_name = "sqa_data"
53
- download_files(dir_name)
54
-
55
- """## Prepare the data
56
-
57
- Let's look at the first few rows of the dataset:
58
- """
59
-
60
- import pandas as pd
61
-
62
- data = pd.read_excel("set_28_examples.xlsx")
63
- data.head()
64
-
65
- """As you can see, each row corresponds to a question related to a table.
66
- * The `position` column identifies whether the question is the first, second, ... in a sequence of questions related to a table.
67
- * The `table_file` column identifies the name of the table file, which refers to a CSV file in the `table_csv` directory.
68
- * The `answer_coordinates` and `answer_text` columns indicate the answer to the question. The `answer_coordinates` is a list of tuples, each tuple being a (row_index, column_index) pair. The `answer_text` column is a list of strings, indicating the cell values.
69
-
70
- However, the `answer_coordinates` and `answer_text` columns are currently not recognized as real Python lists of Python tuples and strings respectively. Let's do that first using the `.literal_eval()`function of the `ast` module:
71
- """
72
-
73
- import ast
74
-
75
- def _parse_answer_coordinates(answer_coordinate_str):
76
- """Parses the answer_coordinates of a question.
77
- Args:
78
- answer_coordinate_str: A string representation of a Python list of tuple
79
- strings.
80
- For example: "['(1, 4)','(1, 3)', ...]"
81
- """
82
-
83
- try:
84
- answer_coordinates = []
85
- # make a list of strings
86
- coords = ast.literal_eval(answer_coordinate_str)
87
- # parse each string as a tuple
88
- for row_index, column_index in sorted(
89
- ast.literal_eval(coord) for coord in coords):
90
- answer_coordinates.append((row_index, column_index))
91
- except SyntaxError:
92
- raise ValueError('Unable to evaluate %s' % answer_coordinate_str)
93
-
94
- return answer_coordinates
95
-
96
-
97
- def _parse_answer_text(answer_text):
98
- """Populates the answer_texts field of `answer` by parsing `answer_text`.
99
- Args:
100
- answer_text: A string representation of a Python list of strings.
101
- For example: "[u'test', u'hello', ...]"
102
- answer: an Answer object.
103
- """
104
- try:
105
- answer = []
106
- for value in ast.literal_eval(answer_text):
107
- answer.append(value)
108
- except SyntaxError:
109
- raise ValueError('Unable to evaluate %s' % answer_text)
110
-
111
- return answer
112
-
113
- data['answer_coordinates'] = data['answer_coordinates'].apply(lambda coords_str: _parse_answer_coordinates(coords_str))
114
- data['answer_text'] = data['answer_text'].apply(lambda txt: _parse_answer_text(txt))
115
-
116
- data.head(10)
117
-
118
- """Let's create a new dataframe that groups questions which are asked in a sequence related to the table. We can do this by adding a `sequence_id` column, which is a combination of the `id` and `annotator` columns:"""
119
-
120
- def get_sequence_id(example_id, annotator):
121
- if "-" in str(annotator):
122
- raise ValueError('"-" not allowed in annotator.')
123
- return f"{example_id}-{annotator}"
124
-
125
- data['sequence_id'] = data.apply(lambda x: get_sequence_id(x.id, x.annotator), axis=1)
126
- data.head()
127
-
128
- # let's group table-question pairs by sequence id, and remove some columns we don't need
129
- grouped = data.groupby(by='sequence_id').agg(lambda x: x.tolist())
130
- grouped = grouped.drop(columns=['id', 'annotator', 'position'])
131
- grouped['table_file'] = grouped['table_file'].apply(lambda x: x[0])
132
- grouped.head(10)
133
-
134
- """Each row in the dataframe above now consists of a **table and one or more questions** which are asked in a **sequence**. Let's visualize the first row, i.e. a table, together with its queries:"""
135
-
136
- # path to the directory containing all csv files
137
- table_csv_path = "table_csv"
138
-
139
- item = grouped.iloc[0]
140
- table = pd.read_csv(table_csv_path + item.table_file[9:]).astype(str)
141
-
142
- display(table)
143
- print("")
144
- print(item.question)
145
-
146
- """We can see that there are 3 sequential questions asked related to the contents of the table.
147
-
148
- We can now use `TapasTokenizer` to batch encode this, as follows:
149
- """
150
-
151
- import torch
152
- from transformers import TapasTokenizer
153
-
154
- # initialize the tokenizer
155
- tokenizer = TapasTokenizer.from_pretrained("google/tapas-base")
156
-
157
- encoding = tokenizer(table=table, queries=item.question, answer_coordinates=item.answer_coordinates, answer_text=item.answer_text,
158
- truncation=True, padding="max_length", return_tensors="pt")
159
- encoding.keys()
160
-
161
- """TAPAS basically flattens every table-question pair before feeding it into a BERT like model:"""
162
-
163
- tokenizer.decode(encoding["input_ids"][0])
164
-
165
- """The `token_type_ids` created here will be of shape (batch_size, sequence_length, 7), as TAPAS uses 7 different token types to encode tabular structure. Let's verify this:"""
166
-
167
- assert encoding["token_type_ids"].shape == (3, 512, 7)
168
-
169
- """
170
-
171
- One thing we can verify is whether the `prev_label` token type ids are created correctly. These indicate which tokens were (part of) an answer to the previous table-question pair.
172
-
173
- The prev_label token type ids of the first example in a batch must always be zero (since there's no previous table-question pair). Let's verify this:"""
174
-
175
- assert encoding["token_type_ids"][0][:,3].sum() == 0
176
-
177
- """However, the `prev_label` token type ids of the second table-question pair in the batch must be set to 1 for the tokens which were an answer to the previous (i.e. the first) table question pair in the batch. The answers to the first table-question pair are the following:"""
178
-
179
- print(item.answer_text[0])
180
-
181
- """So let's now verify whether the `prev_label` ids of the second table-question pair are set correctly:"""
182
-
183
- for id, prev_label in zip (encoding["input_ids"][1], encoding["token_type_ids"][1][:,3]):
184
- if id != 0: # we skip padding tokens
185
- print(tokenizer.decode([id]), prev_label.item())
186
-
187
- """This looks OK! Be sure to check this, because the token type ids are critical for the performance of TAPAS.
188
-
189
- Let's create a PyTorch dataset and corresponding dataloader. Note the __getitem__ method here: in order to properly set the prev_labels token types, we must check whether a table-question pair is the first in a sequence or not. In case it is, we can just encode it. In case it isn't, we need to encode it together with the previous table-question pair.
190
-
191
- Note that this is not the most efficient approach, because we're effectively tokenizing each table-question pair twice when applied on the entire dataset (feel free to ping me a more efficient solution).
192
- """
193
-
194
- class TableDataset(torch.utils.data.Dataset):
195
- def __init__(self, df, tokenizer):
196
- self.df = df
197
- self.tokenizer = tokenizer
198
-
199
- def __getitem__(self, idx):
200
- item = self.df.iloc[idx]
201
- table = pd.read_csv(table_csv_path + item.table_file[9:]).astype(str) # TapasTokenizer expects the table data to be text only
202
- if item.position != 0:
203
- # use the previous table-question pair to correctly set the prev_labels token type ids
204
- previous_item = self.df.iloc[idx-1]
205
- encoding = self.tokenizer(table=table,
206
- queries=[previous_item.question, item.question],
207
- answer_coordinates=[previous_item.answer_coordinates, item.answer_coordinates],
208
- answer_text=[previous_item.answer_text, item.answer_text],
209
- padding="max_length",
210
- truncation=True,
211
- return_tensors="pt"
212
- )
213
- # use encodings of second table-question pair in the batch
214
- encoding = {key: val[-1] for key, val in encoding.items()}
215
- else:
216
- # this means it's the first table-question pair in a sequence
217
- encoding = self.tokenizer(table=table,
218
- queries=item.question,
219
- answer_coordinates=item.answer_coordinates,
220
- answer_text=item.answer_text,
221
- padding="max_length",
222
- truncation=True,
223
- return_tensors="pt"
224
- )
225
- # remove the batch dimension which the tokenizer adds
226
- encoding = {key: val.squeeze(0) for key, val in encoding.items()}
227
- return encoding
228
-
229
- def __len__(self):
230
- return len(self.df)
231
-
232
- train_dataset = TableDataset(df=data, tokenizer=tokenizer)
233
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2)
234
-
235
- train_dataset[0]["token_type_ids"].shape
236
-
237
- train_dataset[1]["input_ids"].shape
238
-
239
- batch = next(iter(train_dataloader))
240
-
241
- batch["input_ids"].shape
242
-
243
- batch["token_type_ids"].shape
244
-
245
- """Let's decode the first table-question pair:"""
246
-
247
- tokenizer.decode(batch["input_ids"][0])
248
-
249
- #first example should not have any prev_labels set
250
- assert batch["token_type_ids"][0][:,3].sum() == 0
251
-
252
- """Let's decode the second table-question pair and verify some more:"""
253
-
254
- tokenizer.decode(batch["input_ids"][1])
255
-
256
- assert batch["labels"][0].sum() == batch["token_type_ids"][1][:,3].sum()
257
- print(batch["token_type_ids"][1][:,3].sum())
258
-
259
- for id, prev_label in zip(batch["input_ids"][1], batch["token_type_ids"][1][:,3]):
260
- if id != 0:
261
- print(tokenizer.decode([id]), prev_label.item())
262
-
263
- """## Define the model
264
-
265
- Here we initialize the model with a pre-trained base and randomly initialized cell selection head, and move it to the GPU (if available).
266
-
267
- Note that the `google/tapas-base` checkpoint has (by default) an SQA configuration, so we don't need to specify any additional hyperparameters.
268
- """
269
-
270
- from transformers import TapasForQuestionAnswering
271
-
272
- model = TapasForQuestionAnswering.from_pretrained("google/tapas-base")
273
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
274
-
275
- model.to(device)
276
-
277
- """## Training the model
278
-
279
- Let's fine-tune the model in well-known PyTorch fashion:
280
- """
281
-
282
- from transformers import AdamW
283
-
284
- optimizer = AdamW(model.parameters(), lr=5e-5)
285
-
286
- for epoch in range(10): # loop over the dataset multiple times
287
- print("Epoch:", epoch)
288
- for idx, batch in enumerate(train_dataloader):
289
- # get the inputs;
290
- input_ids = batch["input_ids"].to(device)
291
- attention_mask = batch["attention_mask"].to(device)
292
- token_type_ids = batch["token_type_ids"].to(device)
293
- labels = batch["labels"].to(device)
294
-
295
- # zero the parameter gradients
296
- optimizer.zero_grad()
297
- # forward + backward + optimize
298
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
299
- labels=labels)
300
- loss = outputs.loss
301
- print("Loss:", loss.item())
302
- loss.backward()
303
- optimizer.step()
304
-
305
- """## Inference
306
-
307
- As SQA is a bit different due to its conversational nature, we need to run every training example of the a batch one by one through the model (sequentially), overwriting the `prev_labels` token types (which were created by the tokenizer) by the answer predicted by the model. It is based on the [following code](https://github.com/google-research/tapas/blob/f458b6624b8aa75961a0ab78e9847355022940d3/tapas/experiments/prediction_utils.py#L92) from the official implementation:
308
- """
309
-
310
- import collections
311
- import numpy as np
312
-
313
- def compute_prediction_sequence(model, data, device):
314
- """Computes predictions using model's answers to the previous questions."""
315
-
316
- # prepare data
317
- input_ids = data["input_ids"].to(device)
318
- attention_mask = data["attention_mask"].to(device)
319
- token_type_ids = data["token_type_ids"].to(device)
320
-
321
- all_logits = []
322
- prev_answers = None
323
-
324
- num_batch = data["input_ids"].shape[0]
325
-
326
- for idx in range(num_batch):
327
-
328
- if prev_answers is not None:
329
- coords_to_answer = prev_answers[idx]
330
- # Next, set the label ids predicted by the model
331
- prev_label_ids_example = token_type_ids_example[:,3] # shape (seq_len,)
332
- model_label_ids = np.zeros_like(prev_label_ids_example.cpu().numpy()) # shape (seq_len,)
333
-
334
- # for each token in the sequence:
335
- token_type_ids_example = token_type_ids[idx] # shape (seq_len, 7)
336
- for i in range(model_label_ids.shape[0]):
337
- segment_id = token_type_ids_example[:,0].tolist()[i]
338
- col_id = token_type_ids_example[:,1].tolist()[i] - 1
339
- row_id = token_type_ids_example[:,2].tolist()[i] - 1
340
- if row_id >= 0 and col_id >= 0 and segment_id == 1:
341
- model_label_ids[i] = int(coords_to_answer[(col_id, row_id)])
342
-
343
- # set the prev label ids of the example (shape (1, seq_len) )
344
- token_type_ids_example[:,3] = torch.from_numpy(model_label_ids).type(torch.long).to(device)
345
-
346
- prev_answers = {}
347
- # get the example
348
- input_ids_example = input_ids[idx] # shape (seq_len,)
349
- attention_mask_example = attention_mask[idx] # shape (seq_len,)
350
- token_type_ids_example = token_type_ids[idx] # shape (seq_len, 7)
351
- # forward pass to obtain the logits
352
- outputs = model(input_ids=input_ids_example.unsqueeze(0),
353
- attention_mask=attention_mask_example.unsqueeze(0),
354
- token_type_ids=token_type_ids_example.unsqueeze(0))
355
- logits = outputs.logits
356
- all_logits.append(logits)
357
-
358
- # convert logits to probabilities (which are of shape (1, seq_len))
359
- dist_per_token = torch.distributions.Bernoulli(logits=logits)
360
- probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(dist_per_token.probs.device)
361
-
362
- # Compute average probability per cell, aggregating over tokens.
363
- # Dictionary maps coordinates to a list of one or more probabilities
364
- coords_to_probs = collections.defaultdict(list)
365
- prev_answers = {}
366
- for i, p in enumerate(probabilities.squeeze().tolist()):
367
- segment_id = token_type_ids_example[:,0].tolist()[i]
368
- col = token_type_ids_example[:,1].tolist()[i] - 1
369
- row = token_type_ids_example[:,2].tolist()[i] - 1
370
- if col >= 0 and row >= 0 and segment_id == 1:
371
- coords_to_probs[(col, row)].append(p)
372
-
373
- # Next, map cell coordinates to 1 or 0 (depending on whether the mean prob of all cell tokens is > 0.5)
374
- coords_to_answer = {}
375
- for key in coords_to_probs:
376
- coords_to_answer[key] = np.array(coords_to_probs[key]).mean() > 0.5
377
- prev_answers[idx+1] = coords_to_answer
378
-
379
- logits_batch = torch.cat(tuple(all_logits), 0)
380
-
381
- return logits_batch
382
-
383
- data = {'Actors': ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
384
- 'Age': ["56", "45", "59"],
385
- 'Number of movies': ["87", "53", "69"],
386
- 'Date of birth': ["7 february 1967", "10 june 1996", "28 november 1967"]}
387
- queries = ["How many movies has George Clooney played in?", "How old is he?", "What's his date of birth?"]
388
-
389
- table = pd.DataFrame.from_dict(data)
390
-
391
- inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors="pt")
392
- logits = compute_prediction_sequence(model, inputs, device)
393
-
394
- """Finally, we can use the handy `convert_logits_to_predictions` function of `TapasTokenizer` to convert the logits into predicted coordinates, and print out the result:"""
395
-
396
- predicted_answer_coordinates, = tokenizer.convert_logits_to_predictions(inputs, logits.cpu().detach())
397
-
398
- # handy helper function in case inference on Pandas dataframe
399
- answers = []
400
- for coordinates in predicted_answer_coordinates:
401
- if len(coordinates) == 1:
402
- # only a single cell:
403
- answers.append(table.iat[coordinates[0]])
404
- else:
405
- # multiple cells
406
- cell_values = []
407
- for coordinate in coordinates:
408
- cell_values.append(table.iat[coordinate])
409
- answers.append(", ".join(cell_values))
410
-
411
- display(table)
412
- print("")
413
- for query, answer in zip(queries, answers):
414
- print(query)
415
- print("Predicted answer: " + answer)
416
-
417
- """Note that the results here are not correct, that's obvious since we only trained on 28 examples and tested it on an entire different example. In reality, you should train on the entire dataset. The result of this is the `google/tapas-base-finetuned-sqa` checkpoint.
418
-
419
- ## Legacy
420
-
421
- The code below was considered during the creation of this tutorial, but eventually not used.
422
- """
423
-
424
- # grouped = data.groupby(data.position)
425
- # test = grouped.get_group(0)
426
- # test.index
427
-
428
- def custom_collate_fn(data):
429
- """
430
- A custom collate function to batch input_ids, attention_mask, token_type_ids and so on of different batch sizes.
431
-
432
- Args:
433
- data:
434
- a list of dictionaries (each dictionary is what the __getitem__ method of TableDataset returns)
435
- """
436
- result = {}
437
- for k in data[0].keys():
438
- result[k] = torch.cat([x[k] for x in data], dim=0)
439
-
440
- return result
441
-
442
- class TableDataset(torch.utils.data.Dataset):
443
- def __init__(self, df, tokenizer):
444
- self.df = df
445
- self.tokenizer = tokenizer
446
-
447
- def __getitem__(self, idx):
448
- item = self.df.iloc[idx]
449
- table = pd.read_csv(table_csv_path + item.table_file[9:]).astype(str) # TapasTokenizer expects the table data to be text only
450
- if item.position != 0:
451
- # use the previous table-question pair
452
- previous_item = self.df.iloc[idx-1]
453
- encoding = self.tokenizer(table=table,
454
- queries=[previous_item.question, item.question],
455
- answer_coordinates=[previous_item.answer_coordinates, item.answer_coordinates],
456
- answer_text=[previous_item.answer_text, item.answer_text],
457
- padding="max_length",
458
- truncation=True,
459
- return_tensors="pt"
460
- )
461
- # remove the batch dimension which the tokenizer adds
462
- encoding = {key: val[-1] for key, val in encoding.items()}
463
- #encoding = {key: val.squeeze(0) for key, val in encoding.items()}
464
- else:
465
- # this means it's the first table-question pair in a sequence
466
- encoding = self.tokenizer(table=table,
467
- queries=item.question,
468
- answer_coordinates=item.answer_coordinates,
469
- answer_text=item.answer_text,
470
- padding="max_length",
471
- truncation=True,
472
- return_tensors="pt"
473
- )
474
- return encoding
475
-
476
- def __len__(self):
477
- return len(self.df)
478
-
479
- train_dataset = TableDataset(df=grouped, tokenizer=tokenizer)
480
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, collate_fn=custom_collate_fn)
481
-