ptran1203 commited on
Commit
f2fa83b
1 Parent(s): 774d537
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -0
  2. README.md +1 -0
  3. app.py +95 -0
  4. example/arcane/anne.jpg +0 -0
  5. example/arcane/boy2.jpg +0 -0
  6. example/arcane/cap.jpg +0 -0
  7. example/arcane/dune2.jpg +0 -0
  8. example/arcane/elon.jpg +0 -0
  9. example/arcane/girl.jpg +0 -0
  10. example/arcane/girl4.jpg +0 -0
  11. example/arcane/girl6.jpg +0 -0
  12. example/arcane/leo.jpg +0 -0
  13. example/arcane/man2.jpg +0 -0
  14. example/arcane/nat_.jpg +0 -0
  15. example/arcane/seydoux.jpg +0 -0
  16. example/arcane/tobey.jpg +0 -0
  17. example/face/anne.jpg +0 -0
  18. example/face/boy2.jpg +0 -0
  19. example/face/cap.jpg +0 -0
  20. example/face/dune2.jpg +0 -0
  21. example/face/elon.jpg +0 -0
  22. example/face/girl.jpg +0 -0
  23. example/face/girl4.jpg +0 -0
  24. example/face/girl6.jpg +0 -0
  25. example/face/leo.jpg +0 -0
  26. example/face/man2.jpg +0 -0
  27. example/face/nat_.jpg +0 -0
  28. example/face/seydoux.jpg +0 -0
  29. example/face/tobey.jpg +0 -0
  30. example/generate_examples.py +49 -0
  31. example/more/hayao_v2/pexels-arnie-chou-304906-1004122.jpg +0 -0
  32. example/more/hayao_v2/pexels-camilacarneiro-6318793.jpg +0 -0
  33. example/more/hayao_v2/pexels-haohd-19859127.jpg +0 -0
  34. example/more/hayao_v2/pexels-huy-nguyen-748440234-19838813.jpg +0 -0
  35. example/more/hayao_v2/pexels-huy-phan-316220-1422386.jpg +0 -0
  36. example/more/hayao_v2/pexels-jimmy-teoh-294331-951531.jpg +0 -0
  37. example/more/hayao_v2/pexels-nandhukumar-450441.jpg +0 -0
  38. example/more/hayao_v2/pexels-sevenstormphotography-575362.jpg +0 -0
  39. inference.py +410 -0
  40. losses.py +248 -0
  41. models/__init__.py +3 -0
  42. models/anime_gan.py +112 -0
  43. models/anime_gan_v2.py +65 -0
  44. models/anime_gan_v3.py +14 -0
  45. models/conv_blocks.py +171 -0
  46. models/layers.py +28 -0
  47. models/vgg.py +80 -0
  48. predict.py +35 -0
  49. train.py +163 -0
  50. trainer/__init__.py +437 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .cache
