Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""train.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1nXacyY7r1lbMC9m9aZvuSOLc343bPtrV | |
""" | |
import os | |
from collections import OrderedDict | |
from torch.autograd import Variable | |
from models.models import create_model | |
from PIL import Image | |
from torchvision import transforms | |
import util.util as util | |
import easydict | |
import torch | |
import numpy as np | |
import cv2 | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
from realesrgan import RealESRGANer | |
from realesrgan.archs.srvgg_arch import SRVGGNetCompact | |
def build_esrgan( | |
model_name = 'RealESRGAN_x4plus_anime_6B', | |
outscale = 4, | |
suffix = 'out', | |
tile = 0, | |
tile_pad = 10, | |
pre_pad = 0, | |
face_enhance = False, | |
half = False, | |
alpha_upsampler = 'realesrgan', | |
ext = 'png' | |
): | |
"""Inference demo for Real-ESRGAN. | |
""" | |
args = easydict.EasyDict({ | |
'model_name' : model_name, | |
'outscale' : outscale, | |
'suffix' : suffix, | |
'tile' : tile, | |
'tile_pad' : tile_pad, | |
'pre_pad' : pre_pad, | |
'face_enhance' : face_enhance, | |
'half' : half, | |
'alpha_upsampler' : alpha_upsampler, | |
'ext' : ext | |
}) | |
# determine models according to model names | |
args.model_name = args.model_name.split('.')[0] | |
if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
netscale = 4 | |
elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) | |
netscale = 4 | |
elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
netscale = 2 | |
elif args.model_name in [ | |
'RealESRGANv2-anime-xsx2', 'RealESRGANv2-animevideo-xsx2-nousm', 'RealESRGANv2-animevideo-xsx2' | |
]: # x2 VGG-style model (XS size) | |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu') | |
netscale = 2 | |
elif args.model_name in [ | |
'RealESRGANv2-anime-xsx4', 'RealESRGANv2-animevideo-xsx4-nousm', 'RealESRGANv2-animevideo-xsx4' | |
]: # x4 VGG-style model (XS size) | |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') | |
netscale = 4 | |
# determine model paths | |
model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth') | |
if not os.path.isfile(model_path): | |
model_path = os.path.join('realesrgan/weights', args.model_name + '.pth') | |
if not os.path.isfile(model_path): | |
raise ValueError(f'Model {args.model_name} does not exist.') | |
# restorer | |
upsampler = RealESRGANer( | |
scale=netscale, | |
model_path=model_path, | |
model=model, | |
tile=args.tile, | |
tile_pad=args.tile_pad, | |
pre_pad=args.pre_pad, | |
half=args.half) | |
return upsampler | |
def build_pix2pix(): | |
opt = easydict.EasyDict({ | |
'isTrain' : False, | |
'name' : 'anime2cheek', | |
'gpu_ids' : [], | |
'checkpoints_dir' : 'experiments', | |
'model' : 'pix2pixHD', | |
'norm' : 'instance', | |
'use_dropout' : False, | |
'data_type' : 32, | |
'verbose' : False, | |
'fp16' : False, | |
'local_rank' : 0, | |
'batchSize' : 1, | |
'loadSize' : 512, | |
'fineSize' : 512, | |
'label_nc' : 0, | |
'input_nc' : 3, | |
'output_nc' : 3, | |
'resize_or_crop' : [], | |
'serial_batches' : True, | |
'no_flip' : True, | |
'nThreads' : 1, | |
'max_dataset_size' : 50000, | |
'display_winsize' : 512, | |
'tf_log' : False, | |
'netG' : 'global', | |
'ngf' : 64, | |
'n_downsample_global' : 4, | |
'n_blocks_global' : 9, | |
'n_blocks_local' : 3, | |
'n_local_enhancers' : 1, | |
'niter_fix_global' : 0, | |
'no_instance' : True, | |
'instance_feat' : False, | |
'label_feat' : False, | |
'feat_num' : 3, | |
'load_features' : False, | |
'n_downsample_E' : 4, | |
'nef' : 16, | |
'n_clusters' : 10, | |
'initialized' : True, | |
'ntest' : float('inf'), | |
'aspect_ratio' : 1.0, | |
'phase' : 'test', | |
'which_epoch' : 'latest', | |
'cluster_path' : None, | |
'use_encoded_image' : False, | |
'export_onnx' : None, | |
'engine' : None, | |
'onnx' : None, | |
}) | |
model = create_model(opt) # create a model given opt.model and other options | |
model.eval() | |
return model | |
def image_preprosses(img, vivid): | |
if (img.mode == 'RGBA') or (img.mode == 'P'): | |
img.load() | |
background = Image.new("RGB", img.size, (255, 255, 255)) | |
background.paste(img, mask=img.split()[3]) # 3 is the alpha channel | |
img = background | |
assert (img.mode == 'RGB') | |
width, height = img.size | |
if not (width == height): | |
minsize = min(width, height) | |
left = (width - minsize)/2 | |
top = (height - minsize)/2 | |
right = (width + minsize)/2 | |
bottom = (height + minsize)/2 | |
img = img.crop((left, top, right, bottom)) | |
assert img.width == img.height | |
if (img.width < 400) or (vivid == True): | |
img = np.array(img.resize((128,128))) | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
else : img = img.resize((512,512)) | |
return img | |
def test_pix2pix(img, pix2pix): | |
pretransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
img = pretransform(img) | |
img = img.unsqueeze(dim=0) | |
with torch.no_grad(): | |
img = pix2pix.netG(img) | |
img = img.data[0].float().numpy() | |
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0 | |
img = img.astype(np.uint8) | |
#img = Image.fromarray(img) | |
return img |