felipesanma commited on
Commit
9213f5f
1 Parent(s): 8fe843e

add: dataset gen and training scripts

Browse files
utils/dataset_gen.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import pandas as pd
3
+ from tqdm.notebook import tqdm
4
+ from sklearn.utils import shuffle
5
+
6
+
7
+ train_dataset = load_dataset("squad", split="train")
8
+ valid_dataset = load_dataset("squad", split="validation")
9
+
10
+ df_train = pd.DataFrame(columns=["context", "answer", "question"])
11
+ df_validation = pd.DataFrame(columns=["context", "answer", "question"])
12
+
13
+
14
+ count_long = 0
15
+ count_short = 0
16
+
17
+
18
+ for index, val in enumerate(tqdm(train_dataset)):
19
+ print(index)
20
+ passage = val["context"]
21
+ question = val["question"]
22
+ answer = val["answers"]["text"][0]
23
+ no_of_words = len(answer.split())
24
+ if no_of_words >= 7:
25
+ count_long = count_long + 1
26
+ continue
27
+ else:
28
+ df_train.loc[count_short] = [passage] + [answer] + [question]
29
+ count_short = count_short + 1
30
+
31
+ print("count_long train dataset: ", count_long)
32
+ print("count_short train dataset: ", count_short)
33
+
34
+
35
+ count_long = 0
36
+ count_short = 0
37
+
38
+
39
+ for index, val in enumerate(tqdm(valid_dataset)):
40
+ print(index)
41
+ passage = val["context"]
42
+ question = val["question"]
43
+ answer = val["answers"]["text"][0]
44
+ no_of_words = len(answer.split())
45
+ if no_of_words >= 7:
46
+ count_long = count_long + 1
47
+ continue
48
+ else:
49
+ df_validation.loc[count_short] = [passage] + [answer] + [question]
50
+ count_short = count_short + 1
51
+
52
+ print("count_long validation dataset: ", count_long)
53
+ print("count_short validation dataset: ", count_short)
54
+
55
+ df_train = shuffle(df_train)
56
+ df_validation = shuffle(df_validation)
57
+
58
+ train_save_path = "squad_t5_train.csv"
59
+ validation_save_path = "squad_t5_validaton.csv"
60
+ df_train.to_csv(train_save_path, index=False)
61
+ df_validation.to_csv(validation_save_path, index=False)
utils/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ datasets==2.14.5
2
+ pandas==2.1.1
3
+ pytorch_lightning==2.1.0
4
+ scikit_learn==1.3.1
5
+ torch==2.1.0
6
+ tqdm==4.66.1
7
+ transformers==4.34.0
utils/t5_train_model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import pandas as pd
4
+ from torch.utils.data import Dataset, DataLoader
5
+
6
+ from transformers import AdamW, T5ForConditionalGeneration, T5Tokenizer
7
+
8
+ from tqdm.notebook import tqdm
9
+ import copy
10
+ import pytorch_lightning as pl
11
+
12
+
13
+ class QuestionGenerationDataset(Dataset):
14
+ def __init__(self, tokenizer, filepath, max_len_inp=512, max_len_out=96):
15
+ self.path = filepath
16
+
17
+ self.passage_column = "context"
18
+ self.answer = "answer"
19
+ self.question = "question"
20
+
21
+ # self.data = pd.read_csv(self.path)
22
+ self.data = pd.read_csv(self.path, nrows=1000)
23
+
24
+ self.max_len_input = max_len_inp
25
+ self.max_len_output = max_len_out
26
+ self.tokenizer = tokenizer
27
+ self.inputs = []
28
+ self.targets = []
29
+ self.skippedcount = 0
30
+ self._build()
31
+
32
+ def __len__(self):
33
+ return len(self.inputs)
34
+
35
+ def __getitem__(self, index):
36
+ source_ids = self.inputs[index]["input_ids"].squeeze()
37
+ target_ids = self.targets[index]["input_ids"].squeeze()
38
+
39
+ src_mask = self.inputs[index][
40
+ "attention_mask"
41
+ ].squeeze() # might need to squeeze
42
+ target_mask = self.targets[index][
43
+ "attention_mask"
44
+ ].squeeze() # might need to squeeze
45
+
46
+ labels = copy.deepcopy(target_ids)
47
+ labels[labels == 0] = -100
48
+
49
+ return {
50
+ "source_ids": source_ids,
51
+ "source_mask": src_mask,
52
+ "target_ids": target_ids,
53
+ "target_mask": target_mask,
54
+ "labels": labels,
55
+ }
56
+
57
+ def _build(self):
58
+ for idx in tqdm(range(len(self.data))):
59
+ passage, answer, target = (
60
+ self.data.loc[idx, self.passage_column],
61
+ self.data.loc[idx, self.answer],
62
+ self.data.loc[idx, self.question],
63
+ )
64
+
65
+ input_ = "context: %s answer: %s </s>" % (passage, answer)
66
+ target = "question: %s </s>" % (str(target))
67
+
68
+ # get encoding length of input. If it is greater than self.max_len skip it
69
+ test_input_encoding = self.tokenizer.encode_plus(
70
+ input_, truncation=False, return_tensors="pt"
71
+ )
72
+
73
+ length_of_input_encoding = len(test_input_encoding["input_ids"][0])
74
+
75
+ if length_of_input_encoding > self.max_len_input:
76
+ self.skippedcount = self.skippedcount + 1
77
+ continue
78
+
79
+ # tokenize inputs
80
+ tokenized_inputs = self.tokenizer.batch_encode_plus(
81
+ [input_],
82
+ max_length=self.max_len_input,
83
+ pad_to_max_length=True,
84
+ return_tensors="pt",
85
+ )
86
+ # tokenize targets
87
+ tokenized_targets = self.tokenizer.batch_encode_plus(
88
+ [target],
89
+ max_length=self.max_len_output,
90
+ pad_to_max_length=True,
91
+ return_tensors="pt",
92
+ )
93
+
94
+ self.inputs.append(tokenized_inputs)
95
+ self.targets.append(tokenized_targets)
96
+
97
+
98
+ class T5FineTuner(pl.LightningModule):
99
+ def __init__(self, hparams, t5model, t5tokenizer):
100
+ super(T5FineTuner, self).__init__()
101
+ self.save_hyperparameters(hparams)
102
+ # self.hparams = hparams
103
+ self.model = t5model
104
+ self.tokenizer = t5tokenizer
105
+
106
+ def forward(
107
+ self,
108
+ input_ids,
109
+ attention_mask=None,
110
+ decoder_input_ids=None,
111
+ decoder_attention_mask=None,
112
+ lm_labels=None,
113
+ ):
114
+ outputs = self.model(
115
+ input_ids=input_ids,
116
+ attention_mask=attention_mask,
117
+ decoder_attention_mask=decoder_attention_mask,
118
+ labels=lm_labels,
119
+ )
120
+
121
+ return outputs
122
+
123
+ def training_step(self, batch, batch_idx):
124
+ outputs = self.forward(
125
+ input_ids=batch["source_ids"],
126
+ attention_mask=batch["source_mask"],
127
+ decoder_input_ids=batch["target_ids"],
128
+ decoder_attention_mask=batch["target_mask"],
129
+ lm_labels=batch["labels"],
130
+ )
131
+
132
+ loss = outputs[0]
133
+ self.log("train_loss", loss)
134
+ return loss
135
+
136
+ def validation_step(self, batch, batch_idx):
137
+ outputs = self.forward(
138
+ input_ids=batch["source_ids"],
139
+ attention_mask=batch["source_mask"],
140
+ decoder_input_ids=batch["target_ids"],
141
+ decoder_attention_mask=batch["target_mask"],
142
+ lm_labels=batch["labels"],
143
+ )
144
+
145
+ loss = outputs[0]
146
+ self.log("val_loss", loss)
147
+ return loss
148
+
149
+ def train_dataloader(self):
150
+ return DataLoader(
151
+ train_dataset, batch_size=self.hparams.batch_size, num_workers=4
152
+ )
153
+
154
+ def val_dataloader(self):
155
+ return DataLoader(
156
+ validation_dataset, batch_size=self.hparams.batch_size, num_workers=4
157
+ )
158
+
159
+ def configure_optimizers(self):
160
+ optimizer = AdamW(self.parameters(), lr=3e-4, eps=1e-8)
161
+ return optimizer
162
+
163
+
164
+ if __name__ == "__main__":
165
+ pl.seed_everything(42)
166
+ train_file_path = "question_generator/dataset/squad_t5_train.csv"
167
+ validation_file_path = "question_generator/dataset/squad_t5_validaton.csv"
168
+
169
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
170
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
171
+
172
+ sample_encoding = t5_tokenizer.encode_plus(
173
+ "My name is Pipe San Martin",
174
+ max_length=64,
175
+ pad_to_max_length=True,
176
+ truncation=True,
177
+ return_tensors="pt",
178
+ )
179
+
180
+ print(sample_encoding.keys())
181
+ print(sample_encoding["input_ids"].shape)
182
+ print(sample_encoding["input_ids"].squeeze().shape)
183
+ print(sample_encoding["input_ids"])
184
+ tokenized_output = t5_tokenizer.convert_ids_to_tokens(
185
+ sample_encoding["input_ids"].squeeze()
186
+ )
187
+ print(f"Tokenized output: {tokenized_output}")
188
+ decoded_output = t5_tokenizer.decode(
189
+ sample_encoding["input_ids"].squeeze(),
190
+ skip_special_tokens=True,
191
+ clean_up_tokenization_spaces=True,
192
+ )
193
+ print(f"Decoded output: {decoded_output}")
194
+ train_dataset = QuestionGenerationDataset(t5_tokenizer, train_file_path)
195
+
196
+ train_sample = train_dataset[50]
197
+ decoded_train_input = t5_tokenizer.decode(train_sample["source_ids"])
198
+ decoded_train_output = t5_tokenizer.decode(train_sample["target_ids"])
199
+
200
+ print(decoded_train_input)
201
+ print(decoded_train_output)
202
+
203
+ validation_dataset = QuestionGenerationDataset(t5_tokenizer, validation_file_path)
204
+ args_dict = dict(
205
+ batch_size=4,
206
+ )
207
+
208
+ args = argparse.Namespace(**args_dict)
209
+
210
+ model = T5FineTuner(args, t5_model, t5_tokenizer)
211
+
212
+ trainer = pl.Trainer(max_epochs=1)
213
+
214
+ trainer.fit(model)
215
+
216
+ #print("Saving model")
217
+ #save_path_model = "question_generator/model/"
218
+ #save_path_tokenizer = "question_generator/tokenizer/"
219
+ #model.model.save_pretrained(save_path_model)
220
+ #t5_tokenizer.save_pretrained(save_path_tokenizer)