wikipedia-assistant / training /run_seq2seq_no_trainer.py
Rschmaelzle's picture
Duplicate from deepset/wikipedia-assistant
2024325
import argparse
import logging
import math
import re
import numpy as np
import torch
from accelerate import Accelerator
from accelerate.utils import set_seed
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import get_scheduler, AutoTokenizer, AdamW, SchedulerType, AutoModelForSeq2SeqLM, \
DataCollatorWithPadding
from datasets import load_dataset
logger = logging.getLogger(__name__)
def get_parser():
parser = argparse.ArgumentParser(description="Train ELI5 seq2seq answer generation model")
parser.add_argument(
"--dataset_name",
type=str,
default="vblagoje/lfqa",
help="The name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=4,
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=4,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--pretrained_model_name",
type=str,
default="facebook/bart-large",
)
parser.add_argument(
"--model_save_name",
type=str,
default="eli5_bart_model",
)
parser.add_argument(
"--learning_rate",
type=float,
default=2e-4,
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.0,
help="Weight decay to use."
)
parser.add_argument(
"--log_freq",
type=int,
default=100,
help="Log train/validation loss every log_freq update steps"
)
parser.add_argument(
"--ignore_pad_token_for_loss",
type=bool,
default=True,
help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
)
parser.add_argument(
"--num_train_epochs",
type=int,
default=3,
)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=16,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--pad_to_max_length",
action="store_true",
help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
)
parser.add_argument(
"--overwrite_cache", type=bool, default=None, help="Overwrite the cached training and evaluation sets"
)
parser.add_argument(
"--max_source_length",
type=int,
default=1024,
help="The maximum total input sequence length after "
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--max_target_length",
type=int,
default=360,
help="The maximum total sequence length for target text after "
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="linear", # this is linear with warmup
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"--num_warmup_steps",
type=int,
default=None,
help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--warmup_percentage",
type=float,
default=0.08,
help="Number of steps for the warmup in the lr scheduler."
)
return parser
def cleanup_references(text):
# URL reference where we need to remove both the link text and URL
# ...and this letter is used by most biographers as the cornerstone of Lee's personal
# views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)).
# ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery.
result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE)
# URL reference where we need to preserve link text but remove URL
# At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South.
# At the outbreak of the Civil War, Leyburn left his church and joined the South.
result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE)
# lastly remove just dangling _URL_[0-9]_ URL references
result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE)
return result
def clean_answer(text):
result = cleanup_references(text)
result = result.replace("\n", " ")
result = re.sub(r"\s\s+", " ", result)
result = re.sub(r"BULLET::::-", "", result)
return result.strip()
def clean_question(text):
result = cleanup_references(text)
result = result.replace("\n", " ")
result = re.sub(r"\s\s+", " ", result)
result = result.replace("[deleted]", "")
return result.lower().strip()
def prepare_support_docs(example):
provenances = example["output"][-1]["provenance"]
context = "<P> " + " <P> ".join([p["text"] for p in provenances])
return {"context": context}
def preprocess_eli5(examples, **fn_kwargs):
document_cache = fn_kwargs["document_cache"]
training = fn_kwargs.get("training", True)
extra_answer_threshold = fn_kwargs.get("extra_answer_threshold", 3)
include_selftext = fn_kwargs.get("include_selftext", False)
exclude_answer_patterns = fn_kwargs.get("exclude_answer_patterns", [])
questions, contexts, answers = [], [], []
for q_id, question, selftext, answer in zip(examples["q_id"], examples["title"], examples["selftext"],
examples["answers"]):
accepted_answer_idx = []
if training:
accepted_answer_idx = [idx for idx, score in enumerate(answer["score"]) if
score > extra_answer_threshold]
if not training or not accepted_answer_idx:
accepted_answer_idx = [0]
document = document_cache[q_id]
for idx in accepted_answer_idx:
skip_answer = any([p.search(answer["text"][idx]) for p in exclude_answer_patterns])
if skip_answer:
continue
if include_selftext:
questions.append(clean_question(f"{question} {selftext}"))
else:
questions.append(clean_question(question))
contexts.append(document.lower().strip())
answers.append(clean_answer(answer["text"][idx]))
return {"question": questions, "context": contexts, "answer": answers}
def eval_qa_s2s_epoch(model, dataloader, accelerator, args):
model.eval()
num_eval_steps = math.ceil(len(dataloader))
progress_bar = tqdm(range(num_eval_steps), disable=not accelerator.is_local_main_process)
total_loss = 0.
with torch.no_grad():
for step, batch in enumerate(dataloader):
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.item()
progress_bar.update(1)
progress_bar.set_postfix(loss=round((total_loss / (step + 1)), 3))
return total_loss / (step + 1)
def train(config):
set_seed(42)
args = config["args"]
eli5 = load_dataset(args.dataset_name)
support_docs = load_dataset("vblagoje/lfqa_support_docs")
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
logger.info(accelerator.state)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(args.pretrained_model_name)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
processed_datasets = {}
support_docs_prepared = {}
with accelerator.main_process_first():
for split in ["train", "validation"]:
support_docs_prepared[split] = support_docs[split].map(prepare_support_docs,
batched=False,
cache_file_name=f"./support_docs_{split}.arrow",
load_from_cache_file=not args.overwrite_cache,
desc="Preparing support docs",
)
column_names = eli5["train"].column_names
for split in ["train", "validation"]:
d_cache = dict([(e["id"], e["context"]) for e in tqdm(support_docs_prepared[split],
desc=f"Adding support docs to LFQA {split}")])
processed_datasets[split] = eli5[split].map(preprocess_eli5,
batched=True,
remove_columns=column_names,
cache_file_name=f"./processed_datasets_{split}.arrow",
load_from_cache_file=not args.overwrite_cache,
desc="Preparing dataset for tokenization",
fn_kwargs={"document_cache": d_cache,
"training": split == "train",
"exclude_answer_patterns": [re.compile("not sure what you"),
re.compile("\n\n >")]}
)
padding = "max_length" if args.pad_to_max_length else False
# Temporarily set max_target_length for training.
max_target_length = args.max_target_length
label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
def tokenize_dataset(examples):
inputs = ["question: {} context: {}".format(q, c) for q, c in zip(examples["question"], examples["context"])]
targets = examples["answer"]
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, padding=True, truncation=True,
return_tensors="np")
model_inputs["decoder_input_ids"] = labels["input_ids"][:, :-1].tolist()
# replace pad_token_id with label_pad_token_id to avoid loss calculation on those tokens
labels["input_ids"] = np.where(labels["input_ids"] == tokenizer.pad_token_id,
label_pad_token_id, labels["input_ids"])
model_inputs["labels"] = labels["input_ids"][:, 1:].tolist()
return model_inputs
tokenized_datasets = {}
with accelerator.main_process_first():
for split, dataset in processed_datasets.items():
tokenized_datasets[split] = dataset.map(
tokenize_dataset,
batched=True,
cache_file_name=f"./tokenized_dataset_{split}.arrow",
remove_columns=dataset.column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset"
)
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]
train_dataset.set_format(type='torch')
eval_dataset.set_format(type='torch')
data_collator = DataCollatorWithPadding(tokenizer, "max_length")
# first epoch we don't shuffle
train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.per_device_train_batch_size,
collate_fn=data_collator)
eval_dataloader = DataLoader(eval_dataset, batch_size=args.per_device_eval_batch_size, collate_fn=data_collator)
# train the model
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader,
eval_dataloader)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps *
args.warmup_percentage)
scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=args.max_train_steps,
)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num eval examples = {len(eval_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
logger.info(f" Warmup steps = {num_warmup_steps}")
logger.info(f" Logging training progress every {args.log_freq} optimization steps")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
switched_train_dataloader = False
for epoch in range(args.num_train_epochs):
model.train()
if epoch > 0 and not switched_train_dataloader:
train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size,
shuffle=True, collate_fn=data_collator)
train_dataloader = accelerator.prepare(train_dataloader)
switched_train_dataloader = True
for step, batch in enumerate(train_dataloader):
outputs = model(**batch)
loss = torch.mean(outputs.loss)
accelerator.backward(loss)
if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
optimizer.step()
scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
progress_bar.set_postfix(loss=round(loss.item(), 3))
completed_steps += 1
if completed_steps >= args.max_train_steps:
break
if step % (args.log_freq * args.gradient_accumulation_steps) == 0:
validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args)
model.train()
logger.info(f"Train loss {loss.item()} , validation loss {validation_loss}")
if args.wandb and accelerator.is_local_main_process:
import wandb
wandb.log({"loss": loss.item(),
"lr": scheduler.get_last_lr()[0],
"validation_loss": validation_loss,
"completed_steps": completed_steps})
logger.info("Saving model {}".format(args.model_save_name))
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch))
# Calculating the validation loss over epoch
validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args)
logger.info("Epoch: {}".format(epoch))
logger.info("Validation loss: {}".format(validation_loss))
def main():
parser = get_parser()
parser.add_argument(
"--wandb",
action="store_true",
help="If true, use W&B logging",
)
main_args, _ = parser.parse_known_args()
config = {"args": main_args}
if main_args.wandb:
import wandb
wandb.init(project="Bart_ELI5")
train(config=config)
main()