felipesanma
commited on
Commit
•
9213f5f
1
Parent(s):
8fe843e
add: dataset gen and training scripts
Browse files- utils/dataset_gen.py +61 -0
- utils/requirements.txt +7 -0
- utils/t5_train_model.py +220 -0
utils/dataset_gen.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
import pandas as pd
|
3 |
+
from tqdm.notebook import tqdm
|
4 |
+
from sklearn.utils import shuffle
|
5 |
+
|
6 |
+
|
7 |
+
train_dataset = load_dataset("squad", split="train")
|
8 |
+
valid_dataset = load_dataset("squad", split="validation")
|
9 |
+
|
10 |
+
df_train = pd.DataFrame(columns=["context", "answer", "question"])
|
11 |
+
df_validation = pd.DataFrame(columns=["context", "answer", "question"])
|
12 |
+
|
13 |
+
|
14 |
+
count_long = 0
|
15 |
+
count_short = 0
|
16 |
+
|
17 |
+
|
18 |
+
for index, val in enumerate(tqdm(train_dataset)):
|
19 |
+
print(index)
|
20 |
+
passage = val["context"]
|
21 |
+
question = val["question"]
|
22 |
+
answer = val["answers"]["text"][0]
|
23 |
+
no_of_words = len(answer.split())
|
24 |
+
if no_of_words >= 7:
|
25 |
+
count_long = count_long + 1
|
26 |
+
continue
|
27 |
+
else:
|
28 |
+
df_train.loc[count_short] = [passage] + [answer] + [question]
|
29 |
+
count_short = count_short + 1
|
30 |
+
|
31 |
+
print("count_long train dataset: ", count_long)
|
32 |
+
print("count_short train dataset: ", count_short)
|
33 |
+
|
34 |
+
|
35 |
+
count_long = 0
|
36 |
+
count_short = 0
|
37 |
+
|
38 |
+
|
39 |
+
for index, val in enumerate(tqdm(valid_dataset)):
|
40 |
+
print(index)
|
41 |
+
passage = val["context"]
|
42 |
+
question = val["question"]
|
43 |
+
answer = val["answers"]["text"][0]
|
44 |
+
no_of_words = len(answer.split())
|
45 |
+
if no_of_words >= 7:
|
46 |
+
count_long = count_long + 1
|
47 |
+
continue
|
48 |
+
else:
|
49 |
+
df_validation.loc[count_short] = [passage] + [answer] + [question]
|
50 |
+
count_short = count_short + 1
|
51 |
+
|
52 |
+
print("count_long validation dataset: ", count_long)
|
53 |
+
print("count_short validation dataset: ", count_short)
|
54 |
+
|
55 |
+
df_train = shuffle(df_train)
|
56 |
+
df_validation = shuffle(df_validation)
|
57 |
+
|
58 |
+
train_save_path = "squad_t5_train.csv"
|
59 |
+
validation_save_path = "squad_t5_validaton.csv"
|
60 |
+
df_train.to_csv(train_save_path, index=False)
|
61 |
+
df_validation.to_csv(validation_save_path, index=False)
|
utils/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets==2.14.5
|
2 |
+
pandas==2.1.1
|
3 |
+
pytorch_lightning==2.1.0
|
4 |
+
scikit_learn==1.3.1
|
5 |
+
torch==2.1.0
|
6 |
+
tqdm==4.66.1
|
7 |
+
transformers==4.34.0
|
utils/t5_train_model.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
|
6 |
+
from transformers import AdamW, T5ForConditionalGeneration, T5Tokenizer
|
7 |
+
|
8 |
+
from tqdm.notebook import tqdm
|
9 |
+
import copy
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
|
12 |
+
|
13 |
+
class QuestionGenerationDataset(Dataset):
|
14 |
+
def __init__(self, tokenizer, filepath, max_len_inp=512, max_len_out=96):
|
15 |
+
self.path = filepath
|
16 |
+
|
17 |
+
self.passage_column = "context"
|
18 |
+
self.answer = "answer"
|
19 |
+
self.question = "question"
|
20 |
+
|
21 |
+
# self.data = pd.read_csv(self.path)
|
22 |
+
self.data = pd.read_csv(self.path, nrows=1000)
|
23 |
+
|
24 |
+
self.max_len_input = max_len_inp
|
25 |
+
self.max_len_output = max_len_out
|
26 |
+
self.tokenizer = tokenizer
|
27 |
+
self.inputs = []
|
28 |
+
self.targets = []
|
29 |
+
self.skippedcount = 0
|
30 |
+
self._build()
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.inputs)
|
34 |
+
|
35 |
+
def __getitem__(self, index):
|
36 |
+
source_ids = self.inputs[index]["input_ids"].squeeze()
|
37 |
+
target_ids = self.targets[index]["input_ids"].squeeze()
|
38 |
+
|
39 |
+
src_mask = self.inputs[index][
|
40 |
+
"attention_mask"
|
41 |
+
].squeeze() # might need to squeeze
|
42 |
+
target_mask = self.targets[index][
|
43 |
+
"attention_mask"
|
44 |
+
].squeeze() # might need to squeeze
|
45 |
+
|
46 |
+
labels = copy.deepcopy(target_ids)
|
47 |
+
labels[labels == 0] = -100
|
48 |
+
|
49 |
+
return {
|
50 |
+
"source_ids": source_ids,
|
51 |
+
"source_mask": src_mask,
|
52 |
+
"target_ids": target_ids,
|
53 |
+
"target_mask": target_mask,
|
54 |
+
"labels": labels,
|
55 |
+
}
|
56 |
+
|
57 |
+
def _build(self):
|
58 |
+
for idx in tqdm(range(len(self.data))):
|
59 |
+
passage, answer, target = (
|
60 |
+
self.data.loc[idx, self.passage_column],
|
61 |
+
self.data.loc[idx, self.answer],
|
62 |
+
self.data.loc[idx, self.question],
|
63 |
+
)
|
64 |
+
|
65 |
+
input_ = "context: %s answer: %s </s>" % (passage, answer)
|
66 |
+
target = "question: %s </s>" % (str(target))
|
67 |
+
|
68 |
+
# get encoding length of input. If it is greater than self.max_len skip it
|
69 |
+
test_input_encoding = self.tokenizer.encode_plus(
|
70 |
+
input_, truncation=False, return_tensors="pt"
|
71 |
+
)
|
72 |
+
|
73 |
+
length_of_input_encoding = len(test_input_encoding["input_ids"][0])
|
74 |
+
|
75 |
+
if length_of_input_encoding > self.max_len_input:
|
76 |
+
self.skippedcount = self.skippedcount + 1
|
77 |
+
continue
|
78 |
+
|
79 |
+
# tokenize inputs
|
80 |
+
tokenized_inputs = self.tokenizer.batch_encode_plus(
|
81 |
+
[input_],
|
82 |
+
max_length=self.max_len_input,
|
83 |
+
pad_to_max_length=True,
|
84 |
+
return_tensors="pt",
|
85 |
+
)
|
86 |
+
# tokenize targets
|
87 |
+
tokenized_targets = self.tokenizer.batch_encode_plus(
|
88 |
+
[target],
|
89 |
+
max_length=self.max_len_output,
|
90 |
+
pad_to_max_length=True,
|
91 |
+
return_tensors="pt",
|
92 |
+
)
|
93 |
+
|
94 |
+
self.inputs.append(tokenized_inputs)
|
95 |
+
self.targets.append(tokenized_targets)
|
96 |
+
|
97 |
+
|
98 |
+
class T5FineTuner(pl.LightningModule):
|
99 |
+
def __init__(self, hparams, t5model, t5tokenizer):
|
100 |
+
super(T5FineTuner, self).__init__()
|
101 |
+
self.save_hyperparameters(hparams)
|
102 |
+
# self.hparams = hparams
|
103 |
+
self.model = t5model
|
104 |
+
self.tokenizer = t5tokenizer
|
105 |
+
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
input_ids,
|
109 |
+
attention_mask=None,
|
110 |
+
decoder_input_ids=None,
|
111 |
+
decoder_attention_mask=None,
|
112 |
+
lm_labels=None,
|
113 |
+
):
|
114 |
+
outputs = self.model(
|
115 |
+
input_ids=input_ids,
|
116 |
+
attention_mask=attention_mask,
|
117 |
+
decoder_attention_mask=decoder_attention_mask,
|
118 |
+
labels=lm_labels,
|
119 |
+
)
|
120 |
+
|
121 |
+
return outputs
|
122 |
+
|
123 |
+
def training_step(self, batch, batch_idx):
|
124 |
+
outputs = self.forward(
|
125 |
+
input_ids=batch["source_ids"],
|
126 |
+
attention_mask=batch["source_mask"],
|
127 |
+
decoder_input_ids=batch["target_ids"],
|
128 |
+
decoder_attention_mask=batch["target_mask"],
|
129 |
+
lm_labels=batch["labels"],
|
130 |
+
)
|
131 |
+
|
132 |
+
loss = outputs[0]
|
133 |
+
self.log("train_loss", loss)
|
134 |
+
return loss
|
135 |
+
|
136 |
+
def validation_step(self, batch, batch_idx):
|
137 |
+
outputs = self.forward(
|
138 |
+
input_ids=batch["source_ids"],
|
139 |
+
attention_mask=batch["source_mask"],
|
140 |
+
decoder_input_ids=batch["target_ids"],
|
141 |
+
decoder_attention_mask=batch["target_mask"],
|
142 |
+
lm_labels=batch["labels"],
|
143 |
+
)
|
144 |
+
|
145 |
+
loss = outputs[0]
|
146 |
+
self.log("val_loss", loss)
|
147 |
+
return loss
|
148 |
+
|
149 |
+
def train_dataloader(self):
|
150 |
+
return DataLoader(
|
151 |
+
train_dataset, batch_size=self.hparams.batch_size, num_workers=4
|
152 |
+
)
|
153 |
+
|
154 |
+
def val_dataloader(self):
|
155 |
+
return DataLoader(
|
156 |
+
validation_dataset, batch_size=self.hparams.batch_size, num_workers=4
|
157 |
+
)
|
158 |
+
|
159 |
+
def configure_optimizers(self):
|
160 |
+
optimizer = AdamW(self.parameters(), lr=3e-4, eps=1e-8)
|
161 |
+
return optimizer
|
162 |
+
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
pl.seed_everything(42)
|
166 |
+
train_file_path = "question_generator/dataset/squad_t5_train.csv"
|
167 |
+
validation_file_path = "question_generator/dataset/squad_t5_validaton.csv"
|
168 |
+
|
169 |
+
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
|
170 |
+
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
|
171 |
+
|
172 |
+
sample_encoding = t5_tokenizer.encode_plus(
|
173 |
+
"My name is Pipe San Martin",
|
174 |
+
max_length=64,
|
175 |
+
pad_to_max_length=True,
|
176 |
+
truncation=True,
|
177 |
+
return_tensors="pt",
|
178 |
+
)
|
179 |
+
|
180 |
+
print(sample_encoding.keys())
|
181 |
+
print(sample_encoding["input_ids"].shape)
|
182 |
+
print(sample_encoding["input_ids"].squeeze().shape)
|
183 |
+
print(sample_encoding["input_ids"])
|
184 |
+
tokenized_output = t5_tokenizer.convert_ids_to_tokens(
|
185 |
+
sample_encoding["input_ids"].squeeze()
|
186 |
+
)
|
187 |
+
print(f"Tokenized output: {tokenized_output}")
|
188 |
+
decoded_output = t5_tokenizer.decode(
|
189 |
+
sample_encoding["input_ids"].squeeze(),
|
190 |
+
skip_special_tokens=True,
|
191 |
+
clean_up_tokenization_spaces=True,
|
192 |
+
)
|
193 |
+
print(f"Decoded output: {decoded_output}")
|
194 |
+
train_dataset = QuestionGenerationDataset(t5_tokenizer, train_file_path)
|
195 |
+
|
196 |
+
train_sample = train_dataset[50]
|
197 |
+
decoded_train_input = t5_tokenizer.decode(train_sample["source_ids"])
|
198 |
+
decoded_train_output = t5_tokenizer.decode(train_sample["target_ids"])
|
199 |
+
|
200 |
+
print(decoded_train_input)
|
201 |
+
print(decoded_train_output)
|
202 |
+
|
203 |
+
validation_dataset = QuestionGenerationDataset(t5_tokenizer, validation_file_path)
|
204 |
+
args_dict = dict(
|
205 |
+
batch_size=4,
|
206 |
+
)
|
207 |
+
|
208 |
+
args = argparse.Namespace(**args_dict)
|
209 |
+
|
210 |
+
model = T5FineTuner(args, t5_model, t5_tokenizer)
|
211 |
+
|
212 |
+
trainer = pl.Trainer(max_epochs=1)
|
213 |
+
|
214 |
+
trainer.fit(model)
|
215 |
+
|
216 |
+
#print("Saving model")
|
217 |
+
#save_path_model = "question_generator/model/"
|
218 |
+
#save_path_tokenizer = "question_generator/tokenizer/"
|
219 |
+
#model.model.save_pretrained(save_path_model)
|
220 |
+
#t5_tokenizer.save_pretrained(save_path_tokenizer)
|