File size: 2,415 Bytes
08c9919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from v2a_model import V2AModel
from lib.model.networks import ImplicitNet
from lib.datasets import create_dataset

import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torchvision import models
from pytorch_lightning.loggers import WandbLogger
import os
import glob
import yaml

class AGen_model(pl.LightningModule):
    def __init__(self, opt):
        super(AGen_model, self).__init__()

        # Configuration options
        self.opt = opt

        # Implicit network
        self.implicit_network = ImplicitNet(opt.model.implicit_network)

    def training_step(self, batch):
        # Each batch contains the path to one training video
        video_path = batch

        metainfo_path = os.path.join(video_path, 'confs', 'metainfo.yaml')
        with open(metainfo_path, 'r') as file:
            self.opt.dataset.metainfo = yaml.safe_load(file)

        # Video reconstruction training step
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            dirpath="checkpoints/",
            filename="{epoch:04d}-{loss}",
            save_on_train_epoch_end=True,
            save_last=True)
        logger = WandbLogger(project=self.opt.project_name, name=f"{self.opt.exp}/{self.opt.run}")

        v2a_trainer = pl.Trainer(
            gpus=1,
            accelerator="gpu",
            callbacks=[checkpoint_callback],
            max_epochs=8000,
            check_val_every_n_epoch=50,
            logger=logger,
            log_every_n_steps=1,
            num_sanity_val_steps=0
        )

        model = V2AModel(self.opt, self.implicit_network)
        trainset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.train)
        validset = create_dataset(self.opt.dataset.metainfo, self.opt.dataset.valid)

        if self.opt.model.is_continue == True:
            checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1]
            v2a_trainer.fit(model, trainset, validset, ckpt_path=checkpoint)
        else: 
            v2a_trainer.fit(model, trainset, validset)


        # Inference on the V2AModel after fitting
        model.eval()


        # inference on the implicit network after fitting


        return 
    
    def configure_optimizers(self):
        # Define your optimizer(s) here
        
        # Example optimizer
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer