ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
import rootutils
import hydra
from omegaconf import DictConfig
import lightning as L
import torch
from pathlib import Path
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from typing import List, Dict, Any
# Setup root
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from src.utils import instantiate_callbacks, instantiate_loggers, RankedLogger, extras # noqa: E402
log = RankedLogger(__name__, rank_zero_only=True)
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Dict[str, Any]:
# Set seed
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
# Applies optional utilities
extras(cfg)
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: List[L.Callback] = instantiate_callbacks(cfg.get("callbacks"))
callbacks_cfg = cfg.get("callbacks")
if (
isinstance(callbacks_cfg, DictConfig)
and "model_checkpoint" in callbacks_cfg
and callbacks_cfg.model_checkpoint is None
):
log.warning(
"`callbacks.model_checkpoint` is null in the composed config. "
"Lightning will use its default ModelCheckpoint callback, which may not "
"save `last.ckpt` and can change filename conventions. Remove the null "
"override or set explicit checkpoint fields in the experiment config."
)
if cfg.get("train") and not any(
isinstance(callback, ModelCheckpoint) for callback in callbacks
):
log.warning(
"No explicit ModelCheckpoint callback was instantiated from config; "
"Lightning default checkpointing behavior will be used."
)
log.info("Instantiating loggers...")
logger: List[L.Logger] = instantiate_loggers(cfg.get("logger"))
# Set float32 matmul precision for Tensor Cores
torch.set_float32_matmul_precision("medium")
# Log config tree and .hydra folder to wandb
for lg in logger:
if isinstance(lg, WandbLogger):
# check if config_tree.log exists
config_tree_path = Path(cfg.paths.output_dir, "config_tree.log")
if config_tree_path.exists():
log.info("Logging config tree to WandB...")
lg.experiment.save(
str(config_tree_path), policy="now", base_path=cfg.paths.output_dir
)
# Upload .hydra folder contents
hydra_dir = Path(cfg.paths.output_dir, ".hydra")
if hydra_dir.exists() and hydra_dir.is_dir():
log.info("Logging .hydra folder to WandB...")
for hydra_file in hydra_dir.iterdir():
if hydra_file.is_file():
lg.experiment.save(
str(hydra_file),
policy="now",
base_path=cfg.paths.output_dir,
)
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: L.Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
return object_dict
if __name__ == "__main__":
main()