|
from v2a_model import V2AModel |
|
from lib.model.networks import ImplicitNet |
|
from lib.datasets import create_dataset |
|
|
|
import torch |
|
import torch.nn as nn |
|
import pytorch_lightning as pl |
|
import torch.nn.functional as F |
|
from torchvision import models |
|
from pytorch_lightning.loggers import WandbLogger |
|
import os |
|
import glob |
|
import yaml |
|
|
|
class AGen_model(pl.LightningModule): |
|
def __init__(self, opt): |
|
super(AGen_model, self).__init__() |
|
|
|
|
|
self.opt = opt |
|
|
|
|
|
self.implicit_network = ImplicitNet(opt.model.implicit_network) |
|
|
|
def training_step(self, batch): |
|
|
|
video_path = batch |
|
|
|
metainfo_path = os.path.join(video_path, 'confs', 'metainfo.yaml') |
|
with open(metainfo_path, 'r') as file: |
|
self.opt.dataset.metainfo = yaml.safe_load(file) |
|
|
|
|
|
checkpoint_callback = pl.callbacks.ModelCheckpoint( |
|
dirpath="checkpoints/", |
|
filename="{epoch:04d}-{loss}", |
|
save_on_train_epoch_end=True, |
|
save_last=True) |
|
logger = WandbLogger(project=self.opt.project_name, name=f"{self.opt.exp}/{self.opt.run}") |
|
|
|
v2a_trainer = pl.Trainer( |
|
gpus=1, |
|
accelerator="gpu", |
|
callbacks=[checkpoint_callback], |
|
max_epochs=8000, |
|
check_val_every_n_epoch=50, |
|
logger=logger, |
|
log_every_n_steps=1, |
|
num_sanity_val_steps=0 |
|
) |
|
|
|
model = V2AModel(self.opt, self.implicit_network) |
|
trainset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.train) |
|
validset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.valid) |
|
|
|
if self.opt.model.is_continue == True: |
|
checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1] |
|
v2a_trainer.fit(model, trainset, validset, ckpt_path=checkpoint) |
|
else: |
|
v2a_trainer.fit(model, trainset, validset) |
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
return |
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) |
|
return optimizer |
|
|