import argparse import logging import os from pathlib import Path from typing import Any, Dict import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info from transformers import ( AdamW, AutoConfig, AutoModel, AutoModelForPreTraining, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelWithLMHead, AutoTokenizer, PretrainedConfig, PreTrainedTokenizer, ) from transformers.optimization import ( Adafactor, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, ) from transformers.utils.versions import require_version logger = logging.getLogger(__name__) require_version("pytorch_lightning>=1.0.4") MODEL_MODES = { "base": AutoModel, "sequence-classification": AutoModelForSequenceClassification, "question-answering": AutoModelForQuestionAnswering, "pretraining": AutoModelForPreTraining, "token-classification": AutoModelForTokenClassification, "language-modeling": AutoModelWithLMHead, "summarization": AutoModelForSeq2SeqLM, "translation": AutoModelForSeq2SeqLM, } # update this and the import above to support new schedulers from transformers.optimization arg_to_scheduler = { "linear": get_linear_schedule_with_warmup, "cosine": get_cosine_schedule_with_warmup, "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, "polynomial": get_polynomial_decay_schedule_with_warmup, # '': get_constant_schedule, # not supported for now # '': get_constant_schedule_with_warmup, # not supported for now } arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}" class BaseTransformer(pl.LightningModule): def __init__( self, hparams: argparse.Namespace, num_labels=None, mode="base", config=None, tokenizer=None, model=None, **config_kwargs, ): """Initialize a model, tokenizer and config.""" super().__init__() # TODO: move to self.save_hyperparameters() # self.save_hyperparameters() # can also expand arguments into trainer signature for easier reading self.save_hyperparameters(hparams) self.step_count = 0 self.output_dir = Path(self.hparams.output_dir) cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None if config is None: self.config = AutoConfig.from_pretrained( self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, **({"num_labels": num_labels} if num_labels is not None else {}), cache_dir=cache_dir, **config_kwargs, ) else: self.config: PretrainedConfig = config extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") for p in extra_model_params: if getattr(self.hparams, p, None): assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute" setattr(self.config, p, getattr(self.hparams, p)) if tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained( self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, cache_dir=cache_dir, ) else: self.tokenizer: PreTrainedTokenizer = tokenizer self.model_type = MODEL_MODES[mode] if model is None: self.model = self.model_type.from_pretrained( self.hparams.model_name_or_path, from_tf=bool(".ckpt" in self.hparams.model_name_or_path), config=self.config, cache_dir=cache_dir, ) else: self.model = model def load_hf_checkpoint(self, *args, **kwargs): self.model = self.model_type.from_pretrained(*args, **kwargs) def get_lr_scheduler(self): get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] scheduler = get_schedule_func( self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps() ) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} return scheduler def configure_optimizers(self): """Prepare optimizer and schedule (linear warmup and decay)""" model = self.model no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], # check this named paramters "weight_decay": self.hparams.weight_decay, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] if self.hparams.adafactor: optimizer = Adafactor( optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False ) else: optimizer = AdamW( optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon ) self.opt = optimizer scheduler = self.get_lr_scheduler() return [optimizer], [scheduler] def test_step(self, batch, batch_nb): return self.validation_step(batch, batch_nb) def test_epoch_end(self, outputs): return self.validation_end(outputs) def total_steps(self) -> int: """The number of total training steps that will be run. Used for lr scheduler purposes.""" num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs def setup(self, stage): if stage == "test": self.dataset_size = len(self.test_dataloader().dataset) else: self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) self.dataset_size = len(self.train_dataloader().dataset) def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False): raise NotImplementedError("You must implement this for your task") def train_dataloader(self): return self.train_loader def val_dataloader(self): return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False) def test_dataloader(self): return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False) def _feature_file(self, mode): return os.path.join( self.hparams.data_dir, "cached_{}_{}_{}".format( mode, list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(), str(self.hparams.max_seq_length), ), ) @pl.utilities.rank_zero_only def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: save_path = self.output_dir.joinpath("best_tfmr") self.model.config.save_step = self.step_count self.model.save_pretrained(save_path) self.tokenizer.save_pretrained(save_path) @staticmethod def add_model_specific_args(parser, root_dir): parser.add_argument( "--model_name_or_path", default=None, type=str, required=True, help="Path to pretrained model or model identifier from huggingface.co/models", ) parser.add_argument( "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" ) parser.add_argument( "--tokenizer_name", default=None, type=str, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--cache_dir", default=str(Path(__file__).parent / "test_run" / "cache"), type=str, help="Where do you want to store the pre-trained models downloaded from huggingface.co", ) parser.add_argument( "--encoder_layerdrop", type=float, help="Encoder layer dropout probability (Optional). Goes into model.config", ) parser.add_argument( "--decoder_layerdrop", type=float, help="Decoder layer dropout probability (Optional). Goes into model.config", ) parser.add_argument( "--dropout", type=float, help="Dropout probability (Optional). Goes into model.config", ) parser.add_argument( "--attention_dropout", type=float, help="Attention dropout probability (Optional). Goes into model.config", ) parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument( "--lr_scheduler", default="linear", choices=arg_to_scheduler_choices, metavar=arg_to_scheduler_metavar, type=str, help="Learning rate scheduler", ) parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) parser.add_argument("--train_batch_size", default=32, type=int) parser.add_argument("--eval_batch_size", default=32, type=int) parser.add_argument("--adafactor", action="store_true") class InitCallback(pl.Callback): # this process can also be done with PL ddp plugging. # But still it is experimental (check original RAG, I updated that with pluggin (shamanez)) def on_sanity_check_start(self, trainer, pl_module): if ( trainer.is_global_zero and trainer.global_rank == 0 ): # we initialize the retriever only on master worker with RAY. In new pytorch-lightning accelorators are removed. pl_module.model.rag.retriever.init_retrieval() # better to use hook functions. class CheckParamCallback(pl.Callback): # check whether new added model paramters are differentiable def on_after_backward(self, trainer, pl_module): # print(pl_module.model.rag) for name, param in pl_module.model.rag.named_parameters(): if param.grad is None: print(name) class LoggingCallback(pl.Callback): def on_batch_end(self, trainer, pl_module): lr_scheduler = trainer.lr_schedulers[0]["scheduler"] lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())} pl_module.logger.log_metrics(lrs) def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): rank_zero_info("***** Validation results *****") metrics = trainer.callback_metrics # Log results for key in sorted(metrics): if key not in ["log", "progress_bar"]: rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): rank_zero_info("***** Test results *****") metrics = trainer.callback_metrics # Log and save results to file output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") with open(output_test_results_file, "w") as writer: for key in sorted(metrics): if key not in ["log", "progress_bar"]: rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) writer.write("{} = {}\n".format(key, str(metrics[key]))) def add_generic_args(parser, root_dir) -> None: # To allow all pl args uncomment the following line # parser = pl.Trainer.add_argparse_args(parser) parser.add_argument( "--output_dir", default=str(Path(__file__).parent / "test_run" / "model_checkpoints"), type=str, help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--fp16", action="store_true", help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", ) parser.add_argument( "--fp16_opt_level", type=str, default="O2", help=( "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html" ), ) parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm") parser.add_argument("--do_train", action="store_true", help="Whether to run training.") parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") parser.add_argument( "--gradient_accumulation_steps", dest="accumulate_grad_batches", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument( "--data_dir", default=str(Path(__file__).parent / "test_run" / "dummy-train-data"), type=str, help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", ) def generic_train( model: BaseTransformer, args: argparse.Namespace, early_stopping_callback=None, logger=True, # can pass WandbLogger() here extra_callbacks=[], checkpoint_callback=None, logging_callback=None, **extra_train_kwargs, ): pl.seed_everything(args.seed) # init model odir = Path(model.hparams.output_dir) odir.mkdir(exist_ok=True) # add custom checkpoints if checkpoint_callback is None: checkpoint_callback = pl.callbacks.ModelCheckpoint( filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 ) if early_stopping_callback: extra_callbacks.append(early_stopping_callback) if logging_callback is None: logging_callback = LoggingCallback() train_params = {} if args.fp16: train_params["precision"] = 16 if args.gpus > 1: train_params["accelerator"] = "auto" train_params["strategy"] = "ddp" train_params["accumulate_grad_batches"] = args.accumulate_grad_batches train_params["profiler"] = None train_params["devices"] = "auto" trainer = pl.Trainer.from_argparse_args( args, weights_summary=None, callbacks=[logging_callback] + extra_callbacks + [InitCallback()] + [checkpoint_callback], logger=logger, val_check_interval=1, num_sanity_val_steps=2, **train_params, ) if args.do_train: trainer.fit(model) else: print("RAG modeling tests with new set functions successfuly executed!") return trainer