Reeve commited on
Commit
cc6a646
1 Parent(s): f20d276

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +189 -0
train.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """train.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1nXacyY7r1lbMC9m9aZvuSOLc343bPtrV
8
+ """
9
+
10
+ import os
11
+ from data import create_dataset
12
+ from models import create_model
13
+ from util.visualizer import save_images
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ import easydict
17
+ import torch
18
+ import numpy as np
19
+ import cv2
20
+ from basicsr.archs.rrdbnet_arch import RRDBNet
21
+
22
+ from realesrgan import RealESRGANer
23
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
24
+
25
+ def build_esrgan(
26
+ model_name = 'RealESRGAN_x4plus_anime_6B',
27
+ outscale = 4,
28
+ suffix = 'out',
29
+ tile = 0,
30
+ tile_pad = 10,
31
+ pre_pad = 0,
32
+ face_enhance = False,
33
+ half = False,
34
+ alpha_upsampler = 'realesrgan',
35
+ ext = 'png'
36
+
37
+ ):
38
+ """Inference demo for Real-ESRGAN.
39
+ """
40
+ args = easydict.EasyDict({
41
+ 'model_name' : model_name,
42
+ 'outscale' : outscale,
43
+ 'suffix' : suffix,
44
+ 'tile' : tile,
45
+ 'tile_pad' : tile_pad,
46
+ 'pre_pad' : pre_pad,
47
+ 'face_enhance' : face_enhance,
48
+ 'half' : half,
49
+ 'alpha_upsampler' : alpha_upsampler,
50
+ 'ext' : ext
51
+ })
52
+
53
+
54
+ # determine models according to model names
55
+ args.model_name = args.model_name.split('.')[0]
56
+ if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model
57
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
58
+ netscale = 4
59
+ elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks
60
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
61
+ netscale = 4
62
+ elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
63
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
64
+ netscale = 2
65
+ elif args.model_name in [
66
+ 'RealESRGANv2-anime-xsx2', 'RealESRGANv2-animevideo-xsx2-nousm', 'RealESRGANv2-animevideo-xsx2'
67
+ ]: # x2 VGG-style model (XS size)
68
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu')
69
+ netscale = 2
70
+ elif args.model_name in [
71
+ 'RealESRGANv2-anime-xsx4', 'RealESRGANv2-animevideo-xsx4-nousm', 'RealESRGANv2-animevideo-xsx4'
72
+ ]: # x4 VGG-style model (XS size)
73
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
74
+ netscale = 4
75
+
76
+ # determine model paths
77
+ model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
78
+ if not os.path.isfile(model_path):
79
+ model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
80
+ if not os.path.isfile(model_path):
81
+ raise ValueError(f'Model {args.model_name} does not exist.')
82
+
83
+ # restorer
84
+ upsampler = RealESRGANer(
85
+ scale=netscale,
86
+ model_path=model_path,
87
+ model=model,
88
+ tile=args.tile,
89
+ tile_pad=args.tile_pad,
90
+ pre_pad=args.pre_pad,
91
+ half=args.half)
92
+
93
+
94
+
95
+ return upsampler
96
+
97
+ def build_pix2pix(model_epoch):
98
+ opt = easydict.EasyDict({
99
+ 'isTrain' : False,
100
+ 'use_wandb' : False,
101
+ 'gpu_ids' : [],
102
+ 'checkpoints_dir' : 'experiments',
103
+ 'batch_size' : 1,
104
+
105
+
106
+ 'model' : 'pix2pix',
107
+ 'input_nc' : 3,
108
+ 'output_nc' : 3,
109
+ 'ngf' : 64,
110
+ 'ndf' : 64,
111
+ 'netD' : 'basic',
112
+ 'netG' : 'unet_256',
113
+ 'n_layers_D' : 3,
114
+ 'norm' : 'batch',
115
+ 'init_type' : 'normal',
116
+ 'init_gain': 0.02,
117
+ 'no_dropout' : False,
118
+
119
+ 'direction' : 'AtoB',
120
+ 'serial_batches' : True,
121
+ 'num_threads' : 0,
122
+ 'load_size' : 512,
123
+ 'crop_size' : 256,
124
+ 'max_dataset_size' : 50000,
125
+ 'preprocess' : [],
126
+ 'no_flip' : True,
127
+ 'display_winsize' : 512,
128
+ 'verbose' : False,
129
+ 'suffix' : '',
130
+ 'load_iter' : 0,
131
+
132
+ #test_arguments
133
+ 'aspect_ratio' : 1.0,
134
+ 'phase' : 'test',
135
+ 'eval' : True,
136
+ 'num_test' : 1,
137
+ 'model' : 'test',
138
+ 'load_size' : 512,
139
+ 'dataset_mode' : 'single',
140
+ 'model_suffix' : '',
141
+ 'epoch' : 110, #latest
142
+ 'name' : 'pix2pix',
143
+ })
144
+ opt.epoch = model_epoch
145
+ model = create_model(opt) # create a model given opt.model and other options
146
+ model.setup(opt) # regular setup: load and print networks; create schedulers
147
+ model.eval()
148
+ return model
149
+
150
+ def image_preprosses(img, vivid):
151
+ if (img.mode == 'RGBA') or (img.mode == 'P'):
152
+ img.load()
153
+ background = Image.new("RGB", img.size, (255, 255, 255))
154
+ background.paste(img, mask=img.split()[3]) # 3 is the alpha channel
155
+ img = background
156
+
157
+ assert (img.mode == 'RGB')
158
+ width, height = img.size
159
+
160
+ if not (width == height):
161
+ minsize = min(width, height)
162
+ left = (width - minsize)/2
163
+ top = (height - minsize)/2
164
+ right = (width + minsize)/2
165
+ bottom = (height + minsize)/2
166
+ img = img.crop((left, top, right, bottom))
167
+
168
+ assert img.width == img.height
169
+
170
+ if (img.width < 400) or (vivid == True):
171
+ img = np.array(img.resize((128,128)))
172
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
173
+ else : img = img.resize((512,512))
174
+
175
+ return img
176
+
177
+ def test_pix2pix(img, pix2pix):
178
+ pretransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
179
+
180
+ img = pretransform(img)
181
+ img = img.unsqueeze(dim=0)
182
+ with torch.no_grad():
183
+ img = pix2pix.netG(img)
184
+ img = img.data[0].float().numpy()
185
+ img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
186
+ img = img.astype(np.uint8)
187
+ #img = Image.fromarray(img)
188
+
189
+ return img