"""Training base class """ import torchvision.transforms as transforms import torchvision.utils as vutils import torch.backends.cudnn as cudnn import torch.nn.functional as F import torch.fft import torch import numpy as np import argparse import wandb import math import time import os from . import flow_transforms class TrainerBase(): def __init__(self, args): """ Initialization function. """ cudnn.benchmark = True os.environ['WANDB_DIR'] = args.work_dir args.use_wandb = (args.use_wandb == 1) if args.use_wandb: wandb.login(key="d56eb81cd6396f0a181524ba214f488cf281e76b") wandb.init(project=args.project_name, name=args.exp_name) wandb.config.update(args) self.mean_values = torch.tensor([0.411, 0.432, 0.45]).view(1, 3, 1, 1).cuda() self.color_palette = np.loadtxt('data/palette.txt',dtype=np.uint8).reshape(-1,3) self.args = args def init_dataset(self): """ Initialize dataset """ if self.args.dataset == 'BSD500': from ..data import BSD500 # ========== Data loading code ============== input_transform = transforms.Compose([ flow_transforms.ArrayToTensor(), transforms.Normalize(mean=[0,0,0], std=[255,255,255]), transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1]) ]) val_input_transform = transforms.Compose([ flow_transforms.ArrayToTensor(), transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1]) ]) target_transform = transforms.Compose([ flow_transforms.ArrayToTensor(), ]) co_transform = flow_transforms.Compose([ flow_transforms.RandomCrop((self.args.train_img_height , self.args.train_img_width)), flow_transforms.RandomVerticalFlip(), flow_transforms.RandomHorizontalFlip() ]) print("=> loading img pairs from '{}'".format(self.args.data)) train_set, val_set = BSD500(self.args.data, transform=input_transform, val_transform = val_input_transform, target_transform=target_transform, co_transform=co_transform, bi_filter=True) print('{} samples found, {} train samples and {} val samples '.format(len(val_set)+len(train_set), len(train_set), len(val_set))) self.train_loader = torch.utils.data.DataLoader( train_set, batch_size=self.args.batch_size, num_workers=self.args.workers, pin_memory=True, shuffle=True, drop_last=True) elif self.args.dataset == 'texture': from ..data.texture_v3 import Dataset dataset = Dataset(self.args.data_path, crop_size=self.args.train_img_height) self.train_loader = torch.utils.data.DataLoader(dataset = dataset, batch_size = self.args.batch_size, shuffle = True, num_workers = self.args.workers, drop_last = True) elif self.args.dataset == 'DIV2K': from basicsr.data import create_dataloader, create_dataset opt = {} opt['dist'] = False opt['phase'] = 'train' opt['name'] = 'DIV2K' opt['type'] = 'PairedImageDataset' opt['dataroot_gt'] = self.args.HR_dir opt['dataroot_lq'] = self.args.LR_dir opt['filename_tmpl'] = '{}' opt['io_backend'] = dict(type='disk') opt['gt_size'] = self.args.train_img_height opt['use_flip'] = True opt['use_rot'] = True opt['use_shuffle'] = True opt['num_worker_per_gpu'] = self.args.workers opt['batch_size_per_gpu'] = self.args.batch_size opt['scale'] = int(self.args.ratio) opt['dataset_enlarge_ratio'] = 1 dataset = create_dataset(opt) self.train_loader = create_dataloader( dataset, opt, num_gpu=1, dist=opt['dist'], sampler=None) else: raise ValueError("Unknown dataset: {}.".format(self.args.dataset)) def init_training(self): self.init_constant() self.init_dataset() self.define_model() self.define_criterion() self.define_optimizer() def adjust_learning_rate(self, iteration): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = self.args.lr * (0.95 ** (iteration // self.args.lr_decay_freq)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr def logging(self, iteration, epoch): print_str = "[{}/{}][{}/{}], ".format(iteration, len(self.train_loader), epoch, self.args.nepochs) for k,v in self.losses.items(): print_str += "{}: {:4f} ".format(k, v) print_str += "time: {:2f}.".format(self.iter_time) print(print_str) def get_sp_grid(self, H, W, G, R = 1): W = int(W // R) H = int(H // R) if G > min(H, W): raise ValueError('Grid size must be smaller than image size!') grid = torch.from_numpy(np.arange(G**2)).view(1, 1, G, G) grid = torch.cat([grid]*(int(math.ceil(W/G))), dim = -1) grid = torch.cat([grid]*(int(math.ceil(H/G))), dim = -2) grid = grid[:, :, :H, :W] return grid.float() def save_network(self, name = None): cpk = {} cpk['epoch'] = self.epoch cpk['lr'] = self.optimizer.param_groups[0]['lr'] if hasattr(self.model, 'module'): cpk['model'] = self.model.module.cpu().state_dict() else: cpk['model'] = self.model.cpu().state_dict() if name is None: out_path = os.path.join(self.args.out_dir, "cpk.pth") else: out_path = os.path.join(self.args.out_dir, name + ".pth") torch.save(cpk, out_path) self.model.cuda() return def init_constant(self): return def define_model(self): raise NotImplementedError def define_criterion(self): raise NotImplementedError def define_optimizer(self): raise NotImplementedError def display(self): raise NotImplementedError def forward(self): raise NotImplementedError def train(self): args = self.args total_iteration = 0 for epoch in range(args.nepochs): self.epoch = epoch for iteration, data in enumerate(self.train_loader): if args.dataset == 'BSD500': image = data[0].cuda() self.label = data[1].cuda() elif args.dataset == 'texture': image = data[0].cuda() self.image2 = data[1].cuda() else: image = data['lq'].cuda() self.gt = data['gt'].cuda() start_time = time.time() total_iteration += 1 self.optimizer.zero_grad() image = image.cuda() if args.dataset == 'BSD500': self.image = image + self.mean_values self.gt = self.image else: self.image = image self.forward() total_loss = 0 for k,v in self.losses.items(): if hasattr(args, '{}_wt'.format(k)): total_loss += v * getattr(args, '{}_wt'.format(k)) else: total_loss += v total_loss.backward() self.optimizer.step() end_time = time.time() self.iter_time = end_time - start_time self.adjust_learning_rate(total_iteration) if((iteration + 1) % args.log_freq == 0): self.logging(iteration, epoch) if args.use_wandb: wandb.log(self.losses) if(iteration % args.display_freq == 0): example_images = self.display() if args.use_wandb: wandb.log({'images': example_images}) if((epoch + 1) % args.save_freq == 0): self.save_network()