Spaces:
Runtime error
Runtime error
| import os | |
| import platform | |
| import time | |
| import yaml | |
| import torch | |
| import datetime | |
| from torch.utils.tensorboard import SummaryWriter | |
| import torch.utils.data | |
| import numpy as np | |
| import glob | |
| import shutil | |
| from utils.net_util import to_cuda | |
| def worker_init_fn(worker_id): # set numpy's random seed | |
| seed = torch.initial_seed() | |
| seed = seed % (2 ** 32) | |
| np.random.seed(seed + worker_id) | |
| class BaseTrainer: | |
| def __init__(self, opt): | |
| self.opt = opt | |
| self.dataset = None | |
| self.network = None | |
| self.net_dict = {} | |
| self.optm_dict = {} | |
| self.update_keys = None | |
| self.lr_schedule_dict = {} | |
| self.iter_idx = 0 | |
| self.epoch_idx = 0 | |
| self.iter_num = 9999999999 | |
| self.loss_weight = self.opt['train']['loss_weight'] | |
| def load_pretrained(path, dict_): | |
| data = torch.load(path) | |
| for k in dict_: | |
| if k in data: | |
| print('# Loading %s...' % k) | |
| dict_[k].load_state_dict(data[k]) | |
| else: | |
| print('# %s not found!' % k) | |
| return data.get('epoch_idx', None) | |
| def load_ckpt(self, path, load_optm = True): | |
| epoch_idx = self.load_pretrained(path + '/net.pt', self.net_dict) | |
| if load_optm: | |
| if os.path.exists(path + '/optm.pt'): | |
| self.load_pretrained(path + '/optm.pt', self.optm_dict) | |
| else: | |
| print('# Optimizer not found!') | |
| return epoch_idx | |
| # @staticmethod | |
| def save_trained(self, path, dict_): | |
| data = {} | |
| for k in dict_: | |
| data[k] = dict_[k].state_dict() | |
| data.update({ | |
| 'epoch_idx': self.epoch_idx, | |
| }) | |
| torch.save(data, path) | |
| def save_ckpt(self, path, save_optm = True): | |
| self.save_trained(path + '/net.pt', self.net_dict) | |
| if save_optm: | |
| self.save_trained(path + '/optm.pt', self.optm_dict) | |
| def zero_grad(self): | |
| if self.update_keys is None: | |
| update_keys = self.optm_dict.keys() | |
| else: | |
| update_keys = self.update_keys | |
| for k in update_keys: | |
| self.optm_dict[k].zero_grad() | |
| def step(self): | |
| if self.update_keys is None: | |
| update_keys = self.optm_dict.keys() | |
| else: | |
| update_keys = self.update_keys | |
| for k in update_keys: | |
| self.optm_dict[k].step() | |
| def update_lr(self, iter_idx): | |
| lr_dict = {} | |
| if self.update_keys is None: | |
| update_keys = self.optm_dict.keys() | |
| else: | |
| update_keys = self.update_keys | |
| for k in update_keys: | |
| lr = self.lr_schedule_dict[k].get_learning_rate(iter_idx) | |
| for param_group in self.optm_dict[k].param_groups: | |
| param_group['lr'] = lr | |
| lr_dict[k] = lr | |
| return lr_dict | |
| def set_dataset(self, dataset): | |
| self.dataset = dataset | |
| def set_network(self, network): | |
| self.network = network | |
| def set_net_dict(self, net_dict): | |
| self.net_dict = net_dict | |
| def set_optm_dict(self, optm_dict): | |
| self.optm_dict = optm_dict | |
| def set_update_keys(self, update_keys): | |
| self.update_keys = update_keys | |
| def set_lr_schedule_dict(self, lr_schedule_dict): | |
| self.lr_schedule_dict = lr_schedule_dict | |
| def set_train(self, flag = True): | |
| if flag: | |
| for k, net in self.net_dict.items(): | |
| if k in self.update_keys: | |
| net.train() | |
| else: | |
| net.eval() | |
| else: | |
| for k, net in self.net_dict.items(): | |
| net.eval() | |
| def train(self): | |
| # log | |
| os.makedirs(self.opt['train']['net_ckpt_dir'], exist_ok = True) | |
| log_dir = self.opt['train']['net_ckpt_dir'] + '/' + datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |
| os.makedirs(log_dir, exist_ok = True) | |
| writer = SummaryWriter(log_dir) | |
| yaml.dump(self.opt, open(log_dir + '/config_bk.yaml', 'w'), sort_keys = False) | |
| self.set_train() | |
| self.dataset.training = True | |
| batch_size = self.opt['train'].get('batch_size', 1) | |
| num_workers = self.opt['train'].get('num_workers', 0) | |
| dataloader = torch.utils.data.DataLoader(self.dataset, | |
| batch_size = batch_size, | |
| shuffle = True, | |
| num_workers = num_workers, | |
| worker_init_fn = worker_init_fn, | |
| drop_last = True) | |
| self.batch_num = len(self.dataset) // batch_size | |
| if self.opt['train'].get('save_init_ckpt', False) and self.opt['train'].get('start_epoch', 0) == 0: | |
| init_folder = self.opt['train']['net_ckpt_dir'] + '/init_ckpt' | |
| if not os.path.exists(init_folder) or self.opt['train']['start_epoch'] == 0: | |
| os.makedirs(init_folder, exist_ok = True) | |
| self.save_ckpt(init_folder, False) | |
| else: | |
| print('# Init checkpoint has been saved!') | |
| if self.opt['train']['prev_ckpt'] is not None: | |
| start_epoch = self.load_ckpt(self.opt['train']['prev_ckpt']) + 1 | |
| else: | |
| prev_ckpt_path = self.opt['train']['net_ckpt_dir'] + '/epoch_latest' | |
| if os.path.exists(prev_ckpt_path): | |
| start_epoch = self.load_ckpt(prev_ckpt_path) + 1 | |
| else: | |
| start_epoch = None | |
| if start_epoch is None: | |
| start_epoch = self.opt['train'].get('start_epoch', 0) | |
| end_epoch = self.opt['train'].get('end_epoch', 999) | |
| forward_one_pass = self.forward_one_pass | |
| for epoch_idx in range(start_epoch, end_epoch): | |
| self.epoch_idx = epoch_idx | |
| self.update_config_before_epoch(epoch_idx) | |
| epoch_losses = dict() | |
| time0 = time.time() | |
| for batch_idx, items in enumerate(dataloader): | |
| iter_idx = batch_idx + self.batch_num * epoch_idx | |
| self.iter_idx = iter_idx | |
| lr_dict = self.update_lr(iter_idx) | |
| items = to_cuda(items) | |
| loss, batch_losses = forward_one_pass(items) | |
| # self.zero_grad() | |
| # loss.backward() | |
| # self.step() | |
| # record batch loss | |
| log_info = 'epoch %d, batch %d, ' % (epoch_idx, batch_idx) | |
| log_info += 'lr: ' | |
| for k in lr_dict.keys(): | |
| log_info += '%s %e, ' % (k, lr_dict[k]) | |
| for key in batch_losses.keys(): | |
| log_info = log_info + ('%s: %f, ' % (key, batch_losses[key])) | |
| writer.add_scalar('%s/Batch' % key, batch_losses[key], iter_idx) | |
| if key in epoch_losses: | |
| epoch_losses[key] += batch_losses[key] | |
| else: | |
| epoch_losses[key] = batch_losses[key] | |
| print(log_info) | |
| with open(os.path.join(log_dir, 'loss.txt'), 'a') as fp: | |
| # record loss weight | |
| if batch_idx == 0: | |
| loss_weights_info = '' | |
| for k in self.opt['train']['loss_weight'].keys(): | |
| loss_weights_info += '%s: %f, ' % (k, self.opt['train']['loss_weight'][k]) | |
| fp.write('# Loss weights: \n' + loss_weights_info + '\n') | |
| fp.write(log_info + '\n') | |
| if iter_idx % self.opt['train']['ckpt_interval']['batch'] == 0 and iter_idx != 0: | |
| for folder in glob.glob(self.opt['train']['net_ckpt_dir'] + '/batch_*'): | |
| shutil.rmtree(folder) | |
| model_folder = self.opt['train']['net_ckpt_dir'] + '/batch_%d' % iter_idx | |
| os.makedirs(model_folder, exist_ok = True) | |
| self.save_ckpt(model_folder, save_optm = False) | |
| if iter_idx % self.opt['train']['eval_interval'] == 0 and iter_idx != 0: | |
| # if True: | |
| self.mini_test() | |
| self.set_train() | |
| time1 = time.time() | |
| print('One iteration costs %f secs' % (time1 - time0)) | |
| time0 = time1 | |
| if iter_idx == self.iter_num: | |
| return | |
| """ EPOCH """ | |
| # record epoch loss | |
| for key in epoch_losses.keys(): | |
| epoch_losses[key] /= self.batch_num | |
| writer.add_scalar('%s/Epoch' % key, epoch_losses[key], epoch_idx) | |
| if epoch_idx % self.opt['train']['ckpt_interval']['epoch'] == 0: | |
| model_folder = self.opt['train']['net_ckpt_dir'] + '/epoch_%d' % epoch_idx | |
| os.makedirs(model_folder, exist_ok = True) | |
| self.save_ckpt(model_folder) | |
| if self.batch_num > 50: | |
| latest_folder = self.opt['train']['net_ckpt_dir'] + '/epoch_latest' | |
| os.makedirs(latest_folder, exist_ok = True) | |
| self.save_ckpt(latest_folder) | |
| writer.close() | |
| def mini_test(self): | |
| """ Test during training """ | |
| pass | |
| def forward_one_pass(self, items): | |
| raise NotImplementedError('"forward_one_pass" method is not implemented!') | |
| def update_config_before_epoch(self, epoch_idx): | |
| pass | |