|
import argparse |
|
|
|
import pandas as pd |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
from transformers import AdamW, T5ForConditionalGeneration, T5Tokenizer |
|
|
|
from tqdm.notebook import tqdm |
|
import copy |
|
import pytorch_lightning as pl |
|
|
|
|
|
class QuestionGenerationDataset(Dataset): |
|
def __init__(self, tokenizer, filepath, max_len_inp=512, max_len_out=96): |
|
self.path = filepath |
|
|
|
self.passage_column = "context" |
|
self.answer = "answer" |
|
self.question = "question" |
|
|
|
|
|
self.data = pd.read_csv(self.path, nrows=1000) |
|
|
|
self.max_len_input = max_len_inp |
|
self.max_len_output = max_len_out |
|
self.tokenizer = tokenizer |
|
self.inputs = [] |
|
self.targets = [] |
|
self.skippedcount = 0 |
|
self._build() |
|
|
|
def __len__(self): |
|
return len(self.inputs) |
|
|
|
def __getitem__(self, index): |
|
source_ids = self.inputs[index]["input_ids"].squeeze() |
|
target_ids = self.targets[index]["input_ids"].squeeze() |
|
|
|
src_mask = self.inputs[index][ |
|
"attention_mask" |
|
].squeeze() |
|
target_mask = self.targets[index][ |
|
"attention_mask" |
|
].squeeze() |
|
|
|
labels = copy.deepcopy(target_ids) |
|
labels[labels == 0] = -100 |
|
|
|
return { |
|
"source_ids": source_ids, |
|
"source_mask": src_mask, |
|
"target_ids": target_ids, |
|
"target_mask": target_mask, |
|
"labels": labels, |
|
} |
|
|
|
def _build(self): |
|
for idx in tqdm(range(len(self.data))): |
|
passage, answer, target = ( |
|
self.data.loc[idx, self.passage_column], |
|
self.data.loc[idx, self.answer], |
|
self.data.loc[idx, self.question], |
|
) |
|
|
|
input_ = "context: %s answer: %s </s>" % (passage, answer) |
|
target = "question: %s </s>" % (str(target)) |
|
|
|
|
|
test_input_encoding = self.tokenizer.encode_plus( |
|
input_, truncation=False, return_tensors="pt" |
|
) |
|
|
|
length_of_input_encoding = len(test_input_encoding["input_ids"][0]) |
|
|
|
if length_of_input_encoding > self.max_len_input: |
|
self.skippedcount = self.skippedcount + 1 |
|
continue |
|
|
|
|
|
tokenized_inputs = self.tokenizer.batch_encode_plus( |
|
[input_], |
|
max_length=self.max_len_input, |
|
pad_to_max_length=True, |
|
return_tensors="pt", |
|
) |
|
|
|
tokenized_targets = self.tokenizer.batch_encode_plus( |
|
[target], |
|
max_length=self.max_len_output, |
|
pad_to_max_length=True, |
|
return_tensors="pt", |
|
) |
|
|
|
self.inputs.append(tokenized_inputs) |
|
self.targets.append(tokenized_targets) |
|
|
|
|
|
class T5FineTuner(pl.LightningModule): |
|
def __init__(self, hparams, t5model, t5tokenizer): |
|
super(T5FineTuner, self).__init__() |
|
self.save_hyperparameters(hparams) |
|
|
|
self.model = t5model |
|
self.tokenizer = t5tokenizer |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
decoder_input_ids=None, |
|
decoder_attention_mask=None, |
|
lm_labels=None, |
|
): |
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
decoder_attention_mask=decoder_attention_mask, |
|
labels=lm_labels, |
|
) |
|
|
|
return outputs |
|
|
|
def training_step(self, batch, batch_idx): |
|
outputs = self.forward( |
|
input_ids=batch["source_ids"], |
|
attention_mask=batch["source_mask"], |
|
decoder_input_ids=batch["target_ids"], |
|
decoder_attention_mask=batch["target_mask"], |
|
lm_labels=batch["labels"], |
|
) |
|
|
|
loss = outputs[0] |
|
self.log("train_loss", loss) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
outputs = self.forward( |
|
input_ids=batch["source_ids"], |
|
attention_mask=batch["source_mask"], |
|
decoder_input_ids=batch["target_ids"], |
|
decoder_attention_mask=batch["target_mask"], |
|
lm_labels=batch["labels"], |
|
) |
|
|
|
loss = outputs[0] |
|
self.log("val_loss", loss) |
|
return loss |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
train_dataset, batch_size=self.hparams.batch_size, num_workers=4 |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
validation_dataset, batch_size=self.hparams.batch_size, num_workers=4 |
|
) |
|
|
|
def configure_optimizers(self): |
|
optimizer = AdamW(self.parameters(), lr=3e-4, eps=1e-8) |
|
return optimizer |
|
|
|
|
|
if __name__ == "__main__": |
|
pl.seed_everything(42) |
|
train_file_path = "question_generator/dataset/squad_t5_train.csv" |
|
validation_file_path = "question_generator/dataset/squad_t5_validaton.csv" |
|
|
|
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base") |
|
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base") |
|
|
|
sample_encoding = t5_tokenizer.encode_plus( |
|
"My name is Pipe San Martin", |
|
max_length=64, |
|
pad_to_max_length=True, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
print(sample_encoding.keys()) |
|
print(sample_encoding["input_ids"].shape) |
|
print(sample_encoding["input_ids"].squeeze().shape) |
|
print(sample_encoding["input_ids"]) |
|
tokenized_output = t5_tokenizer.convert_ids_to_tokens( |
|
sample_encoding["input_ids"].squeeze() |
|
) |
|
print(f"Tokenized output: {tokenized_output}") |
|
decoded_output = t5_tokenizer.decode( |
|
sample_encoding["input_ids"].squeeze(), |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True, |
|
) |
|
print(f"Decoded output: {decoded_output}") |
|
train_dataset = QuestionGenerationDataset(t5_tokenizer, train_file_path) |
|
|
|
train_sample = train_dataset[50] |
|
decoded_train_input = t5_tokenizer.decode(train_sample["source_ids"]) |
|
decoded_train_output = t5_tokenizer.decode(train_sample["target_ids"]) |
|
|
|
print(decoded_train_input) |
|
print(decoded_train_output) |
|
|
|
validation_dataset = QuestionGenerationDataset(t5_tokenizer, validation_file_path) |
|
args_dict = dict( |
|
batch_size=4, |
|
) |
|
|
|
args = argparse.Namespace(**args_dict) |
|
|
|
model = T5FineTuner(args, t5_model, t5_tokenizer) |
|
|
|
trainer = pl.Trainer(max_epochs=1) |
|
|
|
trainer.fit(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|