Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| #!/usr/bin/env python | |
| import argparse | |
| import glob | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback | |
| from torch import nn | |
| from torch.utils.data import DataLoader | |
| from transformers import MBartTokenizer, T5ForConditionalGeneration | |
| from transformers.models.bart.modeling_bart import shift_tokens_right | |
| from utils import ( | |
| ROUGE_KEYS, | |
| LegacySeq2SeqDataset, | |
| Seq2SeqDataset, | |
| assert_all_frozen, | |
| calculate_bleu, | |
| calculate_rouge, | |
| check_output_dir, | |
| flatten_list, | |
| freeze_embeds, | |
| freeze_params, | |
| get_git_info, | |
| label_smoothed_nll_loss, | |
| lmap, | |
| pickle_save, | |
| save_git_info, | |
| save_json, | |
| use_task_specific_params, | |
| ) | |
| # need the parent dir module | |
| sys.path.insert(2, str(Path(__file__).resolve().parents[1])) | |
| from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa | |
| logger = logging.getLogger(__name__) | |
| class SummarizationModule(BaseTransformer): | |
| mode = "summarization" | |
| loss_names = ["loss"] | |
| metric_names = ROUGE_KEYS | |
| default_val_metric = "rouge2" | |
| def __init__(self, hparams, **kwargs): | |
| if hparams.sortish_sampler and hparams.gpus > 1: | |
| hparams.replace_sampler_ddp = False | |
| elif hparams.max_tokens_per_batch is not None: | |
| if hparams.gpus > 1: | |
| raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training") | |
| if hparams.sortish_sampler: | |
| raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously") | |
| super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) | |
| use_task_specific_params(self.model, "summarization") | |
| save_git_info(self.hparams.output_dir) | |
| self.metrics_save_path = Path(self.output_dir) / "metrics.json" | |
| self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" | |
| pickle_save(self.hparams, self.hparams_save_path) | |
| self.step_count = 0 | |
| self.metrics = defaultdict(list) | |
| self.model_type = self.config.model_type | |
| self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size | |
| self.dataset_kwargs: dict = { | |
| "data_dir": self.hparams.data_dir, | |
| "max_source_length": self.hparams.max_source_length, | |
| "prefix": self.model.config.prefix or "", | |
| } | |
| n_observations_per_split = { | |
| "train": self.hparams.n_train, | |
| "val": self.hparams.n_val, | |
| "test": self.hparams.n_test, | |
| } | |
| self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()} | |
| self.target_lens = { | |
| "train": self.hparams.max_target_length, | |
| "val": self.hparams.val_max_target_length, | |
| "test": self.hparams.test_max_target_length, | |
| } | |
| assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" | |
| assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" | |
| if self.hparams.freeze_embeds: | |
| freeze_embeds(self.model) | |
| if self.hparams.freeze_encoder: | |
| freeze_params(self.model.get_encoder()) | |
| assert_all_frozen(self.model.get_encoder()) | |
| self.hparams.git_sha = get_git_info()["repo_sha"] | |
| self.num_workers = hparams.num_workers | |
| self.decoder_start_token_id = None # default to config | |
| if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): | |
| self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] | |
| self.model.config.decoder_start_token_id = self.decoder_start_token_id | |
| self.dataset_class = ( | |
| Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset | |
| ) | |
| self.already_saved_batch = False | |
| self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams | |
| if self.hparams.eval_max_gen_length is not None: | |
| self.eval_max_length = self.hparams.eval_max_gen_length | |
| else: | |
| self.eval_max_length = self.model.config.max_length | |
| self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric | |
| def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]: | |
| """A debugging utility""" | |
| readable_batch = { | |
| k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items() | |
| } | |
| save_json(readable_batch, Path(self.output_dir) / "text_batch.json") | |
| save_json({k: v.tolist() for k, v in batch.items()}, Path(self.output_dir) / "tok_batch.json") | |
| self.already_saved_batch = True | |
| return readable_batch | |
| def forward(self, input_ids, **kwargs): | |
| return self.model(input_ids, **kwargs) | |
| def ids_to_clean_text(self, generated_ids: List[int]): | |
| gen_text = self.tokenizer.batch_decode( | |
| generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| ) | |
| return lmap(str.strip, gen_text) | |
| def _step(self, batch: dict) -> Tuple: | |
| pad_token_id = self.tokenizer.pad_token_id | |
| src_ids, src_mask = batch["input_ids"], batch["attention_mask"] | |
| tgt_ids = batch["labels"] | |
| if isinstance(self.model, T5ForConditionalGeneration): | |
| decoder_input_ids = self.model._shift_right(tgt_ids) | |
| else: | |
| decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) | |
| if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero | |
| batch["decoder_input_ids"] = decoder_input_ids | |
| self.save_readable_batch(batch) | |
| outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) | |
| lm_logits = outputs["logits"] | |
| if self.hparams.label_smoothing == 0: | |
| # Same behavior as modeling_bart.py, besides ignoring pad_token_id | |
| ce_loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id) | |
| assert lm_logits.shape[-1] == self.vocab_size | |
| loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) | |
| else: | |
| lprobs = nn.functional.log_softmax(lm_logits, dim=-1) | |
| loss, nll_loss = label_smoothed_nll_loss( | |
| lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id | |
| ) | |
| return (loss,) | |
| def pad(self) -> int: | |
| return self.tokenizer.pad_token_id | |
| def training_step(self, batch, batch_idx) -> Dict: | |
| loss_tensors = self._step(batch) | |
| logs = dict(zip(self.loss_names, loss_tensors)) | |
| # tokens per batch | |
| logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum() | |
| logs["bs"] = batch["input_ids"].shape[0] | |
| logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum() | |
| logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean() | |
| # TODO(SS): make a wandb summary metric for this | |
| return {"loss": loss_tensors[0], "log": logs} | |
| def validation_step(self, batch, batch_idx) -> Dict: | |
| return self._generative_step(batch) | |
| def validation_epoch_end(self, outputs, prefix="val") -> Dict: | |
| self.step_count += 1 | |
| losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} | |
| loss = losses["loss"] | |
| generative_metrics = { | |
| k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"] | |
| } | |
| metric_val = ( | |
| generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric] | |
| ) | |
| metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss) | |
| generative_metrics.update({k: v.item() for k, v in losses.items()}) | |
| losses.update(generative_metrics) | |
| all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} | |
| all_metrics["step_count"] = self.step_count | |
| self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path | |
| preds = flatten_list([x["preds"] for x in outputs]) | |
| return { | |
| "log": all_metrics, | |
| "preds": preds, | |
| f"{prefix}_loss": loss, | |
| f"{prefix}_{self.val_metric}": metric_tensor, | |
| } | |
| def calc_generative_metrics(self, preds, target) -> Dict: | |
| return calculate_rouge(preds, target) | |
| def _generative_step(self, batch: dict) -> dict: | |
| t0 = time.time() | |
| # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens') | |
| generated_ids = self.model.generate( | |
| batch["input_ids"], | |
| attention_mask=batch["attention_mask"], | |
| use_cache=True, | |
| decoder_start_token_id=self.decoder_start_token_id, | |
| num_beams=self.eval_beams, | |
| max_length=self.eval_max_length, | |
| ) | |
| gen_time = (time.time() - t0) / batch["input_ids"].shape[0] | |
| preds: List[str] = self.ids_to_clean_text(generated_ids) | |
| target: List[str] = self.ids_to_clean_text(batch["labels"]) | |
| loss_tensors = self._step(batch) | |
| base_metrics = dict(zip(self.loss_names, loss_tensors)) | |
| rouge: Dict = self.calc_generative_metrics(preds, target) | |
| summ_len = np.mean(lmap(len, generated_ids)) | |
| base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge) | |
| return base_metrics | |
| def test_step(self, batch, batch_idx): | |
| return self._generative_step(batch) | |
| def test_epoch_end(self, outputs): | |
| return self.validation_epoch_end(outputs, prefix="test") | |
| def get_dataset(self, type_path) -> Seq2SeqDataset: | |
| n_obs = self.n_obs[type_path] | |
| max_target_length = self.target_lens[type_path] | |
| dataset = self.dataset_class( | |
| self.tokenizer, | |
| type_path=type_path, | |
| n_obs=n_obs, | |
| max_target_length=max_target_length, | |
| **self.dataset_kwargs, | |
| ) | |
| return dataset | |
| def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: | |
| dataset = self.get_dataset(type_path) | |
| if self.hparams.sortish_sampler and type_path != "test" and type_path != "val": | |
| sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1) | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| collate_fn=dataset.collate_fn, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| sampler=sampler, | |
| ) | |
| elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val": | |
| batch_sampler = dataset.make_dynamic_sampler( | |
| self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1 | |
| ) | |
| return DataLoader( | |
| dataset, | |
| batch_sampler=batch_sampler, | |
| collate_fn=dataset.collate_fn, | |
| # shuffle=False, | |
| num_workers=self.num_workers, | |
| # batch_size=None, | |
| ) | |
| else: | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| collate_fn=dataset.collate_fn, | |
| shuffle=shuffle, | |
| num_workers=self.num_workers, | |
| sampler=None, | |
| ) | |
| def train_dataloader(self) -> DataLoader: | |
| dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) | |
| return dataloader | |
| def val_dataloader(self) -> DataLoader: | |
| return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size) | |
| def test_dataloader(self) -> DataLoader: | |
| return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size) | |
| def add_model_specific_args(parser, root_dir): | |
| BaseTransformer.add_model_specific_args(parser, root_dir) | |
| add_generic_args(parser, root_dir) | |
| parser.add_argument( | |
| "--max_source_length", | |
| default=1024, | |
| type=int, | |
| help=( | |
| "The maximum total input sequence length after tokenization. Sequences longer " | |
| "than this will be truncated, sequences shorter will be padded." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--max_target_length", | |
| default=56, | |
| type=int, | |
| help=( | |
| "The maximum total input sequence length after tokenization. Sequences longer " | |
| "than this will be truncated, sequences shorter will be padded." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--val_max_target_length", | |
| default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. | |
| type=int, | |
| help=( | |
| "The maximum total input sequence length after tokenization. Sequences longer " | |
| "than this will be truncated, sequences shorter will be padded." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--test_max_target_length", | |
| default=142, | |
| type=int, | |
| help=( | |
| "The maximum total input sequence length after tokenization. Sequences longer " | |
| "than this will be truncated, sequences shorter will be padded." | |
| ), | |
| ) | |
| parser.add_argument("--freeze_encoder", action="store_true") | |
| parser.add_argument("--freeze_embeds", action="store_true") | |
| parser.add_argument("--sortish_sampler", action="store_true", default=False) | |
| parser.add_argument("--overwrite_output_dir", action="store_true", default=False) | |
| parser.add_argument("--max_tokens_per_batch", type=int, default=None) | |
| parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default") | |
| parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") | |
| parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.") | |
| parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") | |
| parser.add_argument( | |
| "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all." | |
| ) | |
| parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) | |
| parser.add_argument("--src_lang", type=str, default="", required=False) | |
| parser.add_argument("--tgt_lang", type=str, default="", required=False) | |
| parser.add_argument("--eval_beams", type=int, default=None, required=False) | |
| parser.add_argument( | |
| "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None] | |
| ) | |
| parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens") | |
| parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save") | |
| parser.add_argument( | |
| "--early_stopping_patience", | |
| type=int, | |
| default=-1, | |
| required=False, | |
| help=( | |
| "-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So" | |
| " val_check_interval will effect it." | |
| ), | |
| ) | |
| return parser | |
| class TranslationModule(SummarizationModule): | |
| mode = "translation" | |
| loss_names = ["loss"] | |
| metric_names = ["bleu"] | |
| default_val_metric = "bleu" | |
| def __init__(self, hparams, **kwargs): | |
| super().__init__(hparams, **kwargs) | |
| self.dataset_kwargs["src_lang"] = hparams.src_lang | |
| self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang | |
| def calc_generative_metrics(self, preds, target) -> dict: | |
| return calculate_bleu(preds, target) | |
| def main(args, model=None) -> SummarizationModule: | |
| Path(args.output_dir).mkdir(exist_ok=True) | |
| check_output_dir(args, expected_items=3) | |
| if model is None: | |
| if "summarization" in args.task: | |
| model: SummarizationModule = SummarizationModule(args) | |
| else: | |
| model: SummarizationModule = TranslationModule(args) | |
| dataset = Path(args.data_dir).name | |
| if ( | |
| args.logger_name == "default" | |
| or args.fast_dev_run | |
| or str(args.output_dir).startswith("/tmp") | |
| or str(args.output_dir).startswith("/var") | |
| ): | |
| logger = True # don't pollute wandb logs unnecessarily | |
| elif args.logger_name == "wandb": | |
| from pytorch_lightning.loggers import WandbLogger | |
| project = os.environ.get("WANDB_PROJECT", dataset) | |
| logger = WandbLogger(name=model.output_dir.name, project=project) | |
| elif args.logger_name == "wandb_shared": | |
| from pytorch_lightning.loggers import WandbLogger | |
| logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") | |
| if args.early_stopping_patience >= 0: | |
| es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience) | |
| else: | |
| es_callback = False | |
| lower_is_better = args.val_metric == "loss" | |
| trainer: pl.Trainer = generic_train( | |
| model, | |
| args, | |
| logging_callback=Seq2SeqLoggingCallback(), | |
| checkpoint_callback=get_checkpoint_callback( | |
| args.output_dir, model.val_metric, args.save_top_k, lower_is_better | |
| ), | |
| early_stopping_callback=es_callback, | |
| logger=logger, | |
| ) | |
| pickle_save(model.hparams, model.output_dir / "hparams.pkl") | |
| if not args.do_predict: | |
| return model | |
| model.hparams.test_checkpoint = "" | |
| checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)) | |
| if checkpoints: | |
| model.hparams.test_checkpoint = checkpoints[-1] | |
| trainer.resume_from_checkpoint = checkpoints[-1] | |
| trainer.logger.log_hyperparams(model.hparams) | |
| # test() without a model tests using the best checkpoint automatically | |
| trainer.test() | |
| return model | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser = pl.Trainer.add_argparse_args(parser) | |
| parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) | |
| args = parser.parse_args() | |
| main(args) | |