# -*- coding: utf-8 -*- """train.ipynb Automatically generated by Colaboratory. Original file is located at https://colab.research.google.com/drive/1nXacyY7r1lbMC9m9aZvuSOLc343bPtrV """ import os from data import create_dataset from models import create_model from util.visualizer import save_images from PIL import Image from torchvision import transforms import easydict import torch import numpy as np import cv2 from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact def build_esrgan( model_name = 'RealESRGAN_x4plus_anime_6B', outscale = 4, suffix = 'out', tile = 0, tile_pad = 10, pre_pad = 0, face_enhance = False, half = False, alpha_upsampler = 'realesrgan', ext = 'png' ): """Inference demo for Real-ESRGAN. """ args = easydict.EasyDict({ 'model_name' : model_name, 'outscale' : outscale, 'suffix' : suffix, 'tile' : tile, 'tile_pad' : tile_pad, 'pre_pad' : pre_pad, 'face_enhance' : face_enhance, 'half' : half, 'alpha_upsampler' : alpha_upsampler, 'ext' : ext }) # determine models according to model names args.model_name = args.model_name.split('.')[0] if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) netscale = 4 elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) netscale = 4 elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) netscale = 2 elif args.model_name in [ 'RealESRGANv2-anime-xsx2', 'RealESRGANv2-animevideo-xsx2-nousm', 'RealESRGANv2-animevideo-xsx2' ]: # x2 VGG-style model (XS size) model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu') netscale = 2 elif args.model_name in [ 'RealESRGANv2-anime-xsx4', 'RealESRGANv2-animevideo-xsx4-nousm', 'RealESRGANv2-animevideo-xsx4' ]: # x4 VGG-style model (XS size) model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') netscale = 4 # determine model paths model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth') if not os.path.isfile(model_path): model_path = os.path.join('realesrgan/weights', args.model_name + '.pth') if not os.path.isfile(model_path): raise ValueError(f'Model {args.model_name} does not exist.') # restorer upsampler = RealESRGANer( scale=netscale, model_path=model_path, model=model, tile=args.tile, tile_pad=args.tile_pad, pre_pad=args.pre_pad, half=args.half) return upsampler def build_pix2pix(model_epoch): opt = easydict.EasyDict({ 'isTrain' : False, 'use_wandb' : False, 'gpu_ids' : [], 'checkpoints_dir' : 'experiments', 'batch_size' : 1, 'model' : 'pix2pix', 'input_nc' : 3, 'output_nc' : 3, 'ngf' : 64, 'ndf' : 64, 'netD' : 'basic', 'netG' : 'unet_256', 'n_layers_D' : 3, 'norm' : 'batch', 'init_type' : 'normal', 'init_gain': 0.02, 'no_dropout' : False, 'direction' : 'AtoB', 'serial_batches' : True, 'num_threads' : 0, 'load_size' : 512, 'crop_size' : 256, 'max_dataset_size' : 50000, 'preprocess' : [], 'no_flip' : True, 'display_winsize' : 512, 'verbose' : False, 'suffix' : '', 'load_iter' : 0, #test_arguments 'aspect_ratio' : 1.0, 'phase' : 'test', 'eval' : True, 'num_test' : 1, 'model' : 'test', 'load_size' : 512, 'dataset_mode' : 'single', 'model_suffix' : '', 'epoch' : 110, #latest 'name' : 'pix2pix', }) opt.epoch = model_epoch model = create_model(opt) # create a model given opt.model and other options model.setup(opt) # regular setup: load and print networks; create schedulers model.eval() return model def image_preprosses(img, vivid): if (img.mode == 'RGBA') or (img.mode == 'P'): img.load() background = Image.new("RGB", img.size, (255, 255, 255)) background.paste(img, mask=img.split()[3]) # 3 is the alpha channel img = background assert (img.mode == 'RGB') width, height = img.size if not (width == height): minsize = min(width, height) left = (width - minsize)/2 top = (height - minsize)/2 right = (width + minsize)/2 bottom = (height + minsize)/2 img = img.crop((left, top, right, bottom)) assert img.width == img.height if (img.width < 400) or (vivid == True): img = np.array(img.resize((128,128))) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) else : img = img.resize((512,512)) return img def test_pix2pix(img, pix2pix): pretransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) img = pretransform(img) img = img.unsqueeze(dim=0) with torch.no_grad(): img = pix2pix.netG(img) img = img.data[0].float().numpy() img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0 img = img.astype(np.uint8) #img = Image.fromarray(img) return img