IF3D / code /AGen_model.py
leobcc
AGen
08c9919
from v2a_model import V2AModel
from lib.model.networks import ImplicitNet
from lib.datasets import create_dataset
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torchvision import models
from pytorch_lightning.loggers import WandbLogger
import os
import glob
import yaml
class AGen_model(pl.LightningModule):
def __init__(self, opt):
super(AGen_model, self).__init__()
# Configuration options
self.opt = opt
# Implicit network
self.implicit_network = ImplicitNet(opt.model.implicit_network)
def training_step(self, batch):
# Each batch contains the path to one training video
video_path = batch
metainfo_path = os.path.join(video_path, 'confs', 'metainfo.yaml')
with open(metainfo_path, 'r') as file:
self.opt.dataset.metainfo = yaml.safe_load(file)
# Video reconstruction training step
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath="checkpoints/",
filename="{epoch:04d}-{loss}",
save_on_train_epoch_end=True,
save_last=True)
logger = WandbLogger(project=self.opt.project_name, name=f"{self.opt.exp}/{self.opt.run}")
v2a_trainer = pl.Trainer(
gpus=1,
accelerator="gpu",
callbacks=[checkpoint_callback],
max_epochs=8000,
check_val_every_n_epoch=50,
logger=logger,
log_every_n_steps=1,
num_sanity_val_steps=0
)
model = V2AModel(self.opt, self.implicit_network)
trainset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.train)
validset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.valid)
if self.opt.model.is_continue == True:
checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1]
v2a_trainer.fit(model, trainset, validset, ckpt_path=checkpoint)
else:
v2a_trainer.fit(model, trainset, validset)
# Inference on the V2AModel after fitting
model.eval()
# inference on the implicit network after fitting
return
def configure_optimizers(self):
# Define your optimizer(s) here
# Example optimizer
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer