# -*- coding: utf-8 -*- """train.ipynb Automatically generated by Colaboratory. Original file is located at https://colab.research.google.com/drive/1nXacyY7r1lbMC9m9aZvuSOLc343bPtrV """ import os from collections import OrderedDict from torch.autograd import Variable from models.models import create_model from PIL import Image from torchvision import transforms import util.util as util import easydict import torch import numpy as np import cv2 from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact def build_esrgan( model_name = 'RealESRGAN_x4plus_anime_6B', outscale = 4, suffix = 'out', tile = 0, tile_pad = 10, pre_pad = 0, face_enhance = False, half = False, alpha_upsampler = 'realesrgan', ext = 'png' ): """Inference demo for Real-ESRGAN. """ args = easydict.EasyDict({ 'model_name' : model_name, 'outscale' : outscale, 'suffix' : suffix, 'tile' : tile, 'tile_pad' : tile_pad, 'pre_pad' : pre_pad, 'face_enhance' : face_enhance, 'half' : half, 'alpha_upsampler' : alpha_upsampler, 'ext' : ext }) # determine models according to model names args.model_name = args.model_name.split('.')[0] if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) netscale = 4 elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) netscale = 4 elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) netscale = 2 elif args.model_name in [ 'RealESRGANv2-anime-xsx2', 'RealESRGANv2-animevideo-xsx2-nousm', 'RealESRGANv2-animevideo-xsx2' ]: # x2 VGG-style model (XS size) model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu') netscale = 2 elif args.model_name in [ 'RealESRGANv2-anime-xsx4', 'RealESRGANv2-animevideo-xsx4-nousm', 'RealESRGANv2-animevideo-xsx4' ]: # x4 VGG-style model (XS size) model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') netscale = 4 # determine model paths model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth') if not os.path.isfile(model_path): model_path = os.path.join('realesrgan/weights', args.model_name + '.pth') if not os.path.isfile(model_path): raise ValueError(f'Model {args.model_name} does not exist.') # restorer upsampler = RealESRGANer( scale=netscale, model_path=model_path, model=model, tile=args.tile, tile_pad=args.tile_pad, pre_pad=args.pre_pad, half=args.half) return upsampler def build_pix2pix(): opt = easydict.EasyDict({ 'isTrain' : False, 'name' : 'anime2cheek', 'gpu_ids' : [], 'checkpoints_dir' : 'experiments', 'model' : 'pix2pixHD', 'norm' : 'instance', 'use_dropout' : False, 'data_type' : 32, 'verbose' : False, 'fp16' : False, 'local_rank' : 0, 'batchSize' : 1, 'loadSize' : 512, 'fineSize' : 512, 'label_nc' : 0, 'input_nc' : 3, 'output_nc' : 3, 'resize_or_crop' : [], 'serial_batches' : True, 'no_flip' : True, 'nThreads' : 1, 'max_dataset_size' : 50000, 'display_winsize' : 512, 'tf_log' : False, 'netG' : 'global', 'ngf' : 64, 'n_downsample_global' : 4, 'n_blocks_global' : 9, 'n_blocks_local' : 3, 'n_local_enhancers' : 1, 'niter_fix_global' : 0, 'no_instance' : True, 'instance_feat' : False, 'label_feat' : False, 'feat_num' : 3, 'load_features' : False, 'n_downsample_E' : 4, 'nef' : 16, 'n_clusters' : 10, 'initialized' : True, 'ntest' : float('inf'), 'aspect_ratio' : 1.0, 'phase' : 'test', 'which_epoch' : 'latest', 'cluster_path' : None, 'use_encoded_image' : False, 'export_onnx' : None, 'engine' : None, 'onnx' : None, }) model = create_model(opt) # create a model given opt.model and other options model.eval() return model def image_preprosses(img, vivid): if (img.mode == 'RGBA') or (img.mode == 'P'): img.load() background = Image.new("RGB", img.size, (255, 255, 255)) background.paste(img, mask=img.split()[3]) # 3 is the alpha channel img = background assert (img.mode == 'RGB') width, height = img.size if not (width == height): minsize = min(width, height) left = (width - minsize)/2 top = (height - minsize)/2 right = (width + minsize)/2 bottom = (height + minsize)/2 img = img.crop((left, top, right, bottom)) assert img.width == img.height if (img.width < 400) or (vivid == True): img = np.array(img.resize((128,128))) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) else : img = img.resize((512,512)) return img def test_pix2pix(img, pix2pix): pretransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) img = pretransform(img) img = img.unsqueeze(dim=0) with torch.no_grad(): img = pix2pix.netG(img) img = img.data[0].float().numpy() img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0 img = img.astype(np.uint8) #img = Image.fromarray(img) return img