from v2a_model import V2AModel from lib.model.networks import ImplicitNet from lib.datasets import create_dataset import torch import torch.nn as nn import pytorch_lightning as pl import torch.nn.functional as F from torchvision import models from pytorch_lightning.loggers import WandbLogger import os import glob import yaml class AGen_model(pl.LightningModule): def __init__(self, opt): super(AGen_model, self).__init__() # Configuration options self.opt = opt # Implicit network self.implicit_network = ImplicitNet(opt.model.implicit_network) def training_step(self, batch): # Each batch contains the path to one training video video_path = batch metainfo_path = os.path.join(video_path, 'confs', 'metainfo.yaml') with open(metainfo_path, 'r') as file: self.opt.dataset.metainfo = yaml.safe_load(file) # Video reconstruction training step checkpoint_callback = pl.callbacks.ModelCheckpoint( dirpath="checkpoints/", filename="{epoch:04d}-{loss}", save_on_train_epoch_end=True, save_last=True) logger = WandbLogger(project=self.opt.project_name, name=f"{self.opt.exp}/{self.opt.run}") v2a_trainer = pl.Trainer( gpus=1, accelerator="gpu", callbacks=[checkpoint_callback], max_epochs=8000, check_val_every_n_epoch=50, logger=logger, log_every_n_steps=1, num_sanity_val_steps=0 ) model = V2AModel(self.opt, self.implicit_network) trainset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.train) validset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.valid) if self.opt.model.is_continue == True: checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1] v2a_trainer.fit(model, trainset, validset, ckpt_path=checkpoint) else: v2a_trainer.fit(model, trainset, validset) # Inference on the V2AModel after fitting model.eval() # inference on the implicit network after fitting return def configure_optimizers(self): # Define your optimizer(s) here # Example optimizer optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer