Spaces:
Runtime error
Runtime error
import os, glob, datetime, time | |
import argparse, json | |
import torch | |
import torch.optim as optim | |
from torch.autograd import Variable | |
import torchvision | |
from torch.utils.data import DataLoader | |
from torch.backends import cudnn | |
from model.base_module import tensor2array | |
from model.model import ResHalf | |
from model.loss import * | |
from utils.dataset import HalftoneVOC2012 as Dataset | |
from utils.util import ensure_dir, save_list, save_images_from_batch | |
class Trainer(): | |
def __init__(self, config, resume): | |
self.config = config | |
self.name = config['name'] | |
self.resume_path = resume | |
self.n_epochs = config['trainer']['epochs'] | |
self.with_cuda = config['cuda'] and torch.cuda.is_available() | |
self.seed = config['seed'] | |
self.start_epoch = 0 | |
self.save_freq = config['trainer']['save_epochs'] | |
self.checkpoint_dir = os.path.join(config['save_dir'], self.name) | |
ensure_dir(self.checkpoint_dir) | |
json.dump(config, open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'), | |
indent=4, sort_keys=False) | |
print("@Workspace: %s *************"%self.checkpoint_dir) | |
self.cache = os.path.join(self.checkpoint_dir, 'train_cache') | |
self.val_halftone = os.path.join(self.cache, 'halftone') | |
self.val_restored = os.path.join(self.cache, 'restored') | |
ensure_dir(self.val_halftone) | |
ensure_dir(self.val_restored) | |
## model | |
self.model = eval(config['model'])() | |
if self.config['multi-gpus']: | |
self.model = torch.nn.DataParallel(self.model).cuda() | |
elif self.with_cuda: | |
self.model = self.model.cuda() | |
## optimizer | |
self.optimizer = getattr(optim, config['optimizer_type'])(self.model.parameters(), **config['optimizer']) | |
self.lr_sheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, **config['lr_sheduler']) | |
## dataset loader | |
with open(os.path.join(config['data_dir'], config['data_loader']['dataset'])) as f: | |
dataset = json.load(f) | |
train_set = Dataset(dataset['train']) | |
self.train_data_loader = DataLoader(train_set, batch_size=config['data_loader']['batch_size'], | |
shuffle=config['data_loader']['shuffle'], | |
num_workers=config['data_loader']['num_workers']) | |
val_set = Dataset(dataset['val']) | |
self.valid_data_loader = DataLoader(val_set, batch_size=config['data_loader']['batch_size'], | |
shuffle=False, | |
num_workers=config['data_loader']['num_workers']) | |
# special dataloader: constant color images | |
with open(os.path.join(config['data_dir'], config['data_loader']['special_set'])) as f: | |
dataset = json.load(f) | |
specialSet = Dataset(dataset['train']) | |
self.specialDataloader = DataLoader(specialSet, batch_size=config['data_loader']['batch_size'], | |
shuffle=config['data_loader']['shuffle'], | |
num_workers=config['data_loader']['num_workers']) | |
## loss function | |
self.quantizeLoss = eval(config['quantizeLoss']) | |
self.quantizeLossWeight = config['quantizeLossWeight'] | |
self.toneLoss = eval(config['toneLoss']) | |
self.toneLossWeight = config['toneLossWeight'] | |
self.structureLoss = eval(config['structureLoss']) | |
self.structureLossWeight = config['structureLossWeight'] | |
self.restoreLoss = eval(config['restoreLoss']) | |
self.restoreLossWeight = config['restoreLossWeight'] | |
# quantize [-1,1] data to be {-1,1} | |
self.quantizer = lambda x: Quantize.apply(0.5 * (x + 1.)) * 2. - 1. | |
self.blueNoiseLossWeight = config['blueNoiseLossWeight'] | |
self.vggloss = Vgg19Loss() | |
self.vggLossWeight = config['vggLossWeight'] | |
# resume checkpoint or load warm-up checkpoint | |
checkpt_path = self.config['initial_ckpt'] | |
if self.resume_path: | |
checkpt_path = self.resume_path | |
assert os.path.exists(checkpt_path), 'Invalid checkpoint Path: %s' % checkpt_path | |
self.load_checkpoint(checkpt_path) | |
def _train(self): | |
torch.manual_seed(self.config['seed']) | |
torch.cuda.manual_seed(self.config['seed']) | |
cudnn.benchmark = True | |
start_time = time.time() | |
self.monitor_best = 999. | |
for epoch in range(self.start_epoch, self.n_epochs + 1): | |
ep_st = time.time() | |
epoch_loss = self._train_epoch(epoch) | |
# perform lr_sheduler | |
self.lr_sheduler.step(epoch_loss['total_loss']) | |
epoch_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] | |
epoch_metric = self._valid_epoch(epoch) | |
print("[*] --- epoch: %d/%d | loss: %4.4f | metric: %4.4f | time-consumed: %4.2f ---" % \ | |
(epoch+1, self.n_epochs, epoch_loss['total_loss'], epoch_metric, (time.time()-ep_st))) | |
# save losses and learning rate | |
epoch_loss['metric'] = epoch_metric | |
epoch_loss['lr'] = epoch_lr | |
self.save_loss(epoch_loss, epoch) | |
if ((epoch+1) % self.save_freq == 0 or epoch == (self.n_epochs-1)): | |
print('---------- saving model ...') | |
self.save_checkpoint(epoch) | |
if self.monitor_best > epoch_metric: | |
self.monitor_best = epoch_metric | |
self.save_checkpoint(epoch, save_best=True) | |
print("Training finished! consumed %f sec" % (time.time() - start_time)) | |
def _to_variable(self, data, target): | |
data, target = Variable(data), Variable(target) | |
if self.with_cuda: | |
data, target = data.cuda(), target.cuda() | |
return data, target | |
def _train_epoch(self, epoch): | |
self.model.train() | |
total_loss, quantize_loss, restore_loss = 0, 0, 0 | |
tone_loss, structure_loss, blue_noise_loss = 0, 0, 0 | |
specialIter = iter(self.specialDataloader) | |
time_stamp = time.time() | |
for batch_idx, (color, halftone) in enumerate(self.train_data_loader): | |
color, halftone = self._to_variable(color, halftone) | |
# special data | |
try: | |
specialColor, specialHalftone = next(specialIter) | |
except StopIteration: | |
# reinitialize data loader | |
specialIter = iter(self.specialDataloader) | |
specialColor, specialHalftone = next(specialIter) | |
specialColor, specialHalftone = self._to_variable(specialColor, specialHalftone) | |
self.optimizer.zero_grad() | |
output = self.model(color, halftone) | |
quantizeLoss = self.quantizeLoss(output[0]) | |
toneLoss = self.toneLoss(output[0], color) | |
structureLoss = self.structureLoss(output[0], color) | |
# restore | |
restoredColor = output[-1] | |
restoreLoss = self.restoreLoss(restoredColor, color) | |
vggLoss = self.vggloss(restoredColor / 2. + 0.5, color / 2. + 0.5) | |
# special data | |
output = self.model(specialColor, specialHalftone) | |
toneLossSpecial = self.toneLoss(output[0], specialColor) | |
blueNoiseLoss = l1_loss(output[1], output[2]) | |
quantizeLossSpecial = self.quantizeLoss(output[0]) | |
loss = (self.toneLossWeight * toneLoss + self.blueNoiseLossWeight*toneLossSpecial) \ | |
+ self.quantizeLossWeight * (0.5*quantizeLoss + 0.5*quantizeLossSpecial) \ | |
+ self.structureLossWeight * structureLoss \ | |
+ self.blueNoiseLossWeight * blueNoiseLoss \ | |
+ self.vggLossWeight * vggLoss \ | |
+ self.restoreLossWeight * restoreLoss | |
loss.backward() | |
# apply grad clip to make training roboust | |
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.0001) | |
self.optimizer.step() | |
total_loss += loss.item() | |
quantize_loss += quantizeLoss.item() | |
restore_loss += (self.restoreLossWeight*restoreLoss + self.vggLossWeight*vggLoss).item() | |
tone_loss += toneLoss.item() | |
structure_loss += structureLoss.item() | |
blue_noise_loss += blueNoiseLoss.item() | |
if batch_idx % 100 == 0: | |
tm = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') | |
print("%s >> [%d/%d] iter:%d loss:%4.4f "%(tm, epoch+1, self.n_epochs, batch_idx+1, loss.item())) | |
epoch_loss = dict() | |
epoch_loss['total_loss'] = total_loss / (batch_idx+1) | |
epoch_loss['quantize_loss'] = quantize_loss / (batch_idx+1) | |
epoch_loss['tone_loss'] = tone_loss / (batch_idx+1) | |
epoch_loss['structure_loss'] = structure_loss / (batch_idx+1) | |
epoch_loss['bluenoise_loss'] = blue_noise_loss / (batch_idx+1) | |
epoch_loss['restore_loss'] = restore_loss / (batch_idx+1) | |
return epoch_loss | |
def _valid_epoch(self, epoch): | |
self.model.eval() | |
total_loss = 0 | |
with torch.no_grad(): | |
for batch_idx, (color, halftone) in enumerate(self.valid_data_loader): | |
color, halftone = self._to_variable(color, halftone) | |
output = self.model(color, halftone) | |
quantizeLoss = self.quantizeLoss(output[0]) | |
toneLoss = self.toneLoss(output[0], color) | |
structureLoss = self.structureLoss(output[0], color) | |
# restore | |
restoredColor = output[-1] | |
restoreLoss = self.restoreLoss(restoredColor, color) | |
vggLoss = self.vggloss(restoredColor / 2. + 0.5, color / 2. + 0.5) | |
loss = self.toneLossWeight * toneLoss \ | |
+ self.quantizeLossWeight * quantizeLoss \ | |
+ self.structureLossWeight * structureLoss \ | |
+ self.vggLossWeight * vggLoss \ | |
+ self.restoreLossWeight * restoreLoss | |
total_loss += loss.item() | |
#! save intermediate images | |
gray_imgs = tensor2array(output[0]) | |
color_imgs = tensor2array(output[-1]) | |
save_images_from_batch(gray_imgs, self.val_halftone, None, batch_idx) | |
save_images_from_batch(color_imgs, self.val_restored, None, batch_idx) | |
return total_loss | |
def save_loss(self, epoch_loss, epoch): | |
if epoch == 0: | |
for key in epoch_loss: | |
save_list(os.path.join(self.cache, key), [epoch_loss[key]], append_mode=False) | |
else: | |
for key in epoch_loss: | |
save_list(os.path.join(self.cache, key), [epoch_loss[key]], append_mode=True) | |
def load_checkpoint(self, checkpt_path): | |
print("-loading checkpoint from: {} ...".format(checkpt_path)) | |
if self.resume_path: | |
checkpoint = torch.load(checkpt_path) | |
self.start_epoch = checkpoint['epoch'] + 1 | |
self.monitor_best = checkpoint['monitor_best'] | |
self.model.load_state_dict(checkpoint['state_dict']) | |
self.optimizer.load_state_dict(checkpoint['optimizer']) | |
else: | |
checkpoint = torch.load(checkpt_path) | |
self.model.load_state_dict(checkpoint['state_dict'], strict=False) | |
print("-pretrained checkpoint loaded.") | |
def save_checkpoint(self, epoch, save_best=False): | |
state = { | |
'epoch': epoch, | |
'state_dict': self.model.state_dict(), | |
'optimizer': self.optimizer.state_dict(), | |
'monitor_best': self.monitor_best | |
} | |
save_path = os.path.join(self.checkpoint_dir, 'model_last.pth.tar') | |
if save_best: | |
save_path = os.path.join(self.checkpoint_dir, 'model_best.pth.tar') | |
torch.save(state, save_path) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='InvHalf') | |
parser.add_argument('-c', '--config', default=None, type=str, | |
help='config file path (default: None)') | |
parser.add_argument('-r', '--resume', default=None, type=str, | |
help='path to latest checkpoint (default: None)') | |
args = parser.parse_args() | |
config_dict = json.load(open(args.config)) | |
node = Trainer(config_dict, resume=args.resume) | |
node._train() |