| from lightning.pytorch.utilities import rank_zero_only |
|
|
| from fish_speech.utils import logger as log |
|
|
|
|
| @rank_zero_only |
| def log_hyperparameters(object_dict: dict) -> None: |
| """Controls which config parts are saved by lightning loggers. |
| |
| Additionally saves: |
| - Number of model parameters |
| """ |
|
|
| hparams = {} |
|
|
| cfg = object_dict["cfg"] |
| model = object_dict["model"] |
| trainer = object_dict["trainer"] |
|
|
| if not trainer.logger: |
| log.warning("Logger not found! Skipping hyperparameter logging...") |
| return |
|
|
| hparams["model"] = cfg["model"] |
|
|
| |
| hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) |
| hparams["model/params/trainable"] = sum( |
| p.numel() for p in model.parameters() if p.requires_grad |
| ) |
| hparams["model/params/non_trainable"] = sum( |
| p.numel() for p in model.parameters() if not p.requires_grad |
| ) |
|
|
| hparams["data"] = cfg["data"] |
| hparams["trainer"] = cfg["trainer"] |
|
|
| hparams["callbacks"] = cfg.get("callbacks") |
| hparams["extras"] = cfg.get("extras") |
|
|
| hparams["task_name"] = cfg.get("task_name") |
| hparams["tags"] = cfg.get("tags") |
| hparams["ckpt_path"] = cfg.get("ckpt_path") |
| hparams["seed"] = cfg.get("seed") |
|
|
| |
| for logger in trainer.loggers: |
| logger.log_hyperparams(hparams) |
|
|