| | import os |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import init |
| |
|
| | from utils1.config import CONFIGCLASS |
| | from utils1.utils import get_network |
| | from utils1.warmup import GradualWarmupScheduler |
| |
|
| |
|
| | class BaseModel(nn.Module): |
| | def __init__(self, cfg: CONFIGCLASS): |
| | super().__init__() |
| | self.cfg = cfg |
| | self.total_steps = 0 |
| | self.isTrain = cfg.isTrain |
| | self.save_dir = cfg.ckpt_dir |
| | self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| | self.model:nn.Module |
| | self.model=nn.Module.to(self.device) |
| | |
| | |
| | self.optimizer: torch.optim.Optimizer |
| |
|
| | def save_networks(self, epoch: int): |
| | save_filename = f"model_epoch_{epoch}.pth" |
| | save_path = os.path.join(self.save_dir, save_filename) |
| |
|
| | |
| | state_dict = { |
| | "model": self.model.state_dict(), |
| | "optimizer": self.optimizer.state_dict(), |
| | "total_steps": self.total_steps, |
| | } |
| |
|
| | torch.save(state_dict, save_path) |
| |
|
| | |
| | def load_networks(self, epoch: int): |
| | load_filename = f"model_epoch_{epoch}.pth" |
| | load_path = os.path.join(self.save_dir, load_filename) |
| |
|
| | if epoch==0: |
| | |
| | load_path="checkpoints/optical.pth" |
| | print("loading optical path") |
| | else : |
| | print(f"loading the model from {load_path}") |
| | |
| | |
| |
|
| | |
| | |
| | state_dict = torch.load(load_path, map_location=self.device) |
| | if hasattr(state_dict, "_metadata"): |
| | del state_dict._metadata |
| |
|
| | self.model.load_state_dict(state_dict["model"]) |
| | self.total_steps = state_dict["total_steps"] |
| |
|
| | if self.isTrain and not self.cfg.new_optim: |
| | self.optimizer.load_state_dict(state_dict["optimizer"]) |
| | |
| | for state in self.optimizer.state.values(): |
| | for k, v in state.items(): |
| | if torch.is_tensor(v): |
| | state[k] = v.to(self.device) |
| |
|
| | for g in self.optimizer.param_groups: |
| | g["lr"] = self.cfg.lr |
| |
|
| | def eval(self): |
| | self.model.eval() |
| |
|
| | def test(self): |
| | with torch.no_grad(): |
| | self.forward() |
| |
|
| |
|
| | def init_weights(net: nn.Module, init_type="normal", gain=0.02): |
| | def init_func(m: nn.Module): |
| | classname = m.__class__.__name__ |
| | if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): |
| | if init_type == "normal": |
| | init.normal_(m.weight.data, 0.0, gain) |
| | elif init_type == "xavier": |
| | init.xavier_normal_(m.weight.data, gain=gain) |
| | elif init_type == "kaiming": |
| | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") |
| | elif init_type == "orthogonal": |
| | init.orthogonal_(m.weight.data, gain=gain) |
| | else: |
| | raise NotImplementedError(f"initialization method [{init_type}] is not implemented") |
| | if hasattr(m, "bias") and m.bias is not None: |
| | init.constant_(m.bias.data, 0.0) |
| | elif classname.find("BatchNorm2d") != -1: |
| | init.normal_(m.weight.data, 1.0, gain) |
| | init.constant_(m.bias.data, 0.0) |
| |
|
| | print(f"initialize network with {init_type}") |
| | net.apply(init_func) |
| |
|
| |
|
| | class Trainer(BaseModel): |
| | def name(self): |
| | return "Trainer" |
| |
|
| | def __init__(self, cfg: CONFIGCLASS): |
| | super().__init__(cfg) |
| | self.arch = cfg.arch |
| | self.model = get_network(self.arch, cfg.isTrain, cfg.continue_train, cfg.init_gain, cfg.pretrained) |
| |
|
| | self.loss_fn = nn.BCEWithLogitsLoss() |
| | |
| | if cfg.optim == "adam": |
| | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999)) |
| | elif cfg.optim == "sgd": |
| | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=5e-4) |
| | else: |
| | raise ValueError("optim should be [adam, sgd]") |
| | if cfg.warmup: |
| | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR( |
| | self.optimizer, cfg.nepoch - cfg.warmup_epoch, eta_min=1e-6 |
| | ) |
| | self.scheduler = GradualWarmupScheduler( |
| | self.optimizer, multiplier=1, total_epoch=cfg.warmup_epoch, after_scheduler=scheduler_cosine |
| | ) |
| | self.scheduler.step() |
| | if cfg.continue_train: |
| | self.load_networks(cfg.epoch) |
| | self.model.to(self.device) |
| |
|
| | |
| |
|
| | def adjust_learning_rate(self, min_lr=1e-6): |
| | for param_group in self.optimizer.param_groups: |
| | param_group["lr"] /= 10.0 |
| | if param_group["lr"] < min_lr: |
| | return False |
| | return True |
| |
|
| | def set_input(self, input): |
| | img, label, meta = input if len(input) == 3 else (input[0], input[1], {}) |
| | self.input = img.to(self.device) |
| | self.label = label.to(self.device).float() |
| | for k in meta.keys(): |
| | if isinstance(meta[k], torch.Tensor): |
| | meta[k] = meta[k].to(self.device) |
| | self.meta = meta |
| |
|
| | def forward(self): |
| | self.output = self.model(self.input, self.meta) |
| |
|
| | def get_loss(self): |
| | return self.loss_fn(self.output.squeeze(1), self.label) |
| |
|
| | def optimize_parameters(self): |
| | self.forward() |
| | self.loss = self.loss_fn(self.output.squeeze(1), self.label) |
| | self.optimizer.zero_grad() |
| | self.loss.backward() |
| | self.optimizer.step() |
| |
|