|
|
|
import argparse |
|
import logging |
|
import math |
|
import os |
|
from datetime import datetime |
|
import datasets |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from tqdm.auto import tqdm |
|
import sys |
|
import transformers |
|
from accelerate import Accelerator, DistributedType |
|
from shutil import copyfile |
|
import wandb |
|
import numpy as np |
|
|
|
from transformers import ( |
|
MODEL_MAPPING, |
|
AutoModelForMaskedLM, |
|
AutoTokenizer, |
|
DataCollatorForLanguageModeling, |
|
SchedulerType, |
|
get_scheduler |
|
) |
|
from transformers.utils.versions import require_version |
|
|
|
|
|
|
|
class TrainDataset(torch.utils.data.IterableDataset): |
|
def __init__(self, filepath, tokenizer, max_length, batch_size, train_samples): |
|
self.tokenizer = tokenizer |
|
self.fIn = open(filepath) |
|
self.max_length = max_length |
|
self.batch_size = batch_size |
|
self.train_samples = train_samples |
|
|
|
def __iter__(self): |
|
batch = [] |
|
for sent in self.fIn: |
|
batch.append(sent.strip()[0:1000]) |
|
|
|
if len(batch) >= self.batch_size: |
|
|
|
encoded = self.tokenizer(batch, add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True, padding=True) |
|
|
|
for idx in range(len(batch)): |
|
single_sample = {key: encoded[key][idx] for key in encoded} |
|
yield single_sample |
|
|
|
batch = [] |
|
|
|
def __len__(self): |
|
return self.train_samples |
|
|
|
|
|
|
|
|
|
|
|
|
|
class DevDataset(torch.utils.data.Dataset): |
|
def __init__(self, filepath, tokenizer, max_length): |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
with open(filepath) as fIn: |
|
sentences = [sent.strip() for sent in fIn] |
|
|
|
self.num_sentences = len(sentences) |
|
self.tokenized = self.tokenizer(sentences, add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True) |
|
|
|
def __getitem__(self, idx): |
|
return {key: self.tokenized[key][idx] for key in self.tokenized} |
|
|
|
def __len__(self): |
|
return self.num_sentences |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") |
|
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) |
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Masked Language Modeling task") |
|
parser.add_argument( |
|
"--dataset_config_name", |
|
type=str, |
|
default=None, |
|
help="The configuration name of the dataset to use (via the datasets library).", |
|
) |
|
parser.add_argument( |
|
"--train_file", type=str, default=None, help="A text file data (1 text per line).." |
|
) |
|
parser.add_argument( |
|
"--dev_file", type=str, default=None, help="A text file data (1 text per line)." |
|
) |
|
parser.add_argument( |
|
"--model_name", |
|
default="nicoladecao/msmarco-word2vec256000-distilbert-base-uncased", |
|
type=str, |
|
help="Path to pretrained model or model identifier from huggingface.co/models." |
|
) |
|
parser.add_argument( |
|
"--per_device_batch_size", |
|
type=int, |
|
default=16, |
|
help="Batch size (per device) for the training dataloader.", |
|
) |
|
parser.add_argument( |
|
"--learning_rate", |
|
type=float, |
|
default=5e-5, |
|
help="Initial learning rate (after the potential warmup period) to use.", |
|
) |
|
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") |
|
parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") |
|
parser.add_argument( |
|
"--max_train_steps", |
|
type=int, |
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
|
) |
|
parser.add_argument( |
|
"--gradient_accumulation_steps", |
|
type=int, |
|
default=1, |
|
help="Number of updates steps to accumulate before performing a backward/update pass.", |
|
) |
|
parser.add_argument( |
|
"--lr_scheduler_type", |
|
type=SchedulerType, |
|
default="linear", |
|
help="The scheduler type to use.", |
|
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], |
|
) |
|
parser.add_argument( |
|
"--num_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler." |
|
) |
|
parser.add_argument( |
|
"--model_type", |
|
type=str, |
|
default=None, |
|
help="Model type to use if training from scratch.", |
|
choices=MODEL_TYPES, |
|
) |
|
parser.add_argument( |
|
"--max_seq_length", |
|
type=int, |
|
default=256, |
|
help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.", |
|
) |
|
parser.add_argument( |
|
"--line_by_line", |
|
type=bool, |
|
default=True, |
|
help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.", |
|
) |
|
parser.add_argument( |
|
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" |
|
) |
|
parser.add_argument( |
|
"--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" |
|
) |
|
parser.add_argument("--mixed_precision", default="fp16") |
|
parser.add_argument("--train_samples", required=True, type=int) |
|
parser.add_argument("--eval_steps", default=10000, type=int) |
|
parser.add_argument("--max_grad_norm", default=1.0, type=float) |
|
parser.add_argument("--project", default="bert-word2vec") |
|
parser.add_argument("--freeze_emb_layer", default=False, action='store_true') |
|
parser.add_argument("--log_interval", default=1000, type=int) |
|
parser.add_argument("--ckp_steps", default=50000, type=int) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
return args |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
|
|
accelerator = Accelerator(mixed_precision=args.mixed_precision) |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
logger.info(accelerator.state) |
|
|
|
|
|
|
|
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) |
|
if accelerator.is_local_main_process: |
|
datasets.utils.logging.set_verbosity_warning() |
|
transformers.utils.logging.set_verbosity_info() |
|
else: |
|
datasets.utils.logging.set_verbosity_error() |
|
transformers.utils.logging.set_verbosity_error() |
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
model = AutoModelForMaskedLM.from_pretrained(args.model_name) |
|
|
|
|
|
if args.freeze_emb_layer: |
|
model.distilbert.embeddings.word_embeddings.requires_grad_(False) |
|
|
|
|
|
if accelerator.is_main_process: |
|
exp_name = f'{args.model_name.replace("/", "-")}-{"freeze_emb" if args.freeze_emb_layer else "update_emb"}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}' |
|
output_dir = os.path.join("output-mlm", exp_name) |
|
wandb.init(project=args.project, name=exp_name, config=args) |
|
|
|
os.makedirs(output_dir, exist_ok=False) |
|
|
|
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
train_script_path = os.path.join(output_dir, 'train_script.py') |
|
copyfile(__file__, train_script_path) |
|
with open(train_script_path, 'a') as fOut: |
|
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
|
|
|
|
|
total_batch_size = args.per_device_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
|
|
|
train_dataset = TrainDataset(args.train_file, tokenizer, args.max_seq_length, batch_size=total_batch_size, train_samples=args.train_samples) |
|
eval_dataset = DevDataset(args.dev_file, tokenizer, args.max_seq_length) |
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=args.mlm_probability) |
|
|
|
|
|
train_dataloader = DataLoader(train_dataset, collate_fn=data_collator, batch_size=args.per_device_batch_size) |
|
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_batch_size) |
|
|
|
|
|
|
|
no_decay = ["bias", "LayerNorm.weight"] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
"weight_decay": args.weight_decay, |
|
}, |
|
{ |
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
|
"weight_decay": 0.0, |
|
}, |
|
] |
|
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) |
|
|
|
|
|
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader) |
|
|
|
|
|
if accelerator.distributed_type == DistributedType.TPU: |
|
model.tie_weights() |
|
|
|
|
|
|
|
|
|
|
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
|
if args.max_train_steps is None: |
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
|
else: |
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
|
|
|
lr_scheduler = get_scheduler( |
|
name=args.lr_scheduler_type, |
|
optimizer=optimizer, |
|
num_warmup_steps=args.num_warmup_steps, |
|
num_training_steps=args.max_train_steps, |
|
) |
|
|
|
|
|
|
|
logger.info("***** Running training *****") |
|
logger.info(f" Num examples = {args.train_samples}") |
|
logger.info(f" Num Epochs = {args.num_train_epochs}") |
|
logger.info(f" Instantaneous batch size per device = {args.per_device_batch_size}") |
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
|
logger.info(f" Total optimization steps = {args.max_train_steps}") |
|
|
|
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, smoothing=0.05) |
|
completed_steps = 0 |
|
train_loss_values = [] |
|
|
|
best_eval_loss = 999999 |
|
if accelerator.is_main_process: |
|
best_ckp_dir = os.path.join(output_dir, "best") |
|
tokenizer.save_pretrained(best_ckp_dir) |
|
|
|
for epoch in range(args.num_train_epochs): |
|
logger.info(f"Start epoch {epoch}") |
|
model.train() |
|
for step, batch in enumerate(train_dataloader): |
|
outputs = model(**batch) |
|
loss = outputs.loss |
|
loss = loss / args.gradient_accumulation_steps |
|
|
|
if accelerator.is_main_process: |
|
train_loss_values.append(loss.cpu().item()) |
|
|
|
accelerator.backward(loss) |
|
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) |
|
if step % args.gradient_accumulation_steps == 0: |
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
progress_bar.update(1) |
|
completed_steps += 1 |
|
|
|
|
|
if accelerator.is_main_process: |
|
if completed_steps % args.log_interval == 0: |
|
wandb.log({"train/loss": np.mean(train_loss_values)}, step=completed_steps) |
|
train_loss_values = [] |
|
|
|
|
|
if completed_steps % args.eval_steps == 0: |
|
model.eval() |
|
losses = [] |
|
for step, batch in enumerate(eval_dataloader): |
|
with torch.no_grad(): |
|
outputs = model(**batch) |
|
|
|
loss = outputs.loss |
|
losses.append(accelerator.gather(loss.repeat(args.per_device_batch_size))) |
|
|
|
losses = torch.cat(losses) |
|
losses = losses[: len(eval_dataset)] |
|
try: |
|
eval_loss = torch.mean(losses) |
|
except OverflowError: |
|
eval_loss = float("inf") |
|
|
|
logger.info(f"step {completed_steps}: perplexity: {eval_loss}") |
|
if accelerator.is_main_process: |
|
wandb.log({"eval/loss": eval_loss}, step=completed_steps) |
|
|
|
model.train() |
|
|
|
|
|
accelerator.wait_for_everyone() |
|
if accelerator.is_main_process: |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save) |
|
with open(os.path.join(output_dir, "train_steps.log"), 'a') as fOut: |
|
fOut.write(f"{completed_steps}: {eval_loss}\n") |
|
|
|
|
|
if eval_loss < best_eval_loss: |
|
best_eval_loss = eval_loss |
|
unwrapped_model.save_pretrained(best_ckp_dir, save_function=accelerator.save) |
|
with open(os.path.join(best_ckp_dir, "train_steps.log"), 'a') as fOut: |
|
fOut.write(f"{completed_steps}: {eval_loss}\n") |
|
|
|
if accelerator.is_main_process and completed_steps % args.ckp_steps == 0: |
|
ckp_dir = os.path.join(output_dir, f"ckp-{int(completed_steps/1000)}k") |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
unwrapped_model.save_pretrained(ckp_dir, save_function=accelerator.save) |
|
tokenizer.save_pretrained(ckp_dir) |
|
with open(os.path.join(ckp_dir, "train_steps.log"), 'a') as fOut: |
|
fOut.write(f"{completed_steps}: {eval_loss}\n") |
|
|
|
|
|
if completed_steps >= args.max_train_steps: |
|
break |
|
|
|
if args.output_dir is not None: |
|
accelerator.wait_for_everyone() |
|
if accelerator.is_main_process: |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save) |
|
with open(os.path.join(output_dir, "train_steps.log"), 'a') as fOut: |
|
fOut.write(f"{completed_steps}\n") |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|
|
|