import os import torch from lightning_fabric import seed_everything import pytorch_lightning as pl from pytorch_lightning.loggers.wandb import WandbLogger import datetime import wandb from src.callback import CALLBACK_REGISTRY from src.loop.feature_training_loop import FeatureTrainingLoop from src.loop.style_training_loop import StyleTrainingLoop from src.model import MODEL_REGISTRY from src.utils.opt import Opts from src.utils.renderer import OctreeRender_trilinear_fast def train(config): model = MODEL_REGISTRY.get(config["model"]["name"])(config) epoch = config["trainer"]["n_iters"] time_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") run_name = f"{config['global']['name']}-{time_str}" wandb_logger = WandbLogger( project=config["global"]["project_name"], name=run_name, save_dir=config["global"]["save_dir"], entity=config["global"]["username"], ) wandb_logger.watch((model)) wandb_logger.experiment.config.update(config) callbacks = [ CALLBACK_REGISTRY.get(mcfg["name"])(**mcfg["params"]) for mcfg in config["callbacks"] ] trainer = pl.Trainer( default_root_dir="src", check_val_every_n_epoch=config["trainer"]["evaluate_interval"], log_every_n_steps=config["trainer"]["log_interval"], enable_checkpointing=True, accelerator="gpu" if torch.cuda.is_available() else "auto", devices=-1, sync_batchnorm=True if torch.cuda.is_available() else False, precision=16 if config["trainer"]["use_fp16"] else 32, fast_dev_run=config["trainer"]["debug"], logger=wandb_logger, callbacks=callbacks, num_sanity_val_steps=-1, # Sanity full validation required for visualization callbacks deterministic=False, auto_lr_find=True, ) print("Trainer: ", trainer) if cfg["model"]["type"] == "feature": trainer.fit_loop = FeatureTrainingLoop(epoch=epoch, cfg=config, renderer=OctreeRender_trilinear_fast) elif cfg["model"]["type"] == "style": trainer.fit_loop = StyleTrainingLoop(epoch=epoch, cfg=config, renderer=OctreeRender_trilinear_fast) else: raise NotImplementedError trainer.fit(model, ckpt_path=config["global"]["resume"]) return os.path.join(os.path.join(os.path.join(config["global"]["save_dir"], config["global"]["project_name"]), wandb.run.id), "checkpoints") if __name__ == "__main__": cfg = Opts(cfg="configs/style_baseline.yml").parse_args() seed_everything(seed=cfg["global"]["SEED"]) train(cfg)