File size: 1,511 Bytes
bf8981a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch
import json
import os
from datetime import datetime
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from .learner import DiffproLearner
class TrainConfig:
model: torch.nn.Module
train_dl: DataLoader
val_dl: DataLoader
optimizer: Optimizer
def __init__(self, params, param_scheduler, output_dir) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.params = params
self.param_scheduler = param_scheduler
self.output_dir = output_dir
def train(self):
# collect and display total parameters
total_parameters = sum(
p.numel() for p in self.model.parameters() if p.requires_grad
)
print(f"Total parameters: {total_parameters}")
# dealing with the output storing
output_dir = self.output_dir
if os.path.exists(f"{output_dir}/chkpts/weights.pt"):
print("Checkpoint already exists.")
if input("Resume training? (y/n)") != "y":
return
else:
output_dir = f"{output_dir}/{datetime.now().strftime('%m-%d_%H%M%S')}"
print(f"Creating new log folder as {output_dir}")
# prepare the learner structure and parameters
learner = DiffproLearner(
output_dir, self.model, self.train_dl, self.val_dl, self.optimizer,
self.params, self.param_scheduler
)
learner.train(max_epoch=self.params.max_epoch)
|