ohayo_face_style / train.py
Reeve's picture
Update train.py
9833f28
raw
history blame
5.87 kB
# -*- coding: utf-8 -*-
"""train.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1nXacyY7r1lbMC9m9aZvuSOLc343bPtrV
"""
import os
from collections import OrderedDict
from torch.autograd import Variable
from models.models import create_model
from PIL import Image
from torchvision import transforms
import util.util as util
import easydict
import torch
import numpy as np
import cv2
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
def build_esrgan(
model_name = 'RealESRGAN_x4plus_anime_6B',
outscale = 4,
suffix = 'out',
tile = 0,
tile_pad = 10,
pre_pad = 0,
face_enhance = False,
half = False,
alpha_upsampler = 'realesrgan',
ext = 'png'
):
"""Inference demo for Real-ESRGAN.
"""
args = easydict.EasyDict({
'model_name' : model_name,
'outscale' : outscale,
'suffix' : suffix,
'tile' : tile,
'tile_pad' : tile_pad,
'pre_pad' : pre_pad,
'face_enhance' : face_enhance,
'half' : half,
'alpha_upsampler' : alpha_upsampler,
'ext' : ext
})
# determine models according to model names
args.model_name = args.model_name.split('.')[0]
if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
elif args.model_name in [
'RealESRGANv2-anime-xsx2', 'RealESRGANv2-animevideo-xsx2-nousm', 'RealESRGANv2-animevideo-xsx2'
]: # x2 VGG-style model (XS size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu')
netscale = 2
elif args.model_name in [
'RealESRGANv2-anime-xsx4', 'RealESRGANv2-animevideo-xsx4-nousm', 'RealESRGANv2-animevideo-xsx4'
]: # x4 VGG-style model (XS size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
netscale = 4
# determine model paths
model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
if not os.path.isfile(model_path):
raise ValueError(f'Model {args.model_name} does not exist.')
# restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
model=model,
tile=args.tile,
tile_pad=args.tile_pad,
pre_pad=args.pre_pad,
half=args.half)
return upsampler
def build_pix2pix():
opt = easydict.EasyDict({
'isTrain' : False,
'name' : 'anime2cheek',
'gpu_ids' : [],
'checkpoints_dir' : 'experiments',
'model' : 'pix2pixHD',
'norm' : 'instance',
'use_dropout' : False,
'data_type' : 32,
'verbose' : False,
'fp16' : False,
'local_rank' : 0,
'batchSize' : 1,
'loadSize' : 512,
'fineSize' : 512,
'label_nc' : 0,
'input_nc' : 3,
'output_nc' : 3,
'resize_or_crop' : [],
'serial_batches' : True,
'no_flip' : True,
'nThreads' : 1,
'max_dataset_size' : 50000,
'display_winsize' : 512,
'tf_log' : False,
'netG' : 'global',
'ngf' : 64,
'n_downsample_global' : 4,
'n_blocks_global' : 9,
'n_blocks_local' : 3,
'n_local_enhancers' : 1,
'niter_fix_global' : 0,
'no_instance' : True,
'instance_feat' : False,
'label_feat' : False,
'feat_num' : 3,
'load_features' : False,
'n_downsample_E' : 4,
'nef' : 16,
'n_clusters' : 10,
'initialized' : True,
'ntest' : float('inf'),
'aspect_ratio' : 1.0,
'phase' : 'test',
'which_epoch' : 'latest',
'cluster_path' : None,
'use_encoded_image' : False,
'export_onnx' : None,
'engine' : None,
'onnx' : None,
})
model = create_model(opt) # create a model given opt.model and other options
model.eval()
return model
def image_preprosses(img, vivid):
if (img.mode == 'RGBA') or (img.mode == 'P'):
img.load()
background = Image.new("RGB", img.size, (255, 255, 255))
background.paste(img, mask=img.split()[3]) # 3 is the alpha channel
img = background
assert (img.mode == 'RGB')
width, height = img.size
if not (width == height):
minsize = min(width, height)
left = (width - minsize)/2
top = (height - minsize)/2
right = (width + minsize)/2
bottom = (height + minsize)/2
img = img.crop((left, top, right, bottom))
assert img.width == img.height
if (img.width < 400) or (vivid == True):
img = np.array(img.resize((128,128)))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
else : img = img.resize((512,512))
return img
def test_pix2pix(img, pix2pix):
pretransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img = pretransform(img)
img = img.unsqueeze(dim=0)
with torch.no_grad():
img = pix2pix.netG(img)
img = img.data[0].float().numpy()
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
img = img.astype(np.uint8)
#img = Image.fromarray(img)
return img