2
+ __pycache__
3
+ output
4
+ .token
README.md CHANGED
@@ -11,3 +11,4 @@ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+ <!-- https://huggingface.co/spaces/ptran1203/pytorchAnimeGAN -->
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import gradio as gr
5
+ from inference import Predictor
6
+ from utils.image_processing import resize_image
7
+
8
+ os.makedirs('output', exist_ok=True)
9
+
10
+
11
+ def inference(
12
+ image: np.ndarray,
13
+ style,
14
+ imgsz=None,
15
+ ):
16
+ retain_color = False
17
+
18
+ weight = {
19
+ "AnimeGAN_Hayao": "hayao",
20
+ "AnimeGAN_Shinkai": "shinkai",
21
+ "AnimeGANv2_Hayao": "hayao:v2",
22
+ "AnimeGANv2_Shinkai": "shinkai:v2",
23
+ "AnimeGANv2_Arcane": "arcane:v2",
24
+ }[style]
25
+ predictor = Predictor(
26
+ weight,
27
+ device='cpu',
28
+ retain_color=retain_color,
29
+ imgsz=imgsz,
30
+ )
31
+
32
+ save_path = f"output/out.jpg"
33
+ image = resize_image(image, width=imgsz)
34
+ anime_image = predictor.transform(image)[0]
35
+ cv2.imwrite(save_path, anime_image[..., ::-1])
36
+ return anime_image, save_path
37
+
38
+
39
+ title = "AnimeGANv2: To produce your own animation."
40
+ description = r"""Turn your photo into anime style 😊"""
41
+ article = r"""
42
+ [![GitHub Stars](https://img.shields.io/github/stars/ptran1203/pytorch-animeGAN?style=social)](https://github.com/ptran1203/pytorch-animeGAN)
43
+ ### 🗻 Demo
44
+
45
+ """
46
+
47
+ gr.Interface(
48
+ fn=inference,
49
+ inputs=[
50
+ gr.components.Image(label="Input"),
51
+ gr.Dropdown(
52
+ [
53
+ 'AnimeGAN_Hayao',
54
+ 'AnimeGAN_Shinkai',
55
+ 'AnimeGANv2_Hayao',
56
+ 'AnimeGANv2_Shinkai',
57
+ 'AnimeGANv2_Arcane',
58
+ ],
59
+ type="value",
60
+ value='AnimeGANv2_Hayao',
61
+ label='Style'
62
+ ),
63
+ gr.Dropdown(
64
+ [
65
+ None,
66
+ 416,
67
+ 512,
68
+ 768,
69
+ 1024,
70
+ 1536,
71
+ ],
72
+ type="value",
73
+ value=None,
74
+ label='Image size'
75
+ )
76
+ ],
77
+ outputs=[
78
+ gr.components.Image(type="numpy", label="Output (The whole image)"),
79
+ gr.components.File(label="Download the output image")
80
+ ],
81
+ title=title,
82
+ description=description,
83
+ article=article,
84
+ allow_flagging="never",
85
+ examples=[
86
+ ['example/arcane/girl4.jpg', 'AnimeGANv2_Arcane', "Yes"],
87
+ ['example/arcane/leo.jpg', 'AnimeGANv2_Arcane', "Yes"],
88
+ ['example/arcane/girl.jpg', 'AnimeGANv2_Arcane', "Yes"],
89
+ ['example/arcane/anne.jpg', 'AnimeGANv2_Arcane', "Yes"],
90
+ # ['example/boy2.jpg', 'AnimeGANv3_Arcane', "No"],
91
+ # ['example/cap.jpg', 'AnimeGANv3_Arcane', "No"],
92
+ ['example/more/hayao_v2/pexels-camilacarneiro-6318793.jpg', 'AnimeGANv2_Hayao', "Yes"],
93
+ ['example/more/hayao_v2/pexels-nandhukumar-450441.jpg', 'AnimeGANv2_Hayao', "Yes"],
94
+ ]
95
+ ).launch()
example/arcane/anne.jpg ADDED
example/arcane/boy2.jpg ADDED
example/arcane/cap.jpg ADDED
example/arcane/dune2.jpg ADDED
example/arcane/elon.jpg ADDED
example/arcane/girl.jpg ADDED
example/arcane/girl4.jpg ADDED
example/arcane/girl6.jpg ADDED
example/arcane/leo.jpg ADDED
example/arcane/man2.jpg ADDED
example/arcane/nat_.jpg ADDED
example/arcane/seydoux.jpg ADDED
example/arcane/tobey.jpg ADDED
example/face/anne.jpg ADDED
example/face/boy2.jpg ADDED
example/face/cap.jpg ADDED
example/face/dune2.jpg ADDED
example/face/elon.jpg ADDED
example/face/girl.jpg ADDED
example/face/girl4.jpg ADDED
example/face/girl6.jpg ADDED
example/face/leo.jpg ADDED
example/face/man2.jpg ADDED
example/face/nat_.jpg ADDED
example/face/seydoux.jpg ADDED
example/face/tobey.jpg ADDED
example/generate_examples.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import re
4
+
5
+ REG = re.compile(r"[0-9]{3}")
6
+ dir_ = './example/result'
7
+ readme = './README.md'
8
+
9
+
10
+ def anime_2_input(fi):
11
+ return fi.replace("_anime", "")
12
+
13
+ def rename(f):
14
+ return f.replace(" ", "").replace("(", "").replace(")", "")
15
+
16
+ def rename_back(f):
17
+ nums = REG.search(f)
18
+ if nums:
19
+ nums = nums.group()
20
+ return f.replace(nums, f"{nums[0]} ({nums[1:]})")
21
+
22
+ return f.replace('jpeg', 'jpg')
23
+
24
+ def copyfile(src, dest):
25
+ # copy and resize
26
+ im = cv2.imread(src)
27
+
28
+ if im is None:
29
+ raise FileNotFoundError(src)
30
+
31
+ h, w = im.shape[1], im.shape[0]
32
+
33
+ s = 448
34
+ size = (s, round(s * w / h))
35
+ im = cv2.resize(im, size)
36
+
37
+ print(w, h, im.shape)
38
+ cv2.imwrite(dest, im)
39
+
40
+ files = os.listdir(dir_)
41
+ new_files = []
42
+ for f in files:
43
+ input_ver = os.path.join(dir_, anime_2_input(f))
44
+ copyfile(f"dataset/test/HR_photo/{rename_back(anime_2_input(f))}", rename(input_ver))
45
+
46
+ os.rename(
47
+ os.path.join(dir_, f),
48
+ os.path.join(dir_, rename(f))
49
+ )
example/more/hayao_v2/pexels-arnie-chou-304906-1004122.jpg ADDED
example/more/hayao_v2/pexels-camilacarneiro-6318793.jpg ADDED
example/more/hayao_v2/pexels-haohd-19859127.jpg ADDED
example/more/hayao_v2/pexels-huy-nguyen-748440234-19838813.jpg ADDED
example/more/hayao_v2/pexels-huy-phan-316220-1422386.jpg ADDED
example/more/hayao_v2/pexels-jimmy-teoh-294331-951531.jpg ADDED
example/more/hayao_v2/pexels-nandhukumar-450441.jpg ADDED
example/more/hayao_v2/pexels-sevenstormphotography-575362.jpg ADDED
inference.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import shutil
4
+
5
+ import torch
6
+ import cv2
7
+ import numpy as np
8
+
9
+ from models.anime_gan import GeneratorV1
10
+ from models.anime_gan_v2 import GeneratorV2
11
+ from models.anime_gan_v3 import GeneratorV3
12
+ from utils.common import load_checkpoint, RELEASED_WEIGHTS
13
+ from utils.image_processing import resize_image, normalize_input, denormalize_input
14
+ from utils import read_image, is_image_file, is_video_file
15
+ from tqdm import tqdm
16
+ from color_transfer import color_transfer_pytorch
17
+
18
+
19
+ try:
20
+ import matplotlib.pyplot as plt
21
+ except ImportError:
22
+ plt = None
23
+
24
+ try:
25
+ import moviepy.video.io.ffmpeg_writer as ffmpeg_writer
26
+ from moviepy.video.io.VideoFileClip import VideoFileClip
27
+ except ImportError:
28
+ ffmpeg_writer = None
29
+ VideoFileClip = None
30
+
31
+
32
+ def profile(func):
33
+ def wrap(*args, **kwargs):
34
+ started_at = time.time()
35
+ result = func(*args, **kwargs)
36
+ elapsed = time.time() - started_at
37
+ print(f"Processed in {elapsed:.3f}s")
38
+ return result
39
+ return wrap
40
+
41
+
42
+ def auto_load_weight(weight, version=None, map_location=None):
43
+ """Auto load Generator version from weight."""
44
+ weight_name = os.path.basename(weight).lower()
45
+ if version is not None:
46
+ version = version.lower()
47
+ assert version in {"v1", "v2", "v3"}, f"Version {version} does not exist"
48
+ # If version is provided, use it.
49
+ cls = {
50
+ "v1": GeneratorV1,
51
+ "v2": GeneratorV2,
52
+ "v3": GeneratorV3
53
+ }[version]
54
+ else:
55
+ # Try to get class by name of weight file
56
+ # For convenenice, weight should start with classname
57
+ # e.g: Generatorv2_{anything}.pt
58
+ if weight_name in RELEASED_WEIGHTS:
59
+ version = RELEASED_WEIGHTS[weight_name][0]
60
+ return auto_load_weight(weight, version=version, map_location=map_location)
61
+
62
+ elif weight_name.startswith("generatorv2"):
63
+ cls = GeneratorV2
64
+ elif weight_name.startswith("generatorv3"):
65
+ cls = GeneratorV3
66
+ elif weight_name.startswith("generator"):
67
+ cls = GeneratorV1
68
+ else:
69
+ raise ValueError((f"Can not get Model from {weight_name}, "
70
+ "you might need to explicitly specify version"))
71
+ model = cls()
72
+ load_checkpoint(model, weight, strip_optimizer=True, map_location=map_location)
73
+ model.eval()
74
+ return model
75
+
76
+
77
+ class Predictor:
78
+ """
79
+ Generic class for transfering Image to anime like image.
80
+ """
81
+ def __init__(
82
+ self,
83
+ weight='hayao',
84
+ device='cuda',
85
+ amp=True,
86
+ retain_color=False,
87
+ imgsz=None,
88
+ ):
89
+ if not torch.cuda.is_available():
90
+ device = 'cpu'
91
+ # Amp not working on cpu
92
+ amp = False
93
+ print("Use CPU device")
94
+ else:
95
+ print(f"Use GPU {torch.cuda.get_device_name()}")
96
+
97
+ self.imgsz = imgsz
98
+ self.retain_color = retain_color
99
+ self.amp = amp # Automatic Mixed Precision
100
+ self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
101
+ self.device = torch.device(device)
102
+ self.G = auto_load_weight(weight, map_location=device)
103
+ self.G.to(self.device)
104
+
105
+ def transform_and_show(
106
+ self,
107
+ image_path,
108
+ figsize=(18, 10),
109
+ save_path=None
110
+ ):
111
+ image = resize_image(read_image(image_path))
112
+ anime_img = self.transform(image)
113
+ anime_img = anime_img.astype('uint8')
114
+
115
+ fig = plt.figure(figsize=figsize)
116
+ fig.add_subplot(1, 2, 1)
117
+ # plt.title("Input")
118
+ plt.imshow(image)
119
+ plt.axis('off')
120
+ fig.add_subplot(1, 2, 2)
121
+ # plt.title("Anime style")
122
+ plt.imshow(anime_img[0])
123
+ plt.axis('off')
124
+ plt.tight_layout()
125
+ plt.show()
126
+ if save_path is not None:
127
+ plt.savefig(save_path)
128
+
129
+ def transform(self, image, denorm=True):
130
+ '''
131
+ Transform a image to animation
132
+
133
+ @Arguments:
134
+ - image: np.array, shape = (Batch, width, height, channels)
135
+
136
+ @Returns:
137
+ - anime version of image: np.array
138
+ '''
139
+ with torch.no_grad():
140
+ image = self.preprocess_images(image)
141
+ # image = image.to(self.device)
142
+ # with autocast(self.device_type, enabled=self.amp):
143
+ # print(image.dtype, self.G)
144
+ fake = self.G(image)
145
+ # Transfer color of fake image look similiar color as image
146
+ if self.retain_color:
147
+ fake = color_transfer_pytorch(fake, image)
148
+ fake = (fake / 0.5) - 1.0 # remap to [-1. 1]
149
+ fake = fake.detach().cpu().numpy()
150
+ # Channel last
151
+ fake = fake.transpose(0, 2, 3, 1)
152
+
153
+ if denorm:
154
+ fake = denormalize_input(fake, dtype=np.uint8)
155
+ return fake
156
+
157
+ def read_and_resize(self, path, max_size=1536):
158
+ image = read_image(path)
159
+ _, ext = os.path.splitext(path)
160
+ h, w = image.shape[:2]
161
+ if self.imgsz is not None:
162
+ image = resize_image(image, width=self.imgsz)
163
+ elif max(h, w) > max_size:
164
+ print(f"Image {os.path.basename(path)} is too big ({h}x{w}), resize to max size {max_size}")
165
+ image = resize_image(
166
+ image,
167
+ width=max_size if w > h else None,
168
+ height=max_size if w < h else None,
169
+ )
170
+ cv2.imwrite(path.replace(ext, ".jpg"), image[:,:,::-1])
171
+ else:
172
+ image = resize_image(image)
173
+ # image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
174
+ # image = np.stack([image, image, image], -1)
175
+ # cv2.imwrite(path.replace(ext, ".jpg"), image[:,:,::-1])
176
+ return image
177
+
178
+ @profile
179
+ def transform_file(self, file_path, save_path):
180
+ if not is_image_file(save_path):
181
+ raise ValueError(f"{save_path} is not valid")
182
+
183
+ image = self.read_and_resize(file_path)
184
+ anime_img = self.transform(image)[0]
185
+ cv2.imwrite(save_path, anime_img[..., ::-1])
186
+ print(f"Anime image saved to {save_path}")
187
+ return anime_img
188
+
189
+ @profile
190
+ def transform_gif(self, file_path, save_path, batch_size=4):
191
+ import imageio
192
+
193
+ def _preprocess_gif(img):
194
+ if img.shape[-1] == 4:
195
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
196
+ return resize_image(img)
197
+
198
+ images = imageio.mimread(file_path)
199
+ images = np.stack([
200
+ _preprocess_gif(img)
201
+ for img in images
202
+ ])
203
+
204
+ print(images.shape)
205
+
206
+ anime_gif = np.zeros_like(images)
207
+
208
+ for i in tqdm(range(0, len(images), batch_size)):
209
+ end = i + batch_size
210
+ anime_gif[i: end] = self.transform(
211
+ images[i: end]
212
+ )
213
+
214
+ if end < len(images) - 1:
215
+ # transform last frame
216
+ print("LAST", images[end: ].shape)
217
+ anime_gif[end:] = self.transform(images[end:])
218
+
219
+ print(anime_gif.shape)
220
+ imageio.mimsave(
221
+ save_path,
222
+ anime_gif,
223
+
224
+ )
225
+ print(f"Anime image saved to {save_path}")
226
+
227
+ @profile
228
+ def transform_in_dir(self, img_dir, dest_dir, max_images=0, img_size=(512, 512)):
229
+ '''
230
+ Read all images from img_dir, transform and write the result
231
+ to dest_dir
232
+
233
+ '''
234
+ os.makedirs(dest_dir, exist_ok=True)
235
+
236
+ files = os.listdir(img_dir)
237
+ files = [f for f in files if is_image_file(f)]
238
+ print(f'Found {len(files)} images in {img_dir}')
239
+
240
+ if max_images:
241
+ files = files[:max_images]
242
+
243
+ bar = tqdm(files)
244
+ for fname in bar:
245
+ path = os.path.join(img_dir, fname)
246
+ image = self.read_and_resize(path)
247
+ anime_img = self.transform(image)[0]
248
+ # anime_img = resize_image(anime_img, width=320)
249
+ ext = fname.split('.')[-1]
250
+ fname = fname.replace(f'.{ext}', '')
251
+ cv2.imwrite(os.path.join(dest_dir, f'{fname}.jpg'), anime_img[..., ::-1])
252
+ bar.set_description(f"{fname} {image.shape}")
253
+
254
+ def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0):
255
+ '''
256
+ Transform a video to animation version
257
+ https://github.com/lengstrom/fast-style-transfer/blob/master/evaluate.py#L21
258
+ '''
259
+ if VideoFileClip is None:
260
+ raise ImportError("moviepy is not installed, please install with `pip install moviepy>=1.0.3`")
261
+ # Force to None
262
+ end = end or None
263
+
264
+ if not os.path.isfile(input_path):
265
+ raise FileNotFoundError(f'{input_path} does not exist')
266
+
267
+ output_dir = os.path.dirname(output_path)
268
+ if output_dir:
269
+ os.makedirs(output_dir, exist_ok=True)
270
+
271
+ is_gg_drive = '/drive/' in output_path
272
+ temp_file = ''
273
+
274
+ if is_gg_drive:
275
+ # Writing directly into google drive can be inefficient
276
+ temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
277
+
278
+ def transform_and_write(frames, count, writer):
279
+ anime_images = self.transform(frames)
280
+ for i in range(0, count):
281
+ img = np.clip(anime_images[i], 0, 255)
282
+ writer.write_frame(img)
283
+
284
+ video_clip = VideoFileClip(input_path, audio=False)
285
+ if start or end:
286
+ video_clip = video_clip.subclip(start, end)
287
+
288
+ video_writer = ffmpeg_writer.FFMPEG_VideoWriter(
289
+ temp_file or output_path,
290
+ video_clip.size, video_clip.fps,
291
+ codec="libx264",
292
+ # preset="medium", bitrate="2000k",
293
+ ffmpeg_params=None)
294
+
295
+ total_frames = round(video_clip.fps * video_clip.duration)
296
+ print(f'Transfroming video {input_path}, {total_frames} frames, size: {video_clip.size}')
297
+
298
+ batch_shape = (batch_size, video_clip.size[1], video_clip.size[0], 3)
299
+ frame_count = 0
300
+ frames = np.zeros(batch_shape, dtype=np.float32)
301
+ for frame in tqdm(video_clip.iter_frames(), total=total_frames):
302
+ try:
303
+ frames[frame_count] = frame
304
+ frame_count += 1
305
+ if frame_count == batch_size:
306
+ transform_and_write(frames, frame_count, video_writer)
307
+ frame_count = 0
308
+ except Exception as e:
309
+ print(e)
310
+ break
311
+
312
+ # The last frames
313
+ if frame_count != 0:
314
+ transform_and_write(frames, frame_count, video_writer)
315
+
316
+ if temp_file:
317
+ # move to output path
318
+ shutil.move(temp_file, output_path)
319
+
320
+ print(f'Animation video saved to {output_path}')
321
+ video_writer.close()
322
+
323
+ def preprocess_images(self, images):
324
+ '''
325
+ Preprocess image for inference
326
+
327
+ @Arguments:
328
+ - images: np.ndarray
329
+
330
+ @Returns
331
+ - images: torch.tensor
332
+ '''
333
+ images = images.astype(np.float32)
334
+
335
+ # Normalize to [-1, 1]
336
+ images = normalize_input(images)
337
+ images = torch.from_numpy(images)
338
+
339
+ images = images.to(self.device)
340
+
341
+ # Add batch dim
342
+ if len(images.shape) == 3:
343
+ images = images.unsqueeze(0)
344
+
345
+ # channel first
346
+ images = images.permute(0, 3, 1, 2)
347
+
348
+ return images
349
+
350
+
351
+ def parse_args():
352
+ import argparse
353
+ parser = argparse.ArgumentParser()
354
+ parser.add_argument(
355
+ '--weight',
356
+ type=str,
357
+ default="hayao:v2",
358
+ help=f'Model weight, can be path or pretrained {tuple(RELEASED_WEIGHTS.keys())}'
359
+ )
360
+ parser.add_argument('--src', type=str, help='Source, can be directory contains images, image file or video file.')
361
+ parser.add_argument('--device', type=str, default='cuda', help='Device, cuda or cpu')
362
+ parser.add_argument('--imgsz', type=int, default=None, help='Resize image to specified size if provided')
363
+ parser.add_argument('--out', type=str, default='inference_images', help='Output, can be directory or file')
364
+ parser.add_argument(
365
+ '--retain-color',
366
+ action='store_true',
367
+ help='If provided the generated image will retain original color of input image')
368
+ # Video params
369
+ parser.add_argument('--batch-size', type=int, default=4, help='Batch size when inference video')
370
+ parser.add_argument('--start', type=int, default=0, help='Start time of video (second)')
371
+ parser.add_argument('--end', type=int, default=0, help='End time of video (second), 0 if not set')
372
+
373
+ return parser.parse_args()
374
+
375
+ if __name__ == '__main__':
376
+ args = parse_args()
377
+
378
+ predictor = Predictor(
379
+ args.weight,
380
+ args.device,
381
+ retain_color=args.retain_color,
382
+ imgsz=args.imgsz,
383
+ )
384
+
385
+ if not os.path.exists(args.src):
386
+ raise FileNotFoundError(args.src)
387
+
388
+ if is_video_file(args.src):
389
+ predictor.transform_video(
390
+ args.src,
391
+ args.out,
392
+ args.batch_size,
393
+ start=args.start,
394
+ end=args.end
395
+ )
396
+ elif os.path.isdir(args.src):
397
+ predictor.transform_in_dir(args.src, args.out)
398
+ elif os.path.isfile(args.src):
399
+ save_path = args.out
400
+ if not is_image_file(args.out):
401
+ os.makedirs(args.out, exist_ok=True)
402
+ save_path = os.path.join(args.out, os.path.basename(args.src))
403
+
404
+ if args.src.endswith('.gif'):
405
+ # GIF file
406
+ predictor.transform_gif(args.src, save_path, args.batch_size)
407
+ else:
408
+ predictor.transform_file(args.src, save_path)
409
+ else:
410
+ raise NotImplementedError(f"{args.src} is not supported")
losses.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from models.vgg import Vgg19
5
+ from utils.image_processing import gram
6
+
7
+
8
+ def to_gray_scale(image):
9
+ # https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/_color.py#L33
10
+ # Image are assum in range 1, -1
11
+ image = (image + 1.0) / 2.0 # To [0, 1]
12
+ r, g, b = image.unbind(dim=-3)
13
+ l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
14
+ l_img = l_img.unsqueeze(dim=-3)
15
+ l_img = l_img.to(image.dtype)
16
+ l_img = l_img.expand(image.shape)
17
+ l_img = l_img / 0.5 - 1.0 # To [-1, 1]
18
+ return l_img
19
+
20
+
21
+ class ColorLoss(nn.Module):
22
+ def __init__(self):
23
+ super(ColorLoss, self).__init__()
24
+ self.l1 = nn.L1Loss()
25
+ self.huber = nn.SmoothL1Loss()
26
+ # self._rgb_to_yuv_kernel = torch.tensor([
27
+ # [0.299, -0.14714119, 0.61497538],
28
+ # [0.587, -0.28886916, -0.51496512],
29
+ # [0.114, 0.43601035, -0.10001026]
30
+ # ]).float()
31
+
32
+ self._rgb_to_yuv_kernel = torch.tensor([
33
+ [0.299, 0.587, 0.114],
34
+ [-0.14714119, -0.28886916, 0.43601035],
35
+ [0.61497538, -0.51496512, -0.10001026],
36
+ ]).float()
37
+
38
+ def to(self, device):
39
+ new_self = super(ColorLoss, self).to(device)
40
+ new_self._rgb_to_yuv_kernel = new_self._rgb_to_yuv_kernel.to(device)
41
+ return new_self
42
+
43
+ def rgb_to_yuv(self, image):
44
+ '''
45
+ https://en.wikipedia.org/wiki/YUV
46
+
47
+ output: Image of shape (H, W, C) (channel last)
48
+ '''
49
+ # -1 1 -> 0 1
50
+ image = (image + 1.0) / 2.0
51
+ image = image.permute(0, 2, 3, 1) # To channel last
52
+
53
+ yuv_img = image @ self._rgb_to_yuv_kernel.T
54
+
55
+ return yuv_img
56
+
57
+ def forward(self, image, image_g):
58
+ image = self.rgb_to_yuv(image)
59
+ image_g = self.rgb_to_yuv(image_g)
60
+ # After convert to yuv, both images have channel last
61
+ return (
62
+ self.l1(image[:, :, :, 0], image_g[:, :, :, 0])
63
+ + self.huber(image[:, :, :, 1], image_g[:, :, :, 1])
64
+ + self.huber(image[:, :, :, 2], image_g[:, :, :, 2])
65
+ )
66
+
67
+
68
+ class AnimeGanLoss:
69
+ def __init__(self, args, device, gray_adv=False):
70
+ if isinstance(device, str):
71
+ device = torch.device(device)
72
+
73
+ self.content_loss = nn.L1Loss().to(device)
74
+ self.gram_loss = nn.L1Loss().to(device)
75
+ self.color_loss = ColorLoss().to(device)
76
+ self.wadvg = args.wadvg
77
+ self.wadvd = args.wadvd
78
+ self.wcon = args.wcon
79
+ self.wgra = args.wgra
80
+ self.wcol = args.wcol
81
+ self.wtvar = args.wtvar
82
+ # If true, use gray scale image to calculate adversarial loss
83
+ self.gray_adv = gray_adv
84
+ self.vgg19 = Vgg19().to(device).eval()
85
+ self.adv_type = args.gan_loss
86
+ self.bce_loss = nn.BCEWithLogitsLoss()
87
+
88
+ def compute_loss_G(self, fake_img, img, fake_logit, anime_gray):
89
+ '''
90
+ Compute loss for Generator
91
+
92
+ @Args:
93
+ - fake_img: generated image
94
+ - img: real image
95
+ - fake_logit: output of Discriminator given fake image
96
+ - anime_gray: grayscale of anime image
97
+
98
+ @Returns:
99
+ - Adversarial Loss of fake logits
100
+ - Content loss between real and fake features (vgg19)
101
+ - Gram loss between anime and fake features (Vgg19)
102
+ - Color loss between image and fake image
103
+ - Total variation loss of fake image
104
+ '''
105
+ fake_feat = self.vgg19(fake_img)
106
+ gray_feat = self.vgg19(anime_gray)
107
+ img_feat = self.vgg19(img)
108
+ # fake_gray_feat = self.vgg19(to_gray_scale(fake_img))
109
+
110
+ return [
111
+ # Want to be real image.
112
+ self.wadvg * self.adv_loss_g(fake_logit),
113
+ self.wcon * self.content_loss(img_feat, fake_feat),
114
+ self.wgra * self.gram_loss(gram(gray_feat), gram(fake_feat)),
115
+ self.wcol * self.color_loss(img, fake_img),
116
+ self.wtvar * self.total_variation_loss(fake_img)
117
+ ]
118
+
119
+ def compute_loss_D(
120
+ self,
121
+ fake_img_d,
122
+ real_anime_d,
123
+ real_anime_gray_d,
124
+ real_anime_smooth_gray_d=None
125
+ ):
126
+ if self.gray_adv:
127
+ # Treat gray scale image as real
128
+ return (
129
+ self.adv_loss_d_real(real_anime_gray_d)
130
+ + self.adv_loss_d_fake(fake_img_d)
131
+ + 0.3 * self.adv_loss_d_fake(real_anime_smooth_gray_d)
132
+ )
133
+ else:
134
+ return (
135
+ # Classify real anime as real
136
+ self.adv_loss_d_real(real_anime_d)
137
+ # Classify generated as fake
138
+ + self.adv_loss_d_fake(fake_img_d)
139
+ # Classify real anime gray as fake
140
+ # + self.adv_loss_d_fake(real_anime_gray_d)
141
+ # Classify real anime as fake
142
+ # + 0.1 * self.adv_loss_d_fake(real_anime_smooth_gray_d)
143
+ )
144
+
145
+ def total_variation_loss(self, fake_img):
146
+ """
147
+ A smooth loss in fact. Like the smooth prior in MRF.
148
+ V(y) = || y_{n+1} - y_n ||_2
149
+ """
150
+ # Channel first -> channel last
151
+ fake_img = fake_img.permute(0, 2, 3, 1)
152
+ def _l2(x):
153
+ # sum(t ** 2) / 2
154
+ return torch.sum(x ** 2) / 2
155
+
156
+ dh = fake_img[:, :-1, ...] - fake_img[:, 1:, ...]
157
+ dw = fake_img[:, :, :-1, ...] - fake_img[:, :, 1:, ...]
158
+ return _l2(dh) / dh.numel() + _l2(dw) / dw.numel()
159
+
160
+ def content_loss_vgg(self, image, recontruction):
161
+ feat = self.vgg19(image)
162
+ re_feat = self.vgg19(recontruction)
163
+ feature_loss = self.content_loss(feat, re_feat)
164
+ content_loss = self.content_loss(image, recontruction)
165
+ return feature_loss# + 0.5 * content_loss
166
+
167
+ def adv_loss_d_real(self, pred):
168
+ """Push pred to class 1 (real)"""
169
+ if self.adv_type == 'hinge':
170
+ return torch.mean(F.relu(1.0 - pred))
171
+
172
+ elif self.adv_type == 'lsgan':
173
+ # pred = torch.sigmoid(pred)
174
+ return torch.mean(torch.square(pred - 1.0))
175
+
176
+ elif self.adv_type == 'bce':
177
+ return self.bce_loss(pred, torch.ones_like(pred))
178
+
179
+ raise ValueError(f'Do not support loss type {self.adv_type}')
180
+
181
+ def adv_loss_d_fake(self, pred):
182
+ """Push pred to class 0 (fake)"""
183
+ if self.adv_type == 'hinge':
184
+ return torch.mean(F.relu(1.0 + pred))
185
+
186
+ elif self.adv_type == 'lsgan':
187
+ # pred = torch.sigmoid(pred)
188
+ return torch.mean(torch.square(pred))
189
+
190
+ elif self.adv_type == 'bce':
191
+ return self.bce_loss(pred, torch.zeros_like(pred))
192
+
193
+ raise ValueError(f'Do not support loss type {self.adv_type}')
194
+
195
+ def adv_loss_g(self, pred):
196
+ """Push pred to class 1 (real)"""
197
+ if self.adv_type == 'hinge':
198
+ return -torch.mean(pred)
199
+
200
+ elif self.adv_type == 'lsgan':
201
+ # pred = torch.sigmoid(pred)
202
+ return torch.mean(torch.square(pred - 1.0))
203
+
204
+ elif self.adv_type == 'bce':
205
+ return self.bce_loss(pred, torch.ones_like(pred))
206
+
207
+ raise ValueError(f'Do not support loss type {self.adv_type}')
208
+
209
+
210
+ class LossSummary:
211
+ def __init__(self):
212
+ self.reset()
213
+
214
+ def reset(self):
215
+ self.loss_g_adv = []
216
+ self.loss_content = []
217
+ self.loss_gram = []
218
+ self.loss_color = []
219
+ self.loss_d_adv = []
220
+
221
+ def update_loss_G(self, adv, gram, color, content):
222
+ self.loss_g_adv.append(adv.cpu().detach().numpy())
223
+ self.loss_gram.append(gram.cpu().detach().numpy())
224
+ self.loss_color.append(color.cpu().detach().numpy())
225
+ self.loss_content.append(content.cpu().detach().numpy())
226
+
227
+ def update_loss_D(self, loss):
228
+ self.loss_d_adv.append(loss.cpu().detach().numpy())
229
+
230
+ def avg_loss_G(self):
231
+ return (
232
+ self._avg(self.loss_g_adv),
233
+ self._avg(self.loss_gram),
234
+ self._avg(self.loss_color),
235
+ self._avg(self.loss_content),
236
+ )
237
+
238
+ def avg_loss_D(self):
239
+ return self._avg(self.loss_d_adv)
240
+
241
+ def get_loss_description(self):
242
+ avg_adv, avg_gram, avg_color, avg_content = self.avg_loss_G()
243
+ avg_adv_d = self.avg_loss_D()
244
+ return f'loss G: adv {avg_adv:2f} con {avg_content:2f} gram {avg_gram:2f} color {avg_color:2f} / loss D: {avg_adv_d:2f}'
245
+
246
+ @staticmethod
247
+ def _avg(losses):
248
+ return sum(losses) / len(losses)
models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .anime_gan import GeneratorV1
2
+ from .anime_gan_v2 import GeneratorV2
3
+ from .anime_gan_v3 import GeneratorV3
models/anime_gan.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils import spectral_norm
5
+ from .conv_blocks import DownConv
6
+ from .conv_blocks import UpConv
7
+ from .conv_blocks import SeparableConv2D
8
+ from .conv_blocks import InvertedResBlock
9
+ from .conv_blocks import ConvBlock
10
+ from .layers import get_norm
11
+ from utils.common import initialize_weights
12
+
13
+
14
+ class GeneratorV1(nn.Module):
15
+ def __init__(self, dataset=''):
16
+ super(GeneratorV1, self).__init__()
17
+ self.name = f'{self.__class__.__name__}_{dataset}'
18
+ bias = False
19
+
20
+ self.encode_blocks = nn.Sequential(
21
+ ConvBlock(3, 64, bias=bias),
22
+ ConvBlock(64, 128, bias=bias),
23
+ DownConv(128, bias=bias),
24
+ ConvBlock(128, 128, bias=bias),
25
+ SeparableConv2D(128, 256, bias=bias),
26
+ DownConv(256, bias=bias),
27
+ ConvBlock(256, 256, bias=bias),
28
+ )
29
+
30
+ self.res_blocks = nn.Sequential(
31
+ InvertedResBlock(256, 256),
32
+ InvertedResBlock(256, 256),
33
+ InvertedResBlock(256, 256),
34
+ InvertedResBlock(256, 256),
35
+ InvertedResBlock(256, 256),
36
+ InvertedResBlock(256, 256),
37
+ InvertedResBlock(256, 256),
38
+ InvertedResBlock(256, 256),
39
+ )
40
+
41
+ self.decode_blocks = nn.Sequential(
42
+ ConvBlock(256, 128, bias=bias),
43
+ UpConv(128, bias=bias),
44
+ SeparableConv2D(128, 128, bias=bias),
45
+ ConvBlock(128, 128, bias=bias),
46
+ UpConv(128, bias=bias),
47
+ ConvBlock(128, 64, bias=bias),
48
+ ConvBlock(64, 64, bias=bias),
49
+ nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, bias=bias),
50
+ nn.Tanh(),
51
+ )
52
+
53
+ initialize_weights(self)
54
+
55
+ def forward(self, x):
56
+ out = self.encode_blocks(x)
57
+ out = self.res_blocks(out)
58
+ img = self.decode_blocks(out)
59
+
60
+ return img
61
+
62
+
63
+ class Discriminator(nn.Module):
64
+ def __init__(
65
+ self,
66
+ dataset=None,
67
+ num_layers=1,
68
+ use_sn=False,
69
+ norm_type="instance",
70
+ ):
71
+ super(Discriminator, self).__init__()
72
+ self.name = f'discriminator_{dataset}'
73
+ self.bias = False
74
+ channels = 32
75
+
76
+ layers = [
77
+ nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
78
+ nn.LeakyReLU(0.2, True)
79
+ ]
80
+
81
+ in_channels = channels
82
+ for i in range(num_layers):
83
+ layers += [
84
+ nn.Conv2d(in_channels, channels * 2, kernel_size=3, stride=2, padding=1, bias=self.bias),
85
+ nn.LeakyReLU(0.2, True),
86
+ nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=self.bias),
87
+ get_norm(norm_type, channels * 4),
88
+ nn.LeakyReLU(0.2, True),
89
+ ]
90
+ in_channels = channels * 4
91
+ channels *= 2
92
+
93
+ channels *= 2
94
+ layers += [
95
+ nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
96
+ get_norm(norm_type, channels),
97
+ nn.LeakyReLU(0.2, True),
98
+ nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=self.bias),
99
+ ]
100
+
101
+ if use_sn:
102
+ for i in range(len(layers)):
103
+ if isinstance(layers[i], nn.Conv2d):
104
+ layers[i] = spectral_norm(layers[i])
105
+
106
+ self.discriminate = nn.Sequential(*layers)
107
+
108
+ initialize_weights(self)
109
+
110
+ def forward(self, img):
111
+ logits = self.discriminate(img)
112
+ return logits
models/anime_gan_v2.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models.conv_blocks import InvertedResBlock
5
+ from models.conv_blocks import ConvBlock
6
+ from models.conv_blocks import UpConvLNormLReLU
7
+ from utils.common import initialize_weights
8
+
9
+
10
+ class GeneratorV2(nn.Module):
11
+ def __init__(self, dataset=''):
12
+ super(GeneratorV2, self).__init__()
13
+ self.name = f'{self.__class__.__name__}_{dataset}'
14
+
15
+ self.conv_block1 = nn.Sequential(
16
+ ConvBlock(3, 32, kernel_size=7, stride=1, padding=3, norm_type="layer"),
17
+ ConvBlock(32, 64, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"),
18
+ ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"),
19
+ )
20
+
21
+ self.conv_block2 = nn.Sequential(
22
+ ConvBlock(64, 128, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"),
23
+ ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
24
+ )
25
+
26
+ self.res_blocks = nn.Sequential(
27
+ ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
28
+ InvertedResBlock(128, 256, expand_ratio=2, norm_type="layer"),
29
+ InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
30
+ InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
31
+ InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
32
+ ConvBlock(256, 128, kernel_size=3, stride=1, norm_type="layer"),
33
+ )
34
+
35
+ self.conv_block3 = nn.Sequential(
36
+ # UpConvLNormLReLU(128, 128, norm_type="layer"),
37
+ ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
38
+ ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
39
+ )
40
+
41
+ self.conv_block4 = nn.Sequential(
42
+ # UpConvLNormLReLU(128, 64, norm_type="layer"),
43
+ ConvBlock(128, 64, kernel_size=3, stride=1, norm_type="layer"),
44
+ ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"),
45
+ ConvBlock(64, 32, kernel_size=7, padding=3, stride=1, norm_type="layer"),
46
+ )
47
+
48
+ self.decode_blocks = nn.Sequential(
49
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
50
+ nn.Tanh(),
51
+ )
52
+
53
+ initialize_weights(self)
54
+
55
+ def forward(self, x):
56
+ out = self.conv_block1(x)
57
+ out = self.conv_block2(out)
58
+ out = self.res_blocks(out)
59
+ out = F.interpolate(out, scale_factor=2, mode="bilinear")
60
+ out = self.conv_block3(out)
61
+ out = F.interpolate(out, scale_factor=2, mode="bilinear")
62
+ out = self.conv_block4(out)
63
+ img = self.decode_blocks(out)
64
+
65
+ return img
models/anime_gan_v3.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils import spectral_norm
5
+ from models.conv_blocks import DownConv
6
+ from models.conv_blocks import UpConv
7
+ from models.conv_blocks import SeparableConv2D
8
+ from models.conv_blocks import InvertedResBlock
9
+ from models.conv_blocks import ConvBlock
10
+ from utils.common import initialize_weights
11
+
12
+
13
+ class GeneratorV3(nn.Module):
14
+ pass
models/conv_blocks.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from utils.common import initialize_weights
4
+ from .layers import LayerNorm2d, get_norm
5
+
6
+
7
+ class DownConv(nn.Module):
8
+
9
+ def __init__(self, channels, bias=False):
10
+ super(DownConv, self).__init__()
11
+
12
+ self.conv1 = SeparableConv2D(channels, channels, stride=2, bias=bias)
13
+ self.conv2 = SeparableConv2D(channels, channels, stride=1, bias=bias)
14
+
15
+ def forward(self, x):
16
+ out1 = self.conv1(x)
17
+ out2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')
18
+ out2 = self.conv2(out2)
19
+
20
+ return out1 + out2
21
+
22
+
23
+ class UpConv(nn.Module):
24
+ def __init__(self, channels, bias=False):
25
+ super(UpConv, self).__init__()
26
+
27
+ self.conv = SeparableConv2D(channels, channels, stride=1, bias=bias)
28
+
29
+ def forward(self, x):
30
+ out = F.interpolate(x, scale_factor=2.0, mode='bilinear')
31
+ out = self.conv(out)
32
+ return out
33
+
34
+
35
+ class UpConvLNormLReLU(nn.Module):
36
+ """Upsample Conv block with Layer Norm and Leaky ReLU"""
37
+ def __init__(self, in_channels, out_channels, norm_type="instance", bias=False):
38
+ super(UpConvLNormLReLU, self).__init__()
39
+
40
+ self.conv_block = ConvBlock(
41
+ in_channels,
42
+ out_channels,
43
+ kernel_size=3,
44
+ norm_type=norm_type,
45
+ bias=bias,
46
+ )
47
+
48
+ def forward(self, x):
49
+ out = F.interpolate(x, scale_factor=2.0, mode='bilinear')
50
+ out = self.conv_block(out)
51
+ return out
52
+
53
+ class SeparableConv2D(nn.Module):
54
+ def __init__(self, in_channels, out_channels, stride=1, bias=False):
55
+ super(SeparableConv2D, self).__init__()
56
+ self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3,
57
+ stride=stride, padding=1, groups=in_channels, bias=bias)
58
+ self.pointwise = nn.Conv2d(in_channels, out_channels,
59
+ kernel_size=1, stride=1, bias=bias)
60
+ # self.pad =
61
+ self.ins_norm1 = nn.InstanceNorm2d(in_channels)
62
+ self.activation1 = nn.LeakyReLU(0.2, True)
63
+ self.ins_norm2 = nn.InstanceNorm2d(out_channels)
64
+ self.activation2 = nn.LeakyReLU(0.2, True)
65
+
66
+ initialize_weights(self)
67
+
68
+ def forward(self, x):
69
+ out = self.depthwise(x)
70
+ out = self.ins_norm1(out)
71
+ out = self.activation1(out)
72
+
73
+ out = self.pointwise(out)
74
+ out = self.ins_norm2(out)
75
+
76
+ return self.activation2(out)
77
+
78
+
79
+ class ConvBlock(nn.Module):
80
+ """Stack of Conv2D + Norm + LeakyReLU"""
81
+ def __init__(
82
+ self,
83
+ channels,
84
+ out_channels,
85
+ kernel_size=3,
86
+ stride=1,
87
+ groups=1,
88
+ padding=1,
89
+ bias=False,
90
+ norm_type="instance"
91
+ ):
92
+ super(ConvBlock, self).__init__()
93
+
94
+ # if kernel_size == 3 and stride == 1:
95
+ # self.pad = nn.ReflectionPad2d((1, 1, 1, 1))
96
+ # elif kernel_size == 7 and stride == 1:
97
+ # self.pad = nn.ReflectionPad2d((3, 3, 3, 3))
98
+ # elif stride == 2:
99
+ # self.pad = nn.ReflectionPad2d((0, 1, 1, 0))
100
+ # else:
101
+ # self.pad = None
102
+
103
+ self.pad = nn.ReflectionPad2d(padding)
104
+ self.conv = nn.Conv2d(
105
+ channels,
106
+ out_channels,
107
+ kernel_size=kernel_size,
108
+ stride=stride,
109
+ groups=groups,
110
+ padding=0,
111
+ bias=bias
112
+ )
113
+ self.ins_norm = get_norm(norm_type, out_channels)
114
+ self.activation = nn.LeakyReLU(0.2, True)
115
+
116
+ # initialize_weights(self)
117
+
118
+ def forward(self, x):
119
+ if self.pad is not None:
120
+ x = self.pad(x)
121
+ out = self.conv(x)
122
+ out = self.ins_norm(out)
123
+ out = self.activation(out)
124
+ return out
125
+
126
+
127
+ class InvertedResBlock(nn.Module):
128
+ def __init__(
129
+ self,
130
+ channels=256,
131
+ out_channels=256,
132
+ expand_ratio=2,
133
+ norm_type="instance",
134
+ ):
135
+ super(InvertedResBlock, self).__init__()
136
+ bottleneck_dim = round(expand_ratio * channels)
137
+ self.conv_block = ConvBlock(
138
+ channels,
139
+ bottleneck_dim,
140
+ kernel_size=1,
141
+ padding=0,
142
+ norm_type=norm_type,
143
+ bias=False
144
+ )
145
+ self.conv_block2 = ConvBlock(
146
+ bottleneck_dim,
147
+ bottleneck_dim,
148
+ groups=bottleneck_dim,
149
+ norm_type=norm_type,
150
+ bias=True
151
+ )
152
+ self.conv = nn.Conv2d(
153
+ bottleneck_dim,
154
+ out_channels,
155
+ kernel_size=1,
156
+ padding=0,
157
+ bias=False
158
+ )
159
+ self.norm = get_norm(norm_type, out_channels)
160
+
161
+ def forward(self, x):
162
+ out = self.conv_block(x)
163
+ out = self.conv_block2(out)
164
+ # out = self.activation(out)
165
+ out = self.conv(out)
166
+ out = self.norm(out)
167
+
168
+ if out.shape[1] != x.shape[1]:
169
+ # Only concate if same shape
170
+ return out
171
+ return out + x
models/layers.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+
7
+ class LayerNorm2d(nn.LayerNorm):
8
+ """ LayerNorm for channels of '2D' spatial NCHW tensors """
9
+ def __init__(self, num_channels, eps=1e-6, affine=True):
10
+ super().__init__(num_channels, eps=eps, elementwise_affine=affine)
11
+
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ # https://pytorch.org/vision/0.12/_modules/torchvision/models/convnext.html
14
+ x = x.permute(0, 2, 3, 1)
15
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
16
+ x = x.permute(0, 3, 1, 2)
17
+ return x
18
+
19
+
20
+ def get_norm(norm_type, channels):
21
+ if norm_type == "instance":
22
+ return nn.InstanceNorm2d(channels)
23
+ elif norm_type == "layer":
24
+ # return LayerNorm2d
25
+ return nn.GroupNorm(num_groups=1, num_channels=channels, affine=True)
26
+ # return partial(nn.GroupNorm, 1, out_ch, 1e-5, True)
27
+ else:
28
+ raise ValueError(norm_type)
models/vgg.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy.lib.arraysetops import isin
2
+ import torchvision.models as models
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+
7
+
8
+ class Vgg19(nn.Module):
9
+ def __init__(self):
10
+ super(Vgg19, self).__init__()
11
+ self.vgg19 = self.get_vgg19().eval()
12
+ vgg_mean = torch.tensor([0.485, 0.456, 0.406]).float()
13
+ vgg_std = torch.tensor([0.229, 0.224, 0.225]).float()
14
+ self.mean = vgg_mean.view(-1, 1 ,1)
15
+ self.std = vgg_std.view(-1, 1, 1)
16
+
17
+ def to(self, device):
18
+ new_self = super(Vgg19, self).to(device)
19
+ new_self.mean = new_self.mean.to(device)
20
+ new_self.std = new_self.std.to(device)
21
+ return new_self
22
+
23
+ def forward(self, x):
24
+ return self.vgg19(self.normalize_vgg(x))
25
+
26
+ @staticmethod
27
+ def get_vgg19(last_layer='conv4_4'):
28
+ vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
29
+ model_list = []
30
+
31
+ i = 0
32
+ j = 1
33
+ for layer in vgg.children():
34
+ if isinstance(layer, nn.MaxPool2d):
35
+ i = 0
36
+ j += 1
37
+
38
+ elif isinstance(layer, nn.Conv2d):
39
+ i += 1
40
+
41
+ name = f'conv{j}_{i}'
42
+
43
+ if name == last_layer:
44
+ model_list.append(layer)
45
+ break
46
+
47
+ model_list.append(layer)
48
+
49
+
50
+ model = nn.Sequential(*model_list)
51
+ return model
52
+
53
+
54
+ def normalize_vgg(self, image):
55
+ '''
56
+ Expect input in range -1 1
57
+ '''
58
+ image = (image + 1.0) / 2.0
59
+ return (image - self.mean) / self.std
60
+
61
+
62
+ if __name__ == '__main__':
63
+ from PIL import Image
64
+ import numpy as np
65
+ from utils.image_processing import normalize_input
66
+
67
+ image = Image.open("example/10.jpg")
68
+ image = image.resize((224, 224))
69
+ np_img = np.array(image).astype('float32')
70
+ np_img = normalize_input(np_img)
71
+
72
+ img = torch.from_numpy(np_img)
73
+ img = img.permute(2, 0, 1)
74
+ img = img.unsqueeze(0)
75
+
76
+ vgg = Vgg19()
77
+
78
+ feat = vgg(img)
79
+
80
+ print(feat.shape)
predict.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from inference import Predictor as MyPredictor
3
+ from utils import read_image
4
+ import cv2
5
+ import tempfile
6
+ from utils.image_processing import resize_image, normalize_input, denormalize_input
7
+ import numpy as np
8
+ from cog import BasePredictor, Path, Input
9
+
10
+
11
+ class Predictor(BasePredictor):
12
+ def setup(self):
13
+ pass
14
+
15
+ def predict(
16
+ self,
17
+ image: Path = Input(description="Image"),
18
+ model: str = Input(
19
+ description="Style",
20
+ default='Hayao:v2',
21
+ choices=[
22
+ 'Hayao',
23
+ 'Shinkai',
24
+ 'Hayao:v2'
25
+ ]
26
+ )
27
+ ) -> Path:
28
+ version = model.split(":")[-1]
29
+ predictor = MyPredictor(model, version)
30
+ img = read_image(str(image))
31
+ anime_img = predictor.transform(resize_image(img))[0]
32
+ out_path = Path(tempfile.mkdtemp()) / "out.png"
33
+ cv2.imwrite(str(out_path), anime_img[..., ::-1])
34
+ return out_path
35
+
train.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import os
4
+ from models.anime_gan import GeneratorV1
5
+ from models.anime_gan_v2 import GeneratorV2
6
+ from models.anime_gan_v3 import GeneratorV3
7
+ from models.anime_gan import Discriminator
8
+ from datasets import AnimeDataSet
9
+ from utils.common import load_checkpoint
10
+ from trainer import Trainer
11
+ from utils.logger import get_logger
12
+
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--real_image_dir', type=str, default='dataset/train_photo')
17
+ parser.add_argument('--anime_image_dir', type=str, default='dataset/Hayao')
18
+ parser.add_argument('--test_image_dir', type=str, default='dataset/test/HR_photo')
19
+ parser.add_argument('--model', type=str, default='v1', help="AnimeGAN version, can be {'v1', 'v2', 'v3'}")
20
+ parser.add_argument('--epochs', type=int, default=70)
21
+ parser.add_argument('--init_epochs', type=int, default=10)
22
+ parser.add_argument('--batch_size', type=int, default=8)
23
+ parser.add_argument('--exp_dir', type=str, default='runs', help="Experiment directory")
24
+ parser.add_argument('--gan_loss', type=str, default='lsgan', help='lsgan / hinge / bce')
25
+ parser.add_argument('--resume', action='store_true', help="Continue from current dir")
26
+ parser.add_argument('--resume_G_init', type=str, default='False')
27
+ parser.add_argument('--resume_G', type=str, default='False')
28
+ parser.add_argument('--resume_D', type=str, default='False')
29
+ parser.add_argument('--device', type=str, default='cuda')
30
+ parser.add_argument('--use_sn', action='store_true')
31
+ parser.add_argument('--cache', action='store_true', help="Turn on disk cache")
32
+ parser.add_argument('--amp', action='store_true', help="Turn on Automatic Mixed Precision")
33
+ parser.add_argument('--save_interval', type=int, default=1)
34
+ parser.add_argument('--debug_samples', type=int, default=0)
35
+ parser.add_argument('--num_workers', type=int, default=2)
36
+ parser.add_argument('--imgsz', type=int, nargs="+", default=[256],
37
+ help="Image sizes, can provide multiple values, image size will increase after a proportion of epochs")
38
+ parser.add_argument('--resize_method', type=str, default="crop",
39
+ help="Resize image method if origin photo larger than imgsz")
40
+ # Loss stuff
41
+ parser.add_argument('--lr_g', type=float, default=2e-5)
42
+ parser.add_argument('--lr_d', type=float, default=4e-5)
43
+ parser.add_argument('--init_lr', type=float, default=1e-4)
44
+ parser.add_argument('--wadvg', type=float, default=300.0, help='Adversarial loss weight for G')
45
+ parser.add_argument('--wadvd', type=float, default=300.0, help='Adversarial loss weight for D')
46
+ parser.add_argument(
47
+ '--gray_adv', action='store_true',
48
+ help="If given, train adversarial with gray scale image instead of RGB image to reduce color effect of anime style")
49
+ # Loss weight VGG19
50
+ parser.add_argument('--wcon', type=float, default=1.5, help='Content loss weight') # 1.5 for Hayao, 2.0 for Paprika, 1.2 for Shinkai
51
+ parser.add_argument('--wgra', type=float, default=5.0, help='Gram loss weight') # 2.5 for Hayao, 0.6 for Paprika, 2.0 for Shinkai
52
+ parser.add_argument('--wcol', type=float, default=30.0, help='Color loss weight') # 15. for Hayao, 50. for Paprika, 10. for Shinkai
53
+ parser.add_argument('--wtvar', type=float, default=1.0, help='Total variation loss') # 1. for Hayao, 0.1 for Paprika, 1. for Shinkai
54
+ parser.add_argument('--d_layers', type=int, default=2, help='Discriminator conv layers')
55
+ parser.add_argument('--d_noise', action='store_true')
56
+
57
+ # DDP
58
+ parser.add_argument('--ddp', action='store_true')
59
+ parser.add_argument("--local-rank", default=0, type=int)
60
+ parser.add_argument("--world-size", default=2, type=int)
61
+
62
+ return parser.parse_args()
63
+
64
+
65
+ def check_params(args):
66
+ # dataset/Hayao + dataset/train_photo -> train_photo_Hayao
67
+ args.dataset = f"{os.path.basename(args.real_image_dir)}_{os.path.basename(args.anime_image_dir)}"
68
+ assert args.gan_loss in {'lsgan', 'hinge', 'bce'}, f'{args.gan_loss} is not supported'
69
+
70
+
71
+ def main(args, logger):
72
+ check_params(args)
73
+
74
+ if not torch.cuda.is_available():
75
+ logger.info("CUDA not found, use CPU")
76
+ # Just for debugging purpose, set to minimum config
77
+ # to avoid 🔥 the computer...
78
+ args.device = 'cpu'
79
+ args.debug_samples = 10
80
+ args.batch_size = 2
81
+ else:
82
+ logger.info(f"Use GPU: {torch.cuda.get_device_name(0)}")
83
+
84
+ norm_type = "instance"
85
+ if args.model == 'v1':
86
+ G = GeneratorV1(args.dataset)
87
+ elif args.model == 'v2':
88
+ G = GeneratorV2(args.dataset)
89
+ norm_type = "layer"
90
+ elif args.model == 'v3':
91
+ G = GeneratorV3(args.dataset)
92
+
93
+ D = Discriminator(
94
+ args.dataset,
95
+ num_layers=args.d_layers,
96
+ use_sn=args.use_sn,
97
+ norm_type=norm_type,
98
+ )
99
+
100
+ start_e = 0
101
+ start_e_init = 0
102
+
103
+ trainer = Trainer(
104
+ generator=G,
105
+ discriminator=D,
106
+ config=args,
107
+ logger=logger,
108
+ )
109
+
110
+ if args.resume_G_init.lower() != 'false':
111
+ start_e_init = load_checkpoint(G, args.resume_G_init) + 1
112
+ if args.local_rank == 0:
113
+ logger.info(f"G content weight loaded from {args.resume_G_init}")
114
+ elif args.resume_G.lower() != 'false' and args.resume_D.lower() != 'false':
115
+ # You should provide both
116
+ try:
117
+ start_e = load_checkpoint(G, args.resume_G)
118
+ if args.local_rank == 0:
119
+ logger.info(f"G weight loaded from {args.resume_G}")
120
+ load_checkpoint(D, args.resume_D)
121
+ if args.local_rank == 0:
122
+ logger.info(f"D weight loaded from {args.resume_D}")
123
+ # If loaded both weight, turn off init G phrase
124
+ args.init_epochs = 0
125
+
126
+ except Exception as e:
127
+ print('Could not load checkpoint, train from scratch', e)
128
+ elif args.resume:
129
+ # Try to load from working dir
130
+ logger.info(f"Loading weight from {trainer.checkpoint_path_G}")
131
+ start_e = load_checkpoint(G, trainer.checkpoint_path_G)
132
+ logger.info(f"Loading weight from {trainer.checkpoint_path_D}")
133
+ load_checkpoint(D, trainer.checkpoint_path_D)
134
+ args.init_epochs = 0
135
+
136
+ dataset = AnimeDataSet(
137
+ args.anime_image_dir,
138
+ args.real_image_dir,
139
+ args.debug_samples,
140
+ args.cache,
141
+ imgsz=args.imgsz,
142
+ resize_method=args.resize_method,
143
+ )
144
+ if args.local_rank == 0:
145
+ logger.info(f"Start from epoch {start_e}, {start_e_init}")
146
+ trainer.train(dataset, start_e, start_e_init)
147
+
148
+ if __name__ == '__main__':
149
+ args = parse_args()
150
+ real_name = os.path.basename(args.real_image_dir)
151
+ anime_name = os.path.basename(args.anime_image_dir)
152
+ args.exp_dir = f"{args.exp_dir}_{real_name}_{anime_name}"
153
+
154
+ os.makedirs(args.exp_dir, exist_ok=True)
155
+ logger = get_logger(os.path.join(args.exp_dir, "train.log"))
156
+
157
+ if args.local_rank == 0:
158
+ logger.info("# ==== Train Config ==== #")
159
+ for arg in vars(args):
160
+ logger.info(f"{arg} {getattr(args, arg)}")
161
+ logger.info("==========================")
162
+
163
+ main(args, logger)
trainer/__init__.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import shutil
4
+
5
+ import torch
6
+ import cv2
7
+ import torch.optim as optim
8
+ import numpy as np
9
+ from glob import glob
10
+ from torch.cuda.amp import GradScaler, autocast
11
+ from torch.nn.parallel.distributed import DistributedDataParallel
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from tqdm import tqdm
14
+ from utils.image_processing import denormalize_input, preprocess_images, resize_image
15
+ from losses import LossSummary, AnimeGanLoss, to_gray_scale
16
+ from utils import load_checkpoint, save_checkpoint, read_image
17
+ from utils.common import set_lr
18
+ from color_transfer import color_transfer_pytorch
19
+
20
+
21
+ def transfer_color_and_rescale(src, target):
22
+ """Transfer color from src image to target then rescale to [-1, 1]"""
23
+ out = color_transfer_pytorch(src, target) # [0, 1]
24
+ out = (out / 0.5) - 1
25
+ return out
26
+
27
+ def gaussian_noise():
28
+ gaussian_mean = torch.tensor(0.0)
29
+ gaussian_std = torch.tensor(0.1)
30
+ return torch.normal(gaussian_mean, gaussian_std)
31
+
32
+ def convert_to_readable(seconds):
33
+ return time.strftime('%H:%M:%S', time.gmtime(seconds))
34
+
35
+
36
+ def revert_to_np_image(image_tensor):
37
+ image = image_tensor.cpu().numpy()
38
+ # CHW
39
+ image = image.transpose(1, 2, 0)
40
+ image = denormalize_input(image, dtype=np.int16)
41
+ return image[..., ::-1] # to RGB
42
+
43
+
44
+ def save_generated_images(images: torch.Tensor, save_dir: str):
45
+ """Save generated images `(*, 3, H, W)` range [-1, 1] into disk"""
46
+ os.makedirs(save_dir, exist_ok=True)
47
+ images = images.clone().detach().cpu().numpy()
48
+ images = images.transpose(0, 2, 3, 1)
49
+ n_images = len(images)
50
+
51
+ for i in range(n_images):
52
+ img = images[i]
53
+ img = denormalize_input(img, dtype=np.int16)
54
+ img = img[..., ::-1]
55
+ cv2.imwrite(os.path.join(save_dir, f"G{i}.jpg"), img)
56
+
57
+
58
+ class DDPTrainer:
59
+ def _init_distributed(self):
60
+ if self.cfg.ddp:
61
+ self.logger.info("Setting up DDP")
62
+ self.pg = torch.distributed.init_process_group(
63
+ backend="nccl",
64
+ rank=self.cfg.local_rank,
65
+ world_size=self.cfg.world_size
66
+ )
67
+ self.G = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.G, self.pg)
68
+ self.D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.D, self.pg)
69
+ torch.cuda.set_device(self.cfg.local_rank)
70
+ self.G.cuda(self.cfg.local_rank)
71
+ self.D.cuda(self.cfg.local_rank)
72
+ self.logger.info("Setting up DDP Done")
73
+
74
+ def _init_amp(self, enabled=False):
75
+ # self.scaler = torch.cuda.amp.GradScaler(enabled=enabled, growth_interval=100)
76
+ self.scaler_g = GradScaler(enabled=enabled)
77
+ self.scaler_d = GradScaler(enabled=enabled)
78
+ if self.cfg.ddp:
79
+ self.G = DistributedDataParallel(
80
+ self.G, device_ids=[self.cfg.local_rank],
81
+ output_device=self.cfg.local_rank,
82
+ find_unused_parameters=False)
83
+
84
+ self.D = DistributedDataParallel(
85
+ self.D, device_ids=[self.cfg.local_rank],
86
+ output_device=self.cfg.local_rank,
87
+ find_unused_parameters=False)
88
+ self.logger.info("Set DistributedDataParallel")
89
+
90
+
91
+ class Trainer(DDPTrainer):
92
+ """
93
+ Base Trainer class
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ generator,
99
+ discriminator,
100
+ config,
101
+ logger,
102
+ ) -> None:
103
+ self.G = generator
104
+ self.D = discriminator
105
+ self.cfg = config
106
+ self.max_norm = 10
107
+ self.device_type = 'cuda' if self.cfg.device.startswith('cuda') else 'cpu'
108
+ self.optimizer_g = optim.Adam(self.G.parameters(), lr=self.cfg.lr_g, betas=(0.5, 0.999))
109
+ self.optimizer_d = optim.Adam(self.D.parameters(), lr=self.cfg.lr_d, betas=(0.5, 0.999))
110
+ self.loss_tracker = LossSummary()
111
+ if self.cfg.ddp:
112
+ self.device = torch.device(f"cuda:{self.cfg.local_rank}")
113
+ logger.info(f"---------{self.cfg.local_rank} {self.device}")
114
+ else:
115
+ self.device = torch.device(self.cfg.device)
116
+ self.loss_fn = AnimeGanLoss(self.cfg, self.device, self.cfg.gray_adv)
117
+ self.logger = logger
118
+ self._init_working_dir()
119
+ self._init_distributed()
120
+ self._init_amp(enabled=self.cfg.amp)
121
+
122
+ def _init_working_dir(self):
123
+ """Init working directory for saving checkpoint, ..."""
124
+ os.makedirs(self.cfg.exp_dir, exist_ok=True)
125
+ Gname = self.G.name
126
+ Dname = self.D.name
127
+ self.checkpoint_path_G_init = os.path.join(self.cfg.exp_dir, f"{Gname}_init.pt")
128
+ self.checkpoint_path_G = os.path.join(self.cfg.exp_dir, f"{Gname}.pt")
129
+ self.checkpoint_path_D = os.path.join(self.cfg.exp_dir, f"{Dname}.pt")
130
+ self.save_image_dir = os.path.join(self.cfg.exp_dir, "generated_images")
131
+ self.example_image_dir = os.path.join(self.cfg.exp_dir, "train_images")
132
+ os.makedirs(self.save_image_dir, exist_ok=True)
133
+ os.makedirs(self.example_image_dir, exist_ok=True)
134
+
135
+ def init_weight_G(self, weight: str):
136
+ """Init Generator weight"""
137
+ return load_checkpoint(self.G, weight)
138
+
139
+ def init_weight_D(self, weight: str):
140
+ """Init Discriminator weight"""
141
+ return load_checkpoint(self.D, weight)
142
+
143
+ def pretrain_generator(self, train_loader, start_epoch):
144
+ """
145
+ Pretrain Generator to recontruct input image.
146
+ """
147
+ init_losses = []
148
+ set_lr(self.optimizer_g, self.cfg.init_lr)
149
+ for epoch in range(start_epoch, self.cfg.init_epochs):
150
+ # Train with content loss only
151
+
152
+ pbar = tqdm(train_loader)
153
+ for data in pbar:
154
+ img = data["image"].to(self.device)
155
+
156
+ self.optimizer_g.zero_grad()
157
+
158
+ with autocast(enabled=self.cfg.amp):
159
+ fake_img = self.G(img)
160
+ loss = self.loss_fn.content_loss_vgg(img, fake_img)
161
+
162
+ self.scaler_g.scale(loss).backward()
163
+ self.scaler_g.step(self.optimizer_g)
164
+ self.scaler_g.update()
165
+
166
+ if self.cfg.ddp:
167
+ torch.distributed.barrier()
168
+
169
+ init_losses.append(loss.cpu().detach().numpy())
170
+ avg_content_loss = sum(init_losses) / len(init_losses)
171
+ pbar.set_description(f'[Init Training G] content loss: {avg_content_loss:2f}')
172
+
173
+ save_checkpoint(self.G, self.checkpoint_path_G_init, self.optimizer_g, epoch)
174
+ if self.cfg.local_rank == 0:
175
+ self.generate_and_save(self.cfg.test_image_dir, subname='initg')
176
+ self.logger.info(f"Epoch {epoch}/{self.cfg.init_epochs}")
177
+
178
+ set_lr(self.optimizer_g, self.cfg.lr_g)
179
+
180
+ def train_epoch(self, epoch, train_loader):
181
+ pbar = tqdm(train_loader, total=len(train_loader))
182
+ for data in pbar:
183
+ img = data["image"].to(self.device)
184
+ anime = data["anime"].to(self.device)
185
+ anime_gray = data["anime_gray"].to(self.device)
186
+ anime_smt_gray = data["smooth_gray"].to(self.device)
187
+
188
+ # ---------------- TRAIN D ---------------- #
189
+ self.optimizer_d.zero_grad()
190
+
191
+ with autocast(enabled=self.cfg.amp):
192
+ fake_img = self.G(img)
193
+ # Add some Gaussian noise to images before feeding to D
194
+ if self.cfg.d_noise:
195
+ fake_img += gaussian_noise()
196
+ anime += gaussian_noise()
197
+ anime_gray += gaussian_noise()
198
+ anime_smt_gray += gaussian_noise()
199
+
200
+ if self.cfg.gray_adv:
201
+ fake_img = to_gray_scale(fake_img)
202
+
203
+ fake_d = self.D(fake_img)
204
+ real_anime_d = self.D(anime)
205
+ real_anime_gray_d = self.D(anime_gray)
206
+ real_anime_smt_gray_d = self.D(anime_smt_gray)
207
+
208
+ loss_d = self.loss_fn.compute_loss_D(
209
+ fake_d,
210
+ real_anime_d,
211
+ real_anime_gray_d,
212
+ real_anime_smt_gray_d
213
+ )
214
+
215
+ self.scaler_d.scale(loss_d).backward()
216
+ self.scaler_d.unscale_(self.optimizer_d)
217
+ torch.nn.utils.clip_grad_norm_(self.D.parameters(), max_norm=self.max_norm)
218
+ self.scaler_d.step(self.optimizer_d)
219
+ self.scaler_d.update()
220
+ if self.cfg.ddp:
221
+ torch.distributed.barrier()
222
+ self.loss_tracker.update_loss_D(loss_d)
223
+
224
+ # ---------------- TRAIN G ---------------- #
225
+ self.optimizer_g.zero_grad()
226
+
227
+ with autocast(enabled=self.cfg.amp):
228
+ fake_img = self.G(img)
229
+
230
+ if self.cfg.gray_adv:
231
+ fake_d = self.D(to_gray_scale(fake_img))
232
+ else:
233
+ fake_d = self.D(fake_img)
234
+
235
+ (
236
+ adv_loss, con_loss,
237
+ gra_loss, col_loss,
238
+ tv_loss
239
+ ) = self.loss_fn.compute_loss_G(
240
+ fake_img,
241
+ img,
242
+ fake_d,
243
+ anime_gray,
244
+ )
245
+ loss_g = adv_loss + con_loss + gra_loss + col_loss + tv_loss
246
+ if torch.isnan(adv_loss).any():
247
+ self.logger.info("----------------------------------------------")
248
+ self.logger.info(fake_d)
249
+ self.logger.info(adv_loss)
250
+ self.logger.info("----------------------------------------------")
251
+ raise ValueError("NAN loss!!")
252
+
253
+ self.scaler_g.scale(loss_g).backward()
254
+ self.scaler_d.unscale_(self.optimizer_g)
255
+ grad = torch.nn.utils.clip_grad_norm_(self.G.parameters(), max_norm=self.max_norm)
256
+ self.scaler_g.step(self.optimizer_g)
257
+ self.scaler_g.update()
258
+ if self.cfg.ddp:
259
+ torch.distributed.barrier()
260
+
261
+ self.loss_tracker.update_loss_G(adv_loss, gra_loss, col_loss, con_loss)
262
+ pbar.set_description(f"{self.loss_tracker.get_loss_description()} - {grad:.3f}")
263
+
264
+ def get_train_loader(self, dataset):
265
+ if self.cfg.ddp:
266
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
267
+ else:
268
+ train_sampler = None
269
+ return DataLoader(
270
+ dataset,
271
+ batch_size=self.cfg.batch_size,
272
+ num_workers=self.cfg.num_workers,
273
+ pin_memory=True,
274
+ shuffle=train_sampler is None,
275
+ sampler=train_sampler,
276
+ drop_last=True,
277
+ # collate_fn=collate_fn,
278
+ )
279
+
280
+ def maybe_increase_imgsz(self, epoch, train_dataset):
281
+ """
282
+ Increase image size at specific epoch
283
+ + 50% epochs train at imgsz[0]
284
+ + the rest 50% will increase every `len(epochs) / 2 / (len(imgsz) - 1)`
285
+
286
+ Args:
287
+ epoch: Current epoch
288
+ train_dataset: Dataset
289
+
290
+ Examples:
291
+ ```
292
+ epochs = 100
293
+ imgsz = [256, 352, 416, 512]
294
+ => [(0, 256), (50, 352), (66, 416), (82, 512)]
295
+ ```
296
+ """
297
+ epochs = self.cfg.epochs
298
+ imgsz = self.cfg.imgsz
299
+ num_size_remains = len(imgsz) - 1
300
+ half_epochs = epochs // 2
301
+
302
+ if len(imgsz) == 1:
303
+ new_size = imgsz[0]
304
+ elif epoch < half_epochs:
305
+ new_size = imgsz[0]
306
+ else:
307
+ per_epoch_increment = int(half_epochs / num_size_remains)
308
+ found = None
309
+ for i, size in enumerate(imgsz[:]):
310
+ if epoch < half_epochs + per_epoch_increment * i:
311
+ found = size
312
+ break
313
+ if not found:
314
+ found = imgsz[-1]
315
+ new_size = found
316
+
317
+ self.logger.info(f"Check {imgsz}, {new_size}, {train_dataset.imgsz}")
318
+ if new_size != train_dataset.imgsz:
319
+ train_dataset.set_imgsz(new_size)
320
+ self.logger.info(f"Increase image size to {new_size} at epoch {epoch}")
321
+
322
+ def train(self, train_dataset: Dataset, start_epoch=0, start_epoch_g=0):
323
+ """
324
+ Train Generator and Discriminator.
325
+ """
326
+ self.logger.info(self.device)
327
+ self.G.to(self.device)
328
+ self.D.to(self.device)
329
+
330
+ self.pretrain_generator(self.get_train_loader(train_dataset), start_epoch_g)
331
+
332
+ if self.cfg.local_rank == 0:
333
+ self.logger.info(f"Start training for {self.cfg.epochs} epochs")
334
+
335
+ for i, data in enumerate(train_dataset):
336
+ for k in data.keys():
337
+ image = data[k]
338
+ cv2.imwrite(
339
+ os.path.join(self.example_image_dir, f"data_{k}_{i}.jpg"),
340
+ revert_to_np_image(image)
341
+ )
342
+ if i == 2:
343
+ break
344
+
345
+ end = None
346
+ num_iter = 0
347
+ per_epoch_times = []
348
+ for epoch in range(start_epoch, self.cfg.epochs):
349
+ self.maybe_increase_imgsz(epoch, train_dataset)
350
+
351
+ start = time.time()
352
+ self.train_epoch(epoch, self.get_train_loader(train_dataset))
353
+
354
+ if epoch % self.cfg.save_interval == 0 and self.cfg.local_rank == 0:
355
+ save_checkpoint(self.G, self.checkpoint_path_G,self.optimizer_g, epoch)
356
+ save_checkpoint(self.D, self.checkpoint_path_D, self.optimizer_d, epoch)
357
+ self.generate_and_save(self.cfg.test_image_dir)
358
+
359
+ if epoch % 10 == 0:
360
+ self.copy_results(epoch)
361
+
362
+ num_iter += 1
363
+
364
+ if self.cfg.local_rank == 0:
365
+ end = time.time()
366
+ if end is None:
367
+ eta = 9999
368
+ else:
369
+ per_epoch_time = (end - start)
370
+ per_epoch_times.append(per_epoch_time)
371
+ eta = np.mean(per_epoch_times) * (self.cfg.epochs - epoch)
372
+ eta = convert_to_readable(eta)
373
+ self.logger.info(f"epoch {epoch}/{self.cfg.epochs}, ETA: {eta}")
374
+
375
+ def generate_and_save(
376
+ self,
377
+ image_dir,
378
+ max_imgs=15,
379
+ subname='gen'
380
+ ):
381
+ '''
382
+ Generate and save images
383
+ '''
384
+ start = time.time()
385
+ self.G.eval()
386
+
387
+ max_iter = max_imgs
388
+ fake_imgs = []
389
+ real_imgs = []
390
+ image_files = glob(os.path.join(image_dir, "*"))
391
+
392
+ for i, image_file in enumerate(image_files):
393
+ image = read_image(image_file)
394
+ image = resize_image(image)
395
+ real_imgs.append(image.copy())
396
+ image = preprocess_images(image)
397
+ image = image.to(self.device)
398
+ with torch.no_grad():
399
+ with autocast(enabled=self.cfg.amp):
400
+ fake_img = self.G(image)
401
+ # fake_img = to_gray_scale(fake_img)
402
+ fake_img = fake_img.detach().cpu().numpy()
403
+ # Channel first -> channel last
404
+ fake_img = fake_img.transpose(0, 2, 3, 1)
405
+ fake_imgs.append(denormalize_input(fake_img, dtype=np.int16)[0])
406
+
407
+ if i + 1 == max_iter:
408
+ break
409
+
410
+ # fake_imgs = np.concatenate(fake_imgs, axis=0)
411
+
412
+ for i, (real_img, fake_img) in enumerate(zip(real_imgs, fake_imgs)):
413
+ img = np.concatenate((real_img, fake_img), axis=1) # Concate aross width
414
+ save_path = os.path.join(self.save_image_dir, f'{subname}_{i}.jpg')
415
+ if not cv2.imwrite(save_path, img[..., ::-1]):
416
+ self.logger.info(f"Save generated image failed, {save_path}, {img.shape}")
417
+ elapsed = time.time() - start
418
+ self.logger.info(f"Generated {len(fake_imgs)} images in {elapsed:.3f}s.")
419
+
420
+ def copy_results(self, epoch):
421
+ """Copy result (Weight + Generated images) to each epoch folder
422
+ Every N epoch
423
+ """
424
+ copy_dir = os.path.join(self.cfg.exp_dir, f"epoch_{epoch}")
425
+ os.makedirs(copy_dir, exist_ok=True)
426
+
427
+ shutil.copy2(
428
+ self.checkpoint_path_G,
429
+ copy_dir
430
+ )
431
+
432
+ dest = os.path.join(copy_dir, os.path.basename(self.save_image_dir))
433
+ shutil.copytree(
434
+ self.save_image_dir,
435
+ dest,
436
+ dirs_exist_ok=True
437
+ )