import torch import torch.nn as nn import torch.optim as optim import numpy as np import albumentations as albu import argparse import datetime from utils.utils import open_json, weights_init, weights_init_spectr, generate_mask from model.models import Colorizer, Generator, Content, Discriminator from model.extractor import get_seresnext_extractor from dataset.datasets import TrainDataset, FineTuningDataset from PIL import Image def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("-p", "--path", required=True, help = "dataset path") parser.add_argument('-ft', '--fine_tuning', dest = 'fine_tuning', action = 'store_true') parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true') parser.set_defaults(fine_tuning = False) parser.set_defaults(gpu = False) args = parser.parse_args() return args def get_transforms(): return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.) def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number): train_dataset = TrainDataset(data_path, transforms) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True) if fine_tuning: finetuning_dataset = FineTuningDataset(data_path, transforms) finetuning_dataloader = torch.utils.data.DataLoader(finetuning_dataset, batch_size = batch_size, shuffle = True) return train_dataloader, finetuning_dataloader def get_models(device): generator = Generator() extractor = get_seresnext_extractor() colorizer = Colorizer(generator, extractor) colorizer.extractor_eval() colorizer = colorizer.to(device) discriminator = Discriminator().to(device) content = Content('model/vgg16-397923af.pth').eval().to(device) for param in content.parameters(): param.requires_grad = False return colorizer, discriminator, content def set_weights(colorizer, discriminator): colorizer.generator.apply(weights_init) colorizer.load_extractor_weights(torch.load('model/extractor.pth')) discriminator.apply(weights_init_spectr) def generator_loss(disc_output, true_labels, main_output, guide_output, real_image, content_gen, content_true, dist_loss = nn.L1Loss(), content_dist_loss = nn.MSELoss(), class_loss = nn.BCEWithLogitsLoss()): sim_loss_full = dist_loss(main_output, real_image) sim_loss_guide = dist_loss(guide_output, real_image) adv_loss = class_loss(disc_output, true_labels) content_loss = content_dist_loss(content_gen, content_true) sum_loss = 10 * (sim_loss_full + 0.9 * sim_loss_guide) + adv_loss + content_loss return sum_loss def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr): optimizerG = optim.Adam(colorizer.generator.parameters(), lr = generator_lr, betas=(0.5, 0.9)) optimizerD = optim.Adam(discriminator.parameters(), lr = discriminator_lr, betas=(0.5, 0.9)) return optimizerG, optimizerD def generator_step(inputs, colorizer, discriminator, content, loss_function, optimizer, device, white_penalty = True): for p in discriminator.parameters(): p.requires_grad = False for p in colorizer.generator.parameters(): p.requires_grad = True colorizer.generator.zero_grad() bw, color, hint, dfm = inputs bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device) fake, guide = colorizer(torch.cat([bw, dfm, hint], 1)) logits_fake = discriminator(fake) y_real = torch.ones((bw.size(0), 1), device = device) content_fake = content(fake) with torch.no_grad(): content_true = content(color) generator_loss = loss_function(logits_fake, y_real, fake, guide, color, content_fake, content_true) if white_penalty: mask = (~((color > 0.85).float().sum(dim = 1) == 3).unsqueeze(1).repeat((1, 3, 1, 1 ))).float() white_zones = mask * (fake + 1) / 2 white_penalty = (torch.pow(white_zones.sum(dim = 1), 2).sum(dim = (1, 2)) / (mask.sum(dim = (1, 2, 3)) + 1)).mean() generator_loss += white_penalty generator_loss.backward() optimizer.step() return generator_loss.item() def discriminator_step(inputs, colorizer, discriminator, optimizer, device, loss_function = nn.BCEWithLogitsLoss()): for p in discriminator.parameters(): p.requires_grad = True for p in colorizer.generator.parameters(): p.requires_grad = False discriminator.zero_grad() bw, color, hint, dfm = inputs bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device) y_real = torch.full((bw.size(0), 1), 0.9, device = device) y_fake = torch.zeros((bw.size(0), 1), device = device) with torch.no_grad(): fake_color, _ = colorizer(torch.cat([bw, dfm, hint], 1)) fake_color.detach() logits_fake = discriminator(fake_color) logits_real = discriminator(color) fake_loss = loss_function(logits_fake, y_fake) real_loss = loss_function(logits_real, y_real) discriminator_loss = real_loss + fake_loss discriminator_loss.backward() optimizer.step() return discriminator_loss.item() def decrease_lr(optimizer, rate): for group in optimizer.param_groups: group['lr'] /= rate def set_lr(optimizer, value): for group in optimizer.param_groups: group['lr'] = value def train(colorizer, discriminator, content, dataloader, epochs, colorizer_optimizer, discriminator_optimizer, lr_decay_epoch = -1, device = 'cpu'): colorizer.generator.train() discriminator.train() disc_step = True for epoch in range(epochs): if (epoch == lr_decay_epoch): decrease_lr(colorizer_optimizer, 10) decrease_lr(discriminator_optimizer, 10) sum_disc_loss = 0 sum_gen_loss = 0 for n, inputs in enumerate(dataloader): if n % 5 == 0: print(datetime.datetime.now().time()) print('Step : %d Discr loss: %.4f Gen loss : %.4f \n'%(n, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1))) if disc_step: step_loss = discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device) sum_disc_loss += step_loss else: step_loss = generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device) sum_gen_loss += step_loss disc_step = disc_step ^ True print(datetime.datetime.now().time()) print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1))) def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_optimizer, device, loss_function = nn.BCEWithLogitsLoss()): for p in discriminator.parameters(): p.requires_grad = True for p in colorizer.generator.parameters(): p.requires_grad = False for cur_disc_step in range(5): discriminator.zero_grad() bw, dfm, color_for_real = data_iter.next() bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device) y_real = torch.full((bw.size(0), 1), 0.9, device = device) y_fake = torch.zeros((bw.size(0), 1), device = device) empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3] ).float().to(device) with torch.no_grad(): fake_color_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint ], 1)) fake_color_manga.detach() logits_fake = discriminator(fake_color_manga) logits_real = discriminator(color_for_real) fake_loss = loss_function(logits_fake, y_fake) real_loss = loss_function(logits_real, y_real) discriminator_loss = real_loss + fake_loss discriminator_loss.backward() disc_optimizer.step() for p in discriminator.parameters(): p.requires_grad = False for p in colorizer.generator.parameters(): p.requires_grad = True colorizer.generator.zero_grad() bw, dfm, _ = data_iter.next() bw, dfm = bw.to(device), dfm.to(device) y_real = torch.ones((bw.size(0), 1), device = device) empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3]).float().to(device) fake_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint], 1)) logits_fake = discriminator(fake_manga) adv_loss = loss_function(logits_fake, y_real) generator_loss = adv_loss generator_loss.backward() gen_optimizer.step() def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, data_iter, device = 'cpu'): colorizer.generator.train() discriminator.train() disc_step = True for n, inputs in enumerate(dataloader): if n == iterations: return if disc_step: discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device) else: generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device) disc_step = disc_step ^ True if n % 10 == 5: fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device) if __name__ == '__main__': args = parse_args() config = open_json('configs/train_config.json') if args.gpu: device = 'cuda' else: device = 'cpu' augmentations = get_transforms() train_dataloader, ft_dataloader = get_dataloaders(args.path, augmentations, config['batch_size'], args.fine_tuning, config['number_of_mults']) colorizer, discriminator, content = get_models(device) set_weights(colorizer, discriminator) gen_optimizer, disc_optimizer = get_optimizers(colorizer, discriminator, config['generator_lr'], config['discriminator_lr']) train(colorizer, discriminator, content, train_dataloader, config['epochs'], gen_optimizer, disc_optimizer, config['lr_decrease_epoch'], device) if args.fine_tuning: set_lr(gen_optimizer, config["finetuning_generator_lr"]) fine_tuning(colorizer, discriminator, content, train_dataloader, config['finetuning_iterations'], gen_optimizer, disc_optimizer, iter(ft_dataloader), device) torch.save(colorizer.generator.state_dict(), str(datetime.datetime.now().time()))