|
from datetime import datetime |
|
from typing import Optional |
|
from datasets import load_dataset |
|
import evaluate |
|
from transformers import ( |
|
Seq2SeqTrainer, |
|
Seq2SeqTrainingArguments, |
|
AutoTokenizer, |
|
AutoModelForSeq2SeqLM, |
|
GenerationConfig, |
|
) |
|
import wandb |
|
|
|
|
|
|
|
BATCH_SIZE: int = 2 |
|
|
|
MAX_ANSWER_LENGTH: int = 512 |
|
|
|
|
|
run_name: str = f"vast-gpu_batch-size-{BATCH_SIZE}_ans-len-{MAX_ANSWER_LENGTH}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" |
|
wandb.init(project="led-finetune-lfqa", name=run_name) |
|
|
|
|
|
rouge = evaluate.load("rouge") |
|
|
|
|
|
|
|
pretrained_model_name = "allenai/led-base-16384" |
|
my_model_name = f"stefanbschneider/led-base-16384-lfqa-ans-len-{MAX_ANSWER_LENGTH}" |
|
model_name = my_model_name |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
model.gradient_checkpointing_enable() |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
def process_data_to_model_inputs(batch): |
|
|
|
input = [ |
|
f"question: {question}, context: {' '.join(context)}" |
|
for question, context in zip(batch["question"], batch["context"]) |
|
] |
|
|
|
|
|
inputs = tokenizer( |
|
input, |
|
padding="max_length", |
|
truncation=True, |
|
|
|
max_length=8192, |
|
) |
|
outputs = tokenizer( |
|
batch["answer"], |
|
padding="max_length", |
|
truncation=True, |
|
|
|
max_length=MAX_ANSWER_LENGTH, |
|
) |
|
|
|
batch["input_ids"] = inputs.input_ids |
|
batch["attention_mask"] = inputs.attention_mask |
|
|
|
|
|
batch["global_attention_mask"] = len(batch["input_ids"]) * [ |
|
[0 for _ in range(len(batch["input_ids"][0]))] |
|
] |
|
|
|
|
|
batch["global_attention_mask"][0][0] = 1 |
|
batch["labels"] = outputs.input_ids |
|
|
|
|
|
batch["labels"] = [ |
|
[-100 if token == tokenizer.pad_token_id else token for token in labels] |
|
for labels in batch["labels"] |
|
] |
|
|
|
return batch |
|
|
|
|
|
def load_and_process_dataset(split: str, dataset_limit: Optional[int] = None): |
|
"""Load and process the dataset for training or validation. Optionally limit the number of samples.""" |
|
dataset = load_dataset(f"stefanbschneider/lfqa-max-answer-length-{MAX_ANSWER_LENGTH}", split=split) |
|
|
|
|
|
if dataset_limit is not None: |
|
dataset = dataset.select(range(dataset_limit)) |
|
|
|
dataset = dataset.map( |
|
process_data_to_model_inputs, |
|
batched=True, |
|
batch_size=BATCH_SIZE, |
|
remove_columns=["context", "question", "answer"], |
|
) |
|
|
|
dataset.set_format( |
|
type="torch", |
|
columns=["input_ids", "attention_mask", "global_attention_mask", "labels"], |
|
) |
|
|
|
return dataset |
|
|
|
|
|
def compute_metrics(pred) -> dict[str, float]: |
|
"""Compute rouge score during validation/evaluation""" |
|
labels_ids = pred.label_ids |
|
pred_ids = pred.predictions |
|
|
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
|
labels_ids[labels_ids == -100] = tokenizer.pad_token_id |
|
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) |
|
|
|
rouge_output = rouge.compute( |
|
predictions=pred_str, references=label_str, rouge_types=["rouge2"] |
|
)["rouge2"] |
|
|
|
|
|
|
|
return {"rouge2": round(rouge_output, 4)} |
|
|
|
|
|
|
|
train_data = load_and_process_dataset("train", dataset_limit=None) |
|
val_data = load_and_process_dataset("validation", dataset_limit=64) |
|
|
|
|
|
|
|
|
|
generation_config = GenerationConfig( |
|
|
|
max_length=MAX_ANSWER_LENGTH, |
|
min_length=100, |
|
early_stopping=True, |
|
num_beams=4, |
|
length_penalty=2.0, |
|
|
|
no_repeat_ngram_size=3, |
|
decoder_start_token_id=tokenizer.cls_token_id, |
|
bos_token_id=tokenizer.bos_token_id, |
|
) |
|
model.generation_config = generation_config |
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
predict_with_generate=True, |
|
eval_strategy="steps", |
|
per_device_train_batch_size=BATCH_SIZE, |
|
per_device_eval_batch_size=BATCH_SIZE, |
|
|
|
fp16=True, |
|
output_dir=f"models/{my_model_name}", |
|
logging_steps=50, |
|
eval_steps=500, |
|
save_steps=100, |
|
|
|
save_total_limit=1, |
|
gradient_accumulation_steps=1, |
|
|
|
|
|
push_to_hub=True, |
|
hub_model_id=my_model_name, |
|
log_level="info", |
|
report_to="wandb", |
|
run_name=run_name, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
processing_class=tokenizer, |
|
args=training_args, |
|
compute_metrics=compute_metrics, |
|
train_dataset=train_data, |
|
eval_dataset=val_data, |
|
) |
|
trainer.train() |
|
trainer.push_to_hub() |