Spaces:
Paused
Paused
| import logging | |
| from pathlib import Path | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | |
| from pytorch_lightning.utilities import rank_zero_only | |
| from utils_rag import save_json | |
| def count_trainable_parameters(model): | |
| model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
| params = sum([np.prod(p.size()) for p in model_parameters]) | |
| return params | |
| logger = logging.getLogger(__name__) | |
| def get_checkpoint_callback(output_dir, metric): | |
| """Saves the best model by validation EM score.""" | |
| if metric == "rouge2": | |
| exp = "{val_avg_rouge2:.4f}-{step_count}" | |
| elif metric == "bleu": | |
| exp = "{val_avg_bleu:.4f}-{step_count}" | |
| elif metric == "em": | |
| exp = "{val_avg_em:.4f}-{step_count}" | |
| else: | |
| raise NotImplementedError( | |
| f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this" | |
| " function." | |
| ) | |
| checkpoint_callback = ModelCheckpoint( | |
| dirpath=output_dir, | |
| filename=exp, | |
| monitor=f"val_{metric}", | |
| mode="max", | |
| save_top_k=3, | |
| every_n_epochs=1, # maybe save a checkpoint every time val is run, not just end of epoch. | |
| ) | |
| return checkpoint_callback | |
| def get_early_stopping_callback(metric, patience): | |
| return EarlyStopping( | |
| monitor=f"val_{metric}", # does this need avg? | |
| mode="min" if "loss" in metric else "max", | |
| patience=patience, | |
| verbose=True, | |
| ) | |
| class Seq2SeqLoggingCallback(pl.Callback): | |
| def on_batch_end(self, trainer, pl_module): | |
| lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} | |
| pl_module.logger.log_metrics(lrs) | |
| def _write_logs( | |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True | |
| ) -> None: | |
| logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") | |
| metrics = trainer.callback_metrics | |
| trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) | |
| # Log results | |
| od = Path(pl_module.hparams.output_dir) | |
| if type_path == "test": | |
| results_file = od / "test_results.txt" | |
| generations_file = od / "test_generations.txt" | |
| else: | |
| # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json | |
| # If people want this it will be easy enough to add back. | |
| results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" | |
| generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" | |
| results_file.parent.mkdir(exist_ok=True) | |
| generations_file.parent.mkdir(exist_ok=True) | |
| with open(results_file, "a+") as writer: | |
| for key in sorted(metrics): | |
| if key in ["log", "progress_bar", "preds"]: | |
| continue | |
| val = metrics[key] | |
| if isinstance(val, torch.Tensor): | |
| val = val.item() | |
| msg = f"{key}: {val:.6f}\n" | |
| writer.write(msg) | |
| if not save_generations: | |
| return | |
| if "preds" in metrics: | |
| content = "\n".join(metrics["preds"]) | |
| generations_file.open("w+").write(content) | |
| def on_train_start(self, trainer, pl_module): | |
| try: | |
| npars = pl_module.model.model.num_parameters() | |
| except AttributeError: | |
| npars = pl_module.model.num_parameters() | |
| n_trainable_pars = count_trainable_parameters(pl_module) | |
| # mp stands for million parameters | |
| trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) | |
| def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | |
| save_json(pl_module.metrics, pl_module.metrics_save_path) | |
| return self._write_logs(trainer, pl_module, "test") | |
| def on_validation_end(self, trainer: pl.Trainer, pl_module): | |
| save_json(pl_module.metrics, pl_module.metrics_save_path) | |
| # Uncommenting this will save val generations | |
| # return self._write_logs(trainer, pl_module, "valid") | |