Spaces:
Sleeping
Sleeping
first
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +4 -0
- README.md +1 -0
- app.py +95 -0
- example/arcane/anne.jpg +0 -0
- example/arcane/boy2.jpg +0 -0
- example/arcane/cap.jpg +0 -0
- example/arcane/dune2.jpg +0 -0
- example/arcane/elon.jpg +0 -0
- example/arcane/girl.jpg +0 -0
- example/arcane/girl4.jpg +0 -0
- example/arcane/girl6.jpg +0 -0
- example/arcane/leo.jpg +0 -0
- example/arcane/man2.jpg +0 -0
- example/arcane/nat_.jpg +0 -0
- example/arcane/seydoux.jpg +0 -0
- example/arcane/tobey.jpg +0 -0
- example/face/anne.jpg +0 -0
- example/face/boy2.jpg +0 -0
- example/face/cap.jpg +0 -0
- example/face/dune2.jpg +0 -0
- example/face/elon.jpg +0 -0
- example/face/girl.jpg +0 -0
- example/face/girl4.jpg +0 -0
- example/face/girl6.jpg +0 -0
- example/face/leo.jpg +0 -0
- example/face/man2.jpg +0 -0
- example/face/nat_.jpg +0 -0
- example/face/seydoux.jpg +0 -0
- example/face/tobey.jpg +0 -0
- example/generate_examples.py +49 -0
- example/more/hayao_v2/pexels-arnie-chou-304906-1004122.jpg +0 -0
- example/more/hayao_v2/pexels-camilacarneiro-6318793.jpg +0 -0
- example/more/hayao_v2/pexels-haohd-19859127.jpg +0 -0
- example/more/hayao_v2/pexels-huy-nguyen-748440234-19838813.jpg +0 -0
- example/more/hayao_v2/pexels-huy-phan-316220-1422386.jpg +0 -0
- example/more/hayao_v2/pexels-jimmy-teoh-294331-951531.jpg +0 -0
- example/more/hayao_v2/pexels-nandhukumar-450441.jpg +0 -0
- example/more/hayao_v2/pexels-sevenstormphotography-575362.jpg +0 -0
- inference.py +410 -0
- losses.py +248 -0
- models/__init__.py +3 -0
- models/anime_gan.py +112 -0
- models/anime_gan_v2.py +65 -0
- models/anime_gan_v3.py +14 -0
- models/conv_blocks.py +171 -0
- models/layers.py +28 -0
- models/vgg.py +80 -0
- predict.py +35 -0
- train.py +163 -0
- trainer/__init__.py +437 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.cache
|
2 |
+
__pycache__
|
3 |
+
output
|
4 |
+
.token
|
README.md
CHANGED
@@ -11,3 +11,4 @@ license: mit
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
14 |
+
<!-- https://huggingface.co/spaces/ptran1203/pytorchAnimeGAN -->
|
app.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from inference import Predictor
|
6 |
+
from utils.image_processing import resize_image
|
7 |
+
|
8 |
+
os.makedirs('output', exist_ok=True)
|
9 |
+
|
10 |
+
|
11 |
+
def inference(
|
12 |
+
image: np.ndarray,
|
13 |
+
style,
|
14 |
+
imgsz=None,
|
15 |
+
):
|
16 |
+
retain_color = False
|
17 |
+
|
18 |
+
weight = {
|
19 |
+
"AnimeGAN_Hayao": "hayao",
|
20 |
+
"AnimeGAN_Shinkai": "shinkai",
|
21 |
+
"AnimeGANv2_Hayao": "hayao:v2",
|
22 |
+
"AnimeGANv2_Shinkai": "shinkai:v2",
|
23 |
+
"AnimeGANv2_Arcane": "arcane:v2",
|
24 |
+
}[style]
|
25 |
+
predictor = Predictor(
|
26 |
+
weight,
|
27 |
+
device='cpu',
|
28 |
+
retain_color=retain_color,
|
29 |
+
imgsz=imgsz,
|
30 |
+
)
|
31 |
+
|
32 |
+
save_path = f"output/out.jpg"
|
33 |
+
image = resize_image(image, width=imgsz)
|
34 |
+
anime_image = predictor.transform(image)[0]
|
35 |
+
cv2.imwrite(save_path, anime_image[..., ::-1])
|
36 |
+
return anime_image, save_path
|
37 |
+
|
38 |
+
|
39 |
+
title = "AnimeGANv2: To produce your own animation."
|
40 |
+
description = r"""Turn your photo into anime style 😊"""
|
41 |
+
article = r"""
|
42 |
+
[![GitHub Stars](https://img.shields.io/github/stars/ptran1203/pytorch-animeGAN?style=social)](https://github.com/ptran1203/pytorch-animeGAN)
|
43 |
+
### 🗻 Demo
|
44 |
+
|
45 |
+
"""
|
46 |
+
|
47 |
+
gr.Interface(
|
48 |
+
fn=inference,
|
49 |
+
inputs=[
|
50 |
+
gr.components.Image(label="Input"),
|
51 |
+
gr.Dropdown(
|
52 |
+
[
|
53 |
+
'AnimeGAN_Hayao',
|
54 |
+
'AnimeGAN_Shinkai',
|
55 |
+
'AnimeGANv2_Hayao',
|
56 |
+
'AnimeGANv2_Shinkai',
|
57 |
+
'AnimeGANv2_Arcane',
|
58 |
+
],
|
59 |
+
type="value",
|
60 |
+
value='AnimeGANv2_Hayao',
|
61 |
+
label='Style'
|
62 |
+
),
|
63 |
+
gr.Dropdown(
|
64 |
+
[
|
65 |
+
None,
|
66 |
+
416,
|
67 |
+
512,
|
68 |
+
768,
|
69 |
+
1024,
|
70 |
+
1536,
|
71 |
+
],
|
72 |
+
type="value",
|
73 |
+
value=None,
|
74 |
+
label='Image size'
|
75 |
+
)
|
76 |
+
],
|
77 |
+
outputs=[
|
78 |
+
gr.components.Image(type="numpy", label="Output (The whole image)"),
|
79 |
+
gr.components.File(label="Download the output image")
|
80 |
+
],
|
81 |
+
title=title,
|
82 |
+
description=description,
|
83 |
+
article=article,
|
84 |
+
allow_flagging="never",
|
85 |
+
examples=[
|
86 |
+
['example/arcane/girl4.jpg', 'AnimeGANv2_Arcane', "Yes"],
|
87 |
+
['example/arcane/leo.jpg', 'AnimeGANv2_Arcane', "Yes"],
|
88 |
+
['example/arcane/girl.jpg', 'AnimeGANv2_Arcane', "Yes"],
|
89 |
+
['example/arcane/anne.jpg', 'AnimeGANv2_Arcane', "Yes"],
|
90 |
+
# ['example/boy2.jpg', 'AnimeGANv3_Arcane', "No"],
|
91 |
+
# ['example/cap.jpg', 'AnimeGANv3_Arcane', "No"],
|
92 |
+
['example/more/hayao_v2/pexels-camilacarneiro-6318793.jpg', 'AnimeGANv2_Hayao', "Yes"],
|
93 |
+
['example/more/hayao_v2/pexels-nandhukumar-450441.jpg', 'AnimeGANv2_Hayao', "Yes"],
|
94 |
+
]
|
95 |
+
).launch()
|
example/arcane/anne.jpg
ADDED
example/arcane/boy2.jpg
ADDED
example/arcane/cap.jpg
ADDED
example/arcane/dune2.jpg
ADDED
example/arcane/elon.jpg
ADDED
example/arcane/girl.jpg
ADDED
example/arcane/girl4.jpg
ADDED
example/arcane/girl6.jpg
ADDED
example/arcane/leo.jpg
ADDED
example/arcane/man2.jpg
ADDED
example/arcane/nat_.jpg
ADDED
example/arcane/seydoux.jpg
ADDED
example/arcane/tobey.jpg
ADDED
example/face/anne.jpg
ADDED
example/face/boy2.jpg
ADDED
example/face/cap.jpg
ADDED
example/face/dune2.jpg
ADDED
example/face/elon.jpg
ADDED
example/face/girl.jpg
ADDED
example/face/girl4.jpg
ADDED
example/face/girl6.jpg
ADDED
example/face/leo.jpg
ADDED
example/face/man2.jpg
ADDED
example/face/nat_.jpg
ADDED
example/face/seydoux.jpg
ADDED
example/face/tobey.jpg
ADDED
example/generate_examples.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import re
|
4 |
+
|
5 |
+
REG = re.compile(r"[0-9]{3}")
|
6 |
+
dir_ = './example/result'
|
7 |
+
readme = './README.md'
|
8 |
+
|
9 |
+
|
10 |
+
def anime_2_input(fi):
|
11 |
+
return fi.replace("_anime", "")
|
12 |
+
|
13 |
+
def rename(f):
|
14 |
+
return f.replace(" ", "").replace("(", "").replace(")", "")
|
15 |
+
|
16 |
+
def rename_back(f):
|
17 |
+
nums = REG.search(f)
|
18 |
+
if nums:
|
19 |
+
nums = nums.group()
|
20 |
+
return f.replace(nums, f"{nums[0]} ({nums[1:]})")
|
21 |
+
|
22 |
+
return f.replace('jpeg', 'jpg')
|
23 |
+
|
24 |
+
def copyfile(src, dest):
|
25 |
+
# copy and resize
|
26 |
+
im = cv2.imread(src)
|
27 |
+
|
28 |
+
if im is None:
|
29 |
+
raise FileNotFoundError(src)
|
30 |
+
|
31 |
+
h, w = im.shape[1], im.shape[0]
|
32 |
+
|
33 |
+
s = 448
|
34 |
+
size = (s, round(s * w / h))
|
35 |
+
im = cv2.resize(im, size)
|
36 |
+
|
37 |
+
print(w, h, im.shape)
|
38 |
+
cv2.imwrite(dest, im)
|
39 |
+
|
40 |
+
files = os.listdir(dir_)
|
41 |
+
new_files = []
|
42 |
+
for f in files:
|
43 |
+
input_ver = os.path.join(dir_, anime_2_input(f))
|
44 |
+
copyfile(f"dataset/test/HR_photo/{rename_back(anime_2_input(f))}", rename(input_ver))
|
45 |
+
|
46 |
+
os.rename(
|
47 |
+
os.path.join(dir_, f),
|
48 |
+
os.path.join(dir_, rename(f))
|
49 |
+
)
|
example/more/hayao_v2/pexels-arnie-chou-304906-1004122.jpg
ADDED
example/more/hayao_v2/pexels-camilacarneiro-6318793.jpg
ADDED
example/more/hayao_v2/pexels-haohd-19859127.jpg
ADDED
example/more/hayao_v2/pexels-huy-nguyen-748440234-19838813.jpg
ADDED
example/more/hayao_v2/pexels-huy-phan-316220-1422386.jpg
ADDED
example/more/hayao_v2/pexels-jimmy-teoh-294331-951531.jpg
ADDED
example/more/hayao_v2/pexels-nandhukumar-450441.jpg
ADDED
example/more/hayao_v2/pexels-sevenstormphotography-575362.jpg
ADDED
inference.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from models.anime_gan import GeneratorV1
|
10 |
+
from models.anime_gan_v2 import GeneratorV2
|
11 |
+
from models.anime_gan_v3 import GeneratorV3
|
12 |
+
from utils.common import load_checkpoint, RELEASED_WEIGHTS
|
13 |
+
from utils.image_processing import resize_image, normalize_input, denormalize_input
|
14 |
+
from utils import read_image, is_image_file, is_video_file
|
15 |
+
from tqdm import tqdm
|
16 |
+
from color_transfer import color_transfer_pytorch
|
17 |
+
|
18 |
+
|
19 |
+
try:
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
except ImportError:
|
22 |
+
plt = None
|
23 |
+
|
24 |
+
try:
|
25 |
+
import moviepy.video.io.ffmpeg_writer as ffmpeg_writer
|
26 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
27 |
+
except ImportError:
|
28 |
+
ffmpeg_writer = None
|
29 |
+
VideoFileClip = None
|
30 |
+
|
31 |
+
|
32 |
+
def profile(func):
|
33 |
+
def wrap(*args, **kwargs):
|
34 |
+
started_at = time.time()
|
35 |
+
result = func(*args, **kwargs)
|
36 |
+
elapsed = time.time() - started_at
|
37 |
+
print(f"Processed in {elapsed:.3f}s")
|
38 |
+
return result
|
39 |
+
return wrap
|
40 |
+
|
41 |
+
|
42 |
+
def auto_load_weight(weight, version=None, map_location=None):
|
43 |
+
"""Auto load Generator version from weight."""
|
44 |
+
weight_name = os.path.basename(weight).lower()
|
45 |
+
if version is not None:
|
46 |
+
version = version.lower()
|
47 |
+
assert version in {"v1", "v2", "v3"}, f"Version {version} does not exist"
|
48 |
+
# If version is provided, use it.
|
49 |
+
cls = {
|
50 |
+
"v1": GeneratorV1,
|
51 |
+
"v2": GeneratorV2,
|
52 |
+
"v3": GeneratorV3
|
53 |
+
}[version]
|
54 |
+
else:
|
55 |
+
# Try to get class by name of weight file
|
56 |
+
# For convenenice, weight should start with classname
|
57 |
+
# e.g: Generatorv2_{anything}.pt
|
58 |
+
if weight_name in RELEASED_WEIGHTS:
|
59 |
+
version = RELEASED_WEIGHTS[weight_name][0]
|
60 |
+
return auto_load_weight(weight, version=version, map_location=map_location)
|
61 |
+
|
62 |
+
elif weight_name.startswith("generatorv2"):
|
63 |
+
cls = GeneratorV2
|
64 |
+
elif weight_name.startswith("generatorv3"):
|
65 |
+
cls = GeneratorV3
|
66 |
+
elif weight_name.startswith("generator"):
|
67 |
+
cls = GeneratorV1
|
68 |
+
else:
|
69 |
+
raise ValueError((f"Can not get Model from {weight_name}, "
|
70 |
+
"you might need to explicitly specify version"))
|
71 |
+
model = cls()
|
72 |
+
load_checkpoint(model, weight, strip_optimizer=True, map_location=map_location)
|
73 |
+
model.eval()
|
74 |
+
return model
|
75 |
+
|
76 |
+
|
77 |
+
class Predictor:
|
78 |
+
"""
|
79 |
+
Generic class for transfering Image to anime like image.
|
80 |
+
"""
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
weight='hayao',
|
84 |
+
device='cuda',
|
85 |
+
amp=True,
|
86 |
+
retain_color=False,
|
87 |
+
imgsz=None,
|
88 |
+
):
|
89 |
+
if not torch.cuda.is_available():
|
90 |
+
device = 'cpu'
|
91 |
+
# Amp not working on cpu
|
92 |
+
amp = False
|
93 |
+
print("Use CPU device")
|
94 |
+
else:
|
95 |
+
print(f"Use GPU {torch.cuda.get_device_name()}")
|
96 |
+
|
97 |
+
self.imgsz = imgsz
|
98 |
+
self.retain_color = retain_color
|
99 |
+
self.amp = amp # Automatic Mixed Precision
|
100 |
+
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
|
101 |
+
self.device = torch.device(device)
|
102 |
+
self.G = auto_load_weight(weight, map_location=device)
|
103 |
+
self.G.to(self.device)
|
104 |
+
|
105 |
+
def transform_and_show(
|
106 |
+
self,
|
107 |
+
image_path,
|
108 |
+
figsize=(18, 10),
|
109 |
+
save_path=None
|
110 |
+
):
|
111 |
+
image = resize_image(read_image(image_path))
|
112 |
+
anime_img = self.transform(image)
|
113 |
+
anime_img = anime_img.astype('uint8')
|
114 |
+
|
115 |
+
fig = plt.figure(figsize=figsize)
|
116 |
+
fig.add_subplot(1, 2, 1)
|
117 |
+
# plt.title("Input")
|
118 |
+
plt.imshow(image)
|
119 |
+
plt.axis('off')
|
120 |
+
fig.add_subplot(1, 2, 2)
|
121 |
+
# plt.title("Anime style")
|
122 |
+
plt.imshow(anime_img[0])
|
123 |
+
plt.axis('off')
|
124 |
+
plt.tight_layout()
|
125 |
+
plt.show()
|
126 |
+
if save_path is not None:
|
127 |
+
plt.savefig(save_path)
|
128 |
+
|
129 |
+
def transform(self, image, denorm=True):
|
130 |
+
'''
|
131 |
+
Transform a image to animation
|
132 |
+
|
133 |
+
@Arguments:
|
134 |
+
- image: np.array, shape = (Batch, width, height, channels)
|
135 |
+
|
136 |
+
@Returns:
|
137 |
+
- anime version of image: np.array
|
138 |
+
'''
|
139 |
+
with torch.no_grad():
|
140 |
+
image = self.preprocess_images(image)
|
141 |
+
# image = image.to(self.device)
|
142 |
+
# with autocast(self.device_type, enabled=self.amp):
|
143 |
+
# print(image.dtype, self.G)
|
144 |
+
fake = self.G(image)
|
145 |
+
# Transfer color of fake image look similiar color as image
|
146 |
+
if self.retain_color:
|
147 |
+
fake = color_transfer_pytorch(fake, image)
|
148 |
+
fake = (fake / 0.5) - 1.0 # remap to [-1. 1]
|
149 |
+
fake = fake.detach().cpu().numpy()
|
150 |
+
# Channel last
|
151 |
+
fake = fake.transpose(0, 2, 3, 1)
|
152 |
+
|
153 |
+
if denorm:
|
154 |
+
fake = denormalize_input(fake, dtype=np.uint8)
|
155 |
+
return fake
|
156 |
+
|
157 |
+
def read_and_resize(self, path, max_size=1536):
|
158 |
+
image = read_image(path)
|
159 |
+
_, ext = os.path.splitext(path)
|
160 |
+
h, w = image.shape[:2]
|
161 |
+
if self.imgsz is not None:
|
162 |
+
image = resize_image(image, width=self.imgsz)
|
163 |
+
elif max(h, w) > max_size:
|
164 |
+
print(f"Image {os.path.basename(path)} is too big ({h}x{w}), resize to max size {max_size}")
|
165 |
+
image = resize_image(
|
166 |
+
image,
|
167 |
+
width=max_size if w > h else None,
|
168 |
+
height=max_size if w < h else None,
|
169 |
+
)
|
170 |
+
cv2.imwrite(path.replace(ext, ".jpg"), image[:,:,::-1])
|
171 |
+
else:
|
172 |
+
image = resize_image(image)
|
173 |
+
# image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
174 |
+
# image = np.stack([image, image, image], -1)
|
175 |
+
# cv2.imwrite(path.replace(ext, ".jpg"), image[:,:,::-1])
|
176 |
+
return image
|
177 |
+
|
178 |
+
@profile
|
179 |
+
def transform_file(self, file_path, save_path):
|
180 |
+
if not is_image_file(save_path):
|
181 |
+
raise ValueError(f"{save_path} is not valid")
|
182 |
+
|
183 |
+
image = self.read_and_resize(file_path)
|
184 |
+
anime_img = self.transform(image)[0]
|
185 |
+
cv2.imwrite(save_path, anime_img[..., ::-1])
|
186 |
+
print(f"Anime image saved to {save_path}")
|
187 |
+
return anime_img
|
188 |
+
|
189 |
+
@profile
|
190 |
+
def transform_gif(self, file_path, save_path, batch_size=4):
|
191 |
+
import imageio
|
192 |
+
|
193 |
+
def _preprocess_gif(img):
|
194 |
+
if img.shape[-1] == 4:
|
195 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
|
196 |
+
return resize_image(img)
|
197 |
+
|
198 |
+
images = imageio.mimread(file_path)
|
199 |
+
images = np.stack([
|
200 |
+
_preprocess_gif(img)
|
201 |
+
for img in images
|
202 |
+
])
|
203 |
+
|
204 |
+
print(images.shape)
|
205 |
+
|
206 |
+
anime_gif = np.zeros_like(images)
|
207 |
+
|
208 |
+
for i in tqdm(range(0, len(images), batch_size)):
|
209 |
+
end = i + batch_size
|
210 |
+
anime_gif[i: end] = self.transform(
|
211 |
+
images[i: end]
|
212 |
+
)
|
213 |
+
|
214 |
+
if end < len(images) - 1:
|
215 |
+
# transform last frame
|
216 |
+
print("LAST", images[end: ].shape)
|
217 |
+
anime_gif[end:] = self.transform(images[end:])
|
218 |
+
|
219 |
+
print(anime_gif.shape)
|
220 |
+
imageio.mimsave(
|
221 |
+
save_path,
|
222 |
+
anime_gif,
|
223 |
+
|
224 |
+
)
|
225 |
+
print(f"Anime image saved to {save_path}")
|
226 |
+
|
227 |
+
@profile
|
228 |
+
def transform_in_dir(self, img_dir, dest_dir, max_images=0, img_size=(512, 512)):
|
229 |
+
'''
|
230 |
+
Read all images from img_dir, transform and write the result
|
231 |
+
to dest_dir
|
232 |
+
|
233 |
+
'''
|
234 |
+
os.makedirs(dest_dir, exist_ok=True)
|
235 |
+
|
236 |
+
files = os.listdir(img_dir)
|
237 |
+
files = [f for f in files if is_image_file(f)]
|
238 |
+
print(f'Found {len(files)} images in {img_dir}')
|
239 |
+
|
240 |
+
if max_images:
|
241 |
+
files = files[:max_images]
|
242 |
+
|
243 |
+
bar = tqdm(files)
|
244 |
+
for fname in bar:
|
245 |
+
path = os.path.join(img_dir, fname)
|
246 |
+
image = self.read_and_resize(path)
|
247 |
+
anime_img = self.transform(image)[0]
|
248 |
+
# anime_img = resize_image(anime_img, width=320)
|
249 |
+
ext = fname.split('.')[-1]
|
250 |
+
fname = fname.replace(f'.{ext}', '')
|
251 |
+
cv2.imwrite(os.path.join(dest_dir, f'{fname}.jpg'), anime_img[..., ::-1])
|
252 |
+
bar.set_description(f"{fname} {image.shape}")
|
253 |
+
|
254 |
+
def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0):
|
255 |
+
'''
|
256 |
+
Transform a video to animation version
|
257 |
+
https://github.com/lengstrom/fast-style-transfer/blob/master/evaluate.py#L21
|
258 |
+
'''
|
259 |
+
if VideoFileClip is None:
|
260 |
+
raise ImportError("moviepy is not installed, please install with `pip install moviepy>=1.0.3`")
|
261 |
+
# Force to None
|
262 |
+
end = end or None
|
263 |
+
|
264 |
+
if not os.path.isfile(input_path):
|
265 |
+
raise FileNotFoundError(f'{input_path} does not exist')
|
266 |
+
|
267 |
+
output_dir = os.path.dirname(output_path)
|
268 |
+
if output_dir:
|
269 |
+
os.makedirs(output_dir, exist_ok=True)
|
270 |
+
|
271 |
+
is_gg_drive = '/drive/' in output_path
|
272 |
+
temp_file = ''
|
273 |
+
|
274 |
+
if is_gg_drive:
|
275 |
+
# Writing directly into google drive can be inefficient
|
276 |
+
temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
|
277 |
+
|
278 |
+
def transform_and_write(frames, count, writer):
|
279 |
+
anime_images = self.transform(frames)
|
280 |
+
for i in range(0, count):
|
281 |
+
img = np.clip(anime_images[i], 0, 255)
|
282 |
+
writer.write_frame(img)
|
283 |
+
|
284 |
+
video_clip = VideoFileClip(input_path, audio=False)
|
285 |
+
if start or end:
|
286 |
+
video_clip = video_clip.subclip(start, end)
|
287 |
+
|
288 |
+
video_writer = ffmpeg_writer.FFMPEG_VideoWriter(
|
289 |
+
temp_file or output_path,
|
290 |
+
video_clip.size, video_clip.fps,
|
291 |
+
codec="libx264",
|
292 |
+
# preset="medium", bitrate="2000k",
|
293 |
+
ffmpeg_params=None)
|
294 |
+
|
295 |
+
total_frames = round(video_clip.fps * video_clip.duration)
|
296 |
+
print(f'Transfroming video {input_path}, {total_frames} frames, size: {video_clip.size}')
|
297 |
+
|
298 |
+
batch_shape = (batch_size, video_clip.size[1], video_clip.size[0], 3)
|
299 |
+
frame_count = 0
|
300 |
+
frames = np.zeros(batch_shape, dtype=np.float32)
|
301 |
+
for frame in tqdm(video_clip.iter_frames(), total=total_frames):
|
302 |
+
try:
|
303 |
+
frames[frame_count] = frame
|
304 |
+
frame_count += 1
|
305 |
+
if frame_count == batch_size:
|
306 |
+
transform_and_write(frames, frame_count, video_writer)
|
307 |
+
frame_count = 0
|
308 |
+
except Exception as e:
|
309 |
+
print(e)
|
310 |
+
break
|
311 |
+
|
312 |
+
# The last frames
|
313 |
+
if frame_count != 0:
|
314 |
+
transform_and_write(frames, frame_count, video_writer)
|
315 |
+
|
316 |
+
if temp_file:
|
317 |
+
# move to output path
|
318 |
+
shutil.move(temp_file, output_path)
|
319 |
+
|
320 |
+
print(f'Animation video saved to {output_path}')
|
321 |
+
video_writer.close()
|
322 |
+
|
323 |
+
def preprocess_images(self, images):
|
324 |
+
'''
|
325 |
+
Preprocess image for inference
|
326 |
+
|
327 |
+
@Arguments:
|
328 |
+
- images: np.ndarray
|
329 |
+
|
330 |
+
@Returns
|
331 |
+
- images: torch.tensor
|
332 |
+
'''
|
333 |
+
images = images.astype(np.float32)
|
334 |
+
|
335 |
+
# Normalize to [-1, 1]
|
336 |
+
images = normalize_input(images)
|
337 |
+
images = torch.from_numpy(images)
|
338 |
+
|
339 |
+
images = images.to(self.device)
|
340 |
+
|
341 |
+
# Add batch dim
|
342 |
+
if len(images.shape) == 3:
|
343 |
+
images = images.unsqueeze(0)
|
344 |
+
|
345 |
+
# channel first
|
346 |
+
images = images.permute(0, 3, 1, 2)
|
347 |
+
|
348 |
+
return images
|
349 |
+
|
350 |
+
|
351 |
+
def parse_args():
|
352 |
+
import argparse
|
353 |
+
parser = argparse.ArgumentParser()
|
354 |
+
parser.add_argument(
|
355 |
+
'--weight',
|
356 |
+
type=str,
|
357 |
+
default="hayao:v2",
|
358 |
+
help=f'Model weight, can be path or pretrained {tuple(RELEASED_WEIGHTS.keys())}'
|
359 |
+
)
|
360 |
+
parser.add_argument('--src', type=str, help='Source, can be directory contains images, image file or video file.')
|
361 |
+
parser.add_argument('--device', type=str, default='cuda', help='Device, cuda or cpu')
|
362 |
+
parser.add_argument('--imgsz', type=int, default=None, help='Resize image to specified size if provided')
|
363 |
+
parser.add_argument('--out', type=str, default='inference_images', help='Output, can be directory or file')
|
364 |
+
parser.add_argument(
|
365 |
+
'--retain-color',
|
366 |
+
action='store_true',
|
367 |
+
help='If provided the generated image will retain original color of input image')
|
368 |
+
# Video params
|
369 |
+
parser.add_argument('--batch-size', type=int, default=4, help='Batch size when inference video')
|
370 |
+
parser.add_argument('--start', type=int, default=0, help='Start time of video (second)')
|
371 |
+
parser.add_argument('--end', type=int, default=0, help='End time of video (second), 0 if not set')
|
372 |
+
|
373 |
+
return parser.parse_args()
|
374 |
+
|
375 |
+
if __name__ == '__main__':
|
376 |
+
args = parse_args()
|
377 |
+
|
378 |
+
predictor = Predictor(
|
379 |
+
args.weight,
|
380 |
+
args.device,
|
381 |
+
retain_color=args.retain_color,
|
382 |
+
imgsz=args.imgsz,
|
383 |
+
)
|
384 |
+
|
385 |
+
if not os.path.exists(args.src):
|
386 |
+
raise FileNotFoundError(args.src)
|
387 |
+
|
388 |
+
if is_video_file(args.src):
|
389 |
+
predictor.transform_video(
|
390 |
+
args.src,
|
391 |
+
args.out,
|
392 |
+
args.batch_size,
|
393 |
+
start=args.start,
|
394 |
+
end=args.end
|
395 |
+
)
|
396 |
+
elif os.path.isdir(args.src):
|
397 |
+
predictor.transform_in_dir(args.src, args.out)
|
398 |
+
elif os.path.isfile(args.src):
|
399 |
+
save_path = args.out
|
400 |
+
if not is_image_file(args.out):
|
401 |
+
os.makedirs(args.out, exist_ok=True)
|
402 |
+
save_path = os.path.join(args.out, os.path.basename(args.src))
|
403 |
+
|
404 |
+
if args.src.endswith('.gif'):
|
405 |
+
# GIF file
|
406 |
+
predictor.transform_gif(args.src, save_path, args.batch_size)
|
407 |
+
else:
|
408 |
+
predictor.transform_file(args.src, save_path)
|
409 |
+
else:
|
410 |
+
raise NotImplementedError(f"{args.src} is not supported")
|
losses.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from models.vgg import Vgg19
|
5 |
+
from utils.image_processing import gram
|
6 |
+
|
7 |
+
|
8 |
+
def to_gray_scale(image):
|
9 |
+
# https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/_color.py#L33
|
10 |
+
# Image are assum in range 1, -1
|
11 |
+
image = (image + 1.0) / 2.0 # To [0, 1]
|
12 |
+
r, g, b = image.unbind(dim=-3)
|
13 |
+
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
|
14 |
+
l_img = l_img.unsqueeze(dim=-3)
|
15 |
+
l_img = l_img.to(image.dtype)
|
16 |
+
l_img = l_img.expand(image.shape)
|
17 |
+
l_img = l_img / 0.5 - 1.0 # To [-1, 1]
|
18 |
+
return l_img
|
19 |
+
|
20 |
+
|
21 |
+
class ColorLoss(nn.Module):
|
22 |
+
def __init__(self):
|
23 |
+
super(ColorLoss, self).__init__()
|
24 |
+
self.l1 = nn.L1Loss()
|
25 |
+
self.huber = nn.SmoothL1Loss()
|
26 |
+
# self._rgb_to_yuv_kernel = torch.tensor([
|
27 |
+
# [0.299, -0.14714119, 0.61497538],
|
28 |
+
# [0.587, -0.28886916, -0.51496512],
|
29 |
+
# [0.114, 0.43601035, -0.10001026]
|
30 |
+
# ]).float()
|
31 |
+
|
32 |
+
self._rgb_to_yuv_kernel = torch.tensor([
|
33 |
+
[0.299, 0.587, 0.114],
|
34 |
+
[-0.14714119, -0.28886916, 0.43601035],
|
35 |
+
[0.61497538, -0.51496512, -0.10001026],
|
36 |
+
]).float()
|
37 |
+
|
38 |
+
def to(self, device):
|
39 |
+
new_self = super(ColorLoss, self).to(device)
|
40 |
+
new_self._rgb_to_yuv_kernel = new_self._rgb_to_yuv_kernel.to(device)
|
41 |
+
return new_self
|
42 |
+
|
43 |
+
def rgb_to_yuv(self, image):
|
44 |
+
'''
|
45 |
+
https://en.wikipedia.org/wiki/YUV
|
46 |
+
|
47 |
+
output: Image of shape (H, W, C) (channel last)
|
48 |
+
'''
|
49 |
+
# -1 1 -> 0 1
|
50 |
+
image = (image + 1.0) / 2.0
|
51 |
+
image = image.permute(0, 2, 3, 1) # To channel last
|
52 |
+
|
53 |
+
yuv_img = image @ self._rgb_to_yuv_kernel.T
|
54 |
+
|
55 |
+
return yuv_img
|
56 |
+
|
57 |
+
def forward(self, image, image_g):
|
58 |
+
image = self.rgb_to_yuv(image)
|
59 |
+
image_g = self.rgb_to_yuv(image_g)
|
60 |
+
# After convert to yuv, both images have channel last
|
61 |
+
return (
|
62 |
+
self.l1(image[:, :, :, 0], image_g[:, :, :, 0])
|
63 |
+
+ self.huber(image[:, :, :, 1], image_g[:, :, :, 1])
|
64 |
+
+ self.huber(image[:, :, :, 2], image_g[:, :, :, 2])
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
class AnimeGanLoss:
|
69 |
+
def __init__(self, args, device, gray_adv=False):
|
70 |
+
if isinstance(device, str):
|
71 |
+
device = torch.device(device)
|
72 |
+
|
73 |
+
self.content_loss = nn.L1Loss().to(device)
|
74 |
+
self.gram_loss = nn.L1Loss().to(device)
|
75 |
+
self.color_loss = ColorLoss().to(device)
|
76 |
+
self.wadvg = args.wadvg
|
77 |
+
self.wadvd = args.wadvd
|
78 |
+
self.wcon = args.wcon
|
79 |
+
self.wgra = args.wgra
|
80 |
+
self.wcol = args.wcol
|
81 |
+
self.wtvar = args.wtvar
|
82 |
+
# If true, use gray scale image to calculate adversarial loss
|
83 |
+
self.gray_adv = gray_adv
|
84 |
+
self.vgg19 = Vgg19().to(device).eval()
|
85 |
+
self.adv_type = args.gan_loss
|
86 |
+
self.bce_loss = nn.BCEWithLogitsLoss()
|
87 |
+
|
88 |
+
def compute_loss_G(self, fake_img, img, fake_logit, anime_gray):
|
89 |
+
'''
|
90 |
+
Compute loss for Generator
|
91 |
+
|
92 |
+
@Args:
|
93 |
+
- fake_img: generated image
|
94 |
+
- img: real image
|
95 |
+
- fake_logit: output of Discriminator given fake image
|
96 |
+
- anime_gray: grayscale of anime image
|
97 |
+
|
98 |
+
@Returns:
|
99 |
+
- Adversarial Loss of fake logits
|
100 |
+
- Content loss between real and fake features (vgg19)
|
101 |
+
- Gram loss between anime and fake features (Vgg19)
|
102 |
+
- Color loss between image and fake image
|
103 |
+
- Total variation loss of fake image
|
104 |
+
'''
|
105 |
+
fake_feat = self.vgg19(fake_img)
|
106 |
+
gray_feat = self.vgg19(anime_gray)
|
107 |
+
img_feat = self.vgg19(img)
|
108 |
+
# fake_gray_feat = self.vgg19(to_gray_scale(fake_img))
|
109 |
+
|
110 |
+
return [
|
111 |
+
# Want to be real image.
|
112 |
+
self.wadvg * self.adv_loss_g(fake_logit),
|
113 |
+
self.wcon * self.content_loss(img_feat, fake_feat),
|
114 |
+
self.wgra * self.gram_loss(gram(gray_feat), gram(fake_feat)),
|
115 |
+
self.wcol * self.color_loss(img, fake_img),
|
116 |
+
self.wtvar * self.total_variation_loss(fake_img)
|
117 |
+
]
|
118 |
+
|
119 |
+
def compute_loss_D(
|
120 |
+
self,
|
121 |
+
fake_img_d,
|
122 |
+
real_anime_d,
|
123 |
+
real_anime_gray_d,
|
124 |
+
real_anime_smooth_gray_d=None
|
125 |
+
):
|
126 |
+
if self.gray_adv:
|
127 |
+
# Treat gray scale image as real
|
128 |
+
return (
|
129 |
+
self.adv_loss_d_real(real_anime_gray_d)
|
130 |
+
+ self.adv_loss_d_fake(fake_img_d)
|
131 |
+
+ 0.3 * self.adv_loss_d_fake(real_anime_smooth_gray_d)
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
return (
|
135 |
+
# Classify real anime as real
|
136 |
+
self.adv_loss_d_real(real_anime_d)
|
137 |
+
# Classify generated as fake
|
138 |
+
+ self.adv_loss_d_fake(fake_img_d)
|
139 |
+
# Classify real anime gray as fake
|
140 |
+
# + self.adv_loss_d_fake(real_anime_gray_d)
|
141 |
+
# Classify real anime as fake
|
142 |
+
# + 0.1 * self.adv_loss_d_fake(real_anime_smooth_gray_d)
|
143 |
+
)
|
144 |
+
|
145 |
+
def total_variation_loss(self, fake_img):
|
146 |
+
"""
|
147 |
+
A smooth loss in fact. Like the smooth prior in MRF.
|
148 |
+
V(y) = || y_{n+1} - y_n ||_2
|
149 |
+
"""
|
150 |
+
# Channel first -> channel last
|
151 |
+
fake_img = fake_img.permute(0, 2, 3, 1)
|
152 |
+
def _l2(x):
|
153 |
+
# sum(t ** 2) / 2
|
154 |
+
return torch.sum(x ** 2) / 2
|
155 |
+
|
156 |
+
dh = fake_img[:, :-1, ...] - fake_img[:, 1:, ...]
|
157 |
+
dw = fake_img[:, :, :-1, ...] - fake_img[:, :, 1:, ...]
|
158 |
+
return _l2(dh) / dh.numel() + _l2(dw) / dw.numel()
|
159 |
+
|
160 |
+
def content_loss_vgg(self, image, recontruction):
|
161 |
+
feat = self.vgg19(image)
|
162 |
+
re_feat = self.vgg19(recontruction)
|
163 |
+
feature_loss = self.content_loss(feat, re_feat)
|
164 |
+
content_loss = self.content_loss(image, recontruction)
|
165 |
+
return feature_loss# + 0.5 * content_loss
|
166 |
+
|
167 |
+
def adv_loss_d_real(self, pred):
|
168 |
+
"""Push pred to class 1 (real)"""
|
169 |
+
if self.adv_type == 'hinge':
|
170 |
+
return torch.mean(F.relu(1.0 - pred))
|
171 |
+
|
172 |
+
elif self.adv_type == 'lsgan':
|
173 |
+
# pred = torch.sigmoid(pred)
|
174 |
+
return torch.mean(torch.square(pred - 1.0))
|
175 |
+
|
176 |
+
elif self.adv_type == 'bce':
|
177 |
+
return self.bce_loss(pred, torch.ones_like(pred))
|
178 |
+
|
179 |
+
raise ValueError(f'Do not support loss type {self.adv_type}')
|
180 |
+
|
181 |
+
def adv_loss_d_fake(self, pred):
|
182 |
+
"""Push pred to class 0 (fake)"""
|
183 |
+
if self.adv_type == 'hinge':
|
184 |
+
return torch.mean(F.relu(1.0 + pred))
|
185 |
+
|
186 |
+
elif self.adv_type == 'lsgan':
|
187 |
+
# pred = torch.sigmoid(pred)
|
188 |
+
return torch.mean(torch.square(pred))
|
189 |
+
|
190 |
+
elif self.adv_type == 'bce':
|
191 |
+
return self.bce_loss(pred, torch.zeros_like(pred))
|
192 |
+
|
193 |
+
raise ValueError(f'Do not support loss type {self.adv_type}')
|
194 |
+
|
195 |
+
def adv_loss_g(self, pred):
|
196 |
+
"""Push pred to class 1 (real)"""
|
197 |
+
if self.adv_type == 'hinge':
|
198 |
+
return -torch.mean(pred)
|
199 |
+
|
200 |
+
elif self.adv_type == 'lsgan':
|
201 |
+
# pred = torch.sigmoid(pred)
|
202 |
+
return torch.mean(torch.square(pred - 1.0))
|
203 |
+
|
204 |
+
elif self.adv_type == 'bce':
|
205 |
+
return self.bce_loss(pred, torch.ones_like(pred))
|
206 |
+
|
207 |
+
raise ValueError(f'Do not support loss type {self.adv_type}')
|
208 |
+
|
209 |
+
|
210 |
+
class LossSummary:
|
211 |
+
def __init__(self):
|
212 |
+
self.reset()
|
213 |
+
|
214 |
+
def reset(self):
|
215 |
+
self.loss_g_adv = []
|
216 |
+
self.loss_content = []
|
217 |
+
self.loss_gram = []
|
218 |
+
self.loss_color = []
|
219 |
+
self.loss_d_adv = []
|
220 |
+
|
221 |
+
def update_loss_G(self, adv, gram, color, content):
|
222 |
+
self.loss_g_adv.append(adv.cpu().detach().numpy())
|
223 |
+
self.loss_gram.append(gram.cpu().detach().numpy())
|
224 |
+
self.loss_color.append(color.cpu().detach().numpy())
|
225 |
+
self.loss_content.append(content.cpu().detach().numpy())
|
226 |
+
|
227 |
+
def update_loss_D(self, loss):
|
228 |
+
self.loss_d_adv.append(loss.cpu().detach().numpy())
|
229 |
+
|
230 |
+
def avg_loss_G(self):
|
231 |
+
return (
|
232 |
+
self._avg(self.loss_g_adv),
|
233 |
+
self._avg(self.loss_gram),
|
234 |
+
self._avg(self.loss_color),
|
235 |
+
self._avg(self.loss_content),
|
236 |
+
)
|
237 |
+
|
238 |
+
def avg_loss_D(self):
|
239 |
+
return self._avg(self.loss_d_adv)
|
240 |
+
|
241 |
+
def get_loss_description(self):
|
242 |
+
avg_adv, avg_gram, avg_color, avg_content = self.avg_loss_G()
|
243 |
+
avg_adv_d = self.avg_loss_D()
|
244 |
+
return f'loss G: adv {avg_adv:2f} con {avg_content:2f} gram {avg_gram:2f} color {avg_color:2f} / loss D: {avg_adv_d:2f}'
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def _avg(losses):
|
248 |
+
return sum(losses) / len(losses)
|
models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .anime_gan import GeneratorV1
|
2 |
+
from .anime_gan_v2 import GeneratorV2
|
3 |
+
from .anime_gan_v3 import GeneratorV3
|
models/anime_gan.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.utils import spectral_norm
|
5 |
+
from .conv_blocks import DownConv
|
6 |
+
from .conv_blocks import UpConv
|
7 |
+
from .conv_blocks import SeparableConv2D
|
8 |
+
from .conv_blocks import InvertedResBlock
|
9 |
+
from .conv_blocks import ConvBlock
|
10 |
+
from .layers import get_norm
|
11 |
+
from utils.common import initialize_weights
|
12 |
+
|
13 |
+
|
14 |
+
class GeneratorV1(nn.Module):
|
15 |
+
def __init__(self, dataset=''):
|
16 |
+
super(GeneratorV1, self).__init__()
|
17 |
+
self.name = f'{self.__class__.__name__}_{dataset}'
|
18 |
+
bias = False
|
19 |
+
|
20 |
+
self.encode_blocks = nn.Sequential(
|
21 |
+
ConvBlock(3, 64, bias=bias),
|
22 |
+
ConvBlock(64, 128, bias=bias),
|
23 |
+
DownConv(128, bias=bias),
|
24 |
+
ConvBlock(128, 128, bias=bias),
|
25 |
+
SeparableConv2D(128, 256, bias=bias),
|
26 |
+
DownConv(256, bias=bias),
|
27 |
+
ConvBlock(256, 256, bias=bias),
|
28 |
+
)
|
29 |
+
|
30 |
+
self.res_blocks = nn.Sequential(
|
31 |
+
InvertedResBlock(256, 256),
|
32 |
+
InvertedResBlock(256, 256),
|
33 |
+
InvertedResBlock(256, 256),
|
34 |
+
InvertedResBlock(256, 256),
|
35 |
+
InvertedResBlock(256, 256),
|
36 |
+
InvertedResBlock(256, 256),
|
37 |
+
InvertedResBlock(256, 256),
|
38 |
+
InvertedResBlock(256, 256),
|
39 |
+
)
|
40 |
+
|
41 |
+
self.decode_blocks = nn.Sequential(
|
42 |
+
ConvBlock(256, 128, bias=bias),
|
43 |
+
UpConv(128, bias=bias),
|
44 |
+
SeparableConv2D(128, 128, bias=bias),
|
45 |
+
ConvBlock(128, 128, bias=bias),
|
46 |
+
UpConv(128, bias=bias),
|
47 |
+
ConvBlock(128, 64, bias=bias),
|
48 |
+
ConvBlock(64, 64, bias=bias),
|
49 |
+
nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, bias=bias),
|
50 |
+
nn.Tanh(),
|
51 |
+
)
|
52 |
+
|
53 |
+
initialize_weights(self)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
out = self.encode_blocks(x)
|
57 |
+
out = self.res_blocks(out)
|
58 |
+
img = self.decode_blocks(out)
|
59 |
+
|
60 |
+
return img
|
61 |
+
|
62 |
+
|
63 |
+
class Discriminator(nn.Module):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
dataset=None,
|
67 |
+
num_layers=1,
|
68 |
+
use_sn=False,
|
69 |
+
norm_type="instance",
|
70 |
+
):
|
71 |
+
super(Discriminator, self).__init__()
|
72 |
+
self.name = f'discriminator_{dataset}'
|
73 |
+
self.bias = False
|
74 |
+
channels = 32
|
75 |
+
|
76 |
+
layers = [
|
77 |
+
nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
78 |
+
nn.LeakyReLU(0.2, True)
|
79 |
+
]
|
80 |
+
|
81 |
+
in_channels = channels
|
82 |
+
for i in range(num_layers):
|
83 |
+
layers += [
|
84 |
+
nn.Conv2d(in_channels, channels * 2, kernel_size=3, stride=2, padding=1, bias=self.bias),
|
85 |
+
nn.LeakyReLU(0.2, True),
|
86 |
+
nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
87 |
+
get_norm(norm_type, channels * 4),
|
88 |
+
nn.LeakyReLU(0.2, True),
|
89 |
+
]
|
90 |
+
in_channels = channels * 4
|
91 |
+
channels *= 2
|
92 |
+
|
93 |
+
channels *= 2
|
94 |
+
layers += [
|
95 |
+
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
96 |
+
get_norm(norm_type, channels),
|
97 |
+
nn.LeakyReLU(0.2, True),
|
98 |
+
nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
99 |
+
]
|
100 |
+
|
101 |
+
if use_sn:
|
102 |
+
for i in range(len(layers)):
|
103 |
+
if isinstance(layers[i], nn.Conv2d):
|
104 |
+
layers[i] = spectral_norm(layers[i])
|
105 |
+
|
106 |
+
self.discriminate = nn.Sequential(*layers)
|
107 |
+
|
108 |
+
initialize_weights(self)
|
109 |
+
|
110 |
+
def forward(self, img):
|
111 |
+
logits = self.discriminate(img)
|
112 |
+
return logits
|
models/anime_gan_v2.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from models.conv_blocks import InvertedResBlock
|
5 |
+
from models.conv_blocks import ConvBlock
|
6 |
+
from models.conv_blocks import UpConvLNormLReLU
|
7 |
+
from utils.common import initialize_weights
|
8 |
+
|
9 |
+
|
10 |
+
class GeneratorV2(nn.Module):
|
11 |
+
def __init__(self, dataset=''):
|
12 |
+
super(GeneratorV2, self).__init__()
|
13 |
+
self.name = f'{self.__class__.__name__}_{dataset}'
|
14 |
+
|
15 |
+
self.conv_block1 = nn.Sequential(
|
16 |
+
ConvBlock(3, 32, kernel_size=7, stride=1, padding=3, norm_type="layer"),
|
17 |
+
ConvBlock(32, 64, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"),
|
18 |
+
ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"),
|
19 |
+
)
|
20 |
+
|
21 |
+
self.conv_block2 = nn.Sequential(
|
22 |
+
ConvBlock(64, 128, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"),
|
23 |
+
ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
|
24 |
+
)
|
25 |
+
|
26 |
+
self.res_blocks = nn.Sequential(
|
27 |
+
ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
|
28 |
+
InvertedResBlock(128, 256, expand_ratio=2, norm_type="layer"),
|
29 |
+
InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
|
30 |
+
InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
|
31 |
+
InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
|
32 |
+
ConvBlock(256, 128, kernel_size=3, stride=1, norm_type="layer"),
|
33 |
+
)
|
34 |
+
|
35 |
+
self.conv_block3 = nn.Sequential(
|
36 |
+
# UpConvLNormLReLU(128, 128, norm_type="layer"),
|
37 |
+
ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
|
38 |
+
ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
|
39 |
+
)
|
40 |
+
|
41 |
+
self.conv_block4 = nn.Sequential(
|
42 |
+
# UpConvLNormLReLU(128, 64, norm_type="layer"),
|
43 |
+
ConvBlock(128, 64, kernel_size=3, stride=1, norm_type="layer"),
|
44 |
+
ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"),
|
45 |
+
ConvBlock(64, 32, kernel_size=7, padding=3, stride=1, norm_type="layer"),
|
46 |
+
)
|
47 |
+
|
48 |
+
self.decode_blocks = nn.Sequential(
|
49 |
+
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
|
50 |
+
nn.Tanh(),
|
51 |
+
)
|
52 |
+
|
53 |
+
initialize_weights(self)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
out = self.conv_block1(x)
|
57 |
+
out = self.conv_block2(out)
|
58 |
+
out = self.res_blocks(out)
|
59 |
+
out = F.interpolate(out, scale_factor=2, mode="bilinear")
|
60 |
+
out = self.conv_block3(out)
|
61 |
+
out = F.interpolate(out, scale_factor=2, mode="bilinear")
|
62 |
+
out = self.conv_block4(out)
|
63 |
+
img = self.decode_blocks(out)
|
64 |
+
|
65 |
+
return img
|
models/anime_gan_v3.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.utils import spectral_norm
|
5 |
+
from models.conv_blocks import DownConv
|
6 |
+
from models.conv_blocks import UpConv
|
7 |
+
from models.conv_blocks import SeparableConv2D
|
8 |
+
from models.conv_blocks import InvertedResBlock
|
9 |
+
from models.conv_blocks import ConvBlock
|
10 |
+
from utils.common import initialize_weights
|
11 |
+
|
12 |
+
|
13 |
+
class GeneratorV3(nn.Module):
|
14 |
+
pass
|
models/conv_blocks.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from utils.common import initialize_weights
|
4 |
+
from .layers import LayerNorm2d, get_norm
|
5 |
+
|
6 |
+
|
7 |
+
class DownConv(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, channels, bias=False):
|
10 |
+
super(DownConv, self).__init__()
|
11 |
+
|
12 |
+
self.conv1 = SeparableConv2D(channels, channels, stride=2, bias=bias)
|
13 |
+
self.conv2 = SeparableConv2D(channels, channels, stride=1, bias=bias)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
out1 = self.conv1(x)
|
17 |
+
out2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')
|
18 |
+
out2 = self.conv2(out2)
|
19 |
+
|
20 |
+
return out1 + out2
|
21 |
+
|
22 |
+
|
23 |
+
class UpConv(nn.Module):
|
24 |
+
def __init__(self, channels, bias=False):
|
25 |
+
super(UpConv, self).__init__()
|
26 |
+
|
27 |
+
self.conv = SeparableConv2D(channels, channels, stride=1, bias=bias)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
out = F.interpolate(x, scale_factor=2.0, mode='bilinear')
|
31 |
+
out = self.conv(out)
|
32 |
+
return out
|
33 |
+
|
34 |
+
|
35 |
+
class UpConvLNormLReLU(nn.Module):
|
36 |
+
"""Upsample Conv block with Layer Norm and Leaky ReLU"""
|
37 |
+
def __init__(self, in_channels, out_channels, norm_type="instance", bias=False):
|
38 |
+
super(UpConvLNormLReLU, self).__init__()
|
39 |
+
|
40 |
+
self.conv_block = ConvBlock(
|
41 |
+
in_channels,
|
42 |
+
out_channels,
|
43 |
+
kernel_size=3,
|
44 |
+
norm_type=norm_type,
|
45 |
+
bias=bias,
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
out = F.interpolate(x, scale_factor=2.0, mode='bilinear')
|
50 |
+
out = self.conv_block(out)
|
51 |
+
return out
|
52 |
+
|
53 |
+
class SeparableConv2D(nn.Module):
|
54 |
+
def __init__(self, in_channels, out_channels, stride=1, bias=False):
|
55 |
+
super(SeparableConv2D, self).__init__()
|
56 |
+
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3,
|
57 |
+
stride=stride, padding=1, groups=in_channels, bias=bias)
|
58 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels,
|
59 |
+
kernel_size=1, stride=1, bias=bias)
|
60 |
+
# self.pad =
|
61 |
+
self.ins_norm1 = nn.InstanceNorm2d(in_channels)
|
62 |
+
self.activation1 = nn.LeakyReLU(0.2, True)
|
63 |
+
self.ins_norm2 = nn.InstanceNorm2d(out_channels)
|
64 |
+
self.activation2 = nn.LeakyReLU(0.2, True)
|
65 |
+
|
66 |
+
initialize_weights(self)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
out = self.depthwise(x)
|
70 |
+
out = self.ins_norm1(out)
|
71 |
+
out = self.activation1(out)
|
72 |
+
|
73 |
+
out = self.pointwise(out)
|
74 |
+
out = self.ins_norm2(out)
|
75 |
+
|
76 |
+
return self.activation2(out)
|
77 |
+
|
78 |
+
|
79 |
+
class ConvBlock(nn.Module):
|
80 |
+
"""Stack of Conv2D + Norm + LeakyReLU"""
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
channels,
|
84 |
+
out_channels,
|
85 |
+
kernel_size=3,
|
86 |
+
stride=1,
|
87 |
+
groups=1,
|
88 |
+
padding=1,
|
89 |
+
bias=False,
|
90 |
+
norm_type="instance"
|
91 |
+
):
|
92 |
+
super(ConvBlock, self).__init__()
|
93 |
+
|
94 |
+
# if kernel_size == 3 and stride == 1:
|
95 |
+
# self.pad = nn.ReflectionPad2d((1, 1, 1, 1))
|
96 |
+
# elif kernel_size == 7 and stride == 1:
|
97 |
+
# self.pad = nn.ReflectionPad2d((3, 3, 3, 3))
|
98 |
+
# elif stride == 2:
|
99 |
+
# self.pad = nn.ReflectionPad2d((0, 1, 1, 0))
|
100 |
+
# else:
|
101 |
+
# self.pad = None
|
102 |
+
|
103 |
+
self.pad = nn.ReflectionPad2d(padding)
|
104 |
+
self.conv = nn.Conv2d(
|
105 |
+
channels,
|
106 |
+
out_channels,
|
107 |
+
kernel_size=kernel_size,
|
108 |
+
stride=stride,
|
109 |
+
groups=groups,
|
110 |
+
padding=0,
|
111 |
+
bias=bias
|
112 |
+
)
|
113 |
+
self.ins_norm = get_norm(norm_type, out_channels)
|
114 |
+
self.activation = nn.LeakyReLU(0.2, True)
|
115 |
+
|
116 |
+
# initialize_weights(self)
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
if self.pad is not None:
|
120 |
+
x = self.pad(x)
|
121 |
+
out = self.conv(x)
|
122 |
+
out = self.ins_norm(out)
|
123 |
+
out = self.activation(out)
|
124 |
+
return out
|
125 |
+
|
126 |
+
|
127 |
+
class InvertedResBlock(nn.Module):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
channels=256,
|
131 |
+
out_channels=256,
|
132 |
+
expand_ratio=2,
|
133 |
+
norm_type="instance",
|
134 |
+
):
|
135 |
+
super(InvertedResBlock, self).__init__()
|
136 |
+
bottleneck_dim = round(expand_ratio * channels)
|
137 |
+
self.conv_block = ConvBlock(
|
138 |
+
channels,
|
139 |
+
bottleneck_dim,
|
140 |
+
kernel_size=1,
|
141 |
+
padding=0,
|
142 |
+
norm_type=norm_type,
|
143 |
+
bias=False
|
144 |
+
)
|
145 |
+
self.conv_block2 = ConvBlock(
|
146 |
+
bottleneck_dim,
|
147 |
+
bottleneck_dim,
|
148 |
+
groups=bottleneck_dim,
|
149 |
+
norm_type=norm_type,
|
150 |
+
bias=True
|
151 |
+
)
|
152 |
+
self.conv = nn.Conv2d(
|
153 |
+
bottleneck_dim,
|
154 |
+
out_channels,
|
155 |
+
kernel_size=1,
|
156 |
+
padding=0,
|
157 |
+
bias=False
|
158 |
+
)
|
159 |
+
self.norm = get_norm(norm_type, out_channels)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
out = self.conv_block(x)
|
163 |
+
out = self.conv_block2(out)
|
164 |
+
# out = self.activation(out)
|
165 |
+
out = self.conv(out)
|
166 |
+
out = self.norm(out)
|
167 |
+
|
168 |
+
if out.shape[1] != x.shape[1]:
|
169 |
+
# Only concate if same shape
|
170 |
+
return out
|
171 |
+
return out + x
|
models/layers.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
|
7 |
+
class LayerNorm2d(nn.LayerNorm):
|
8 |
+
""" LayerNorm for channels of '2D' spatial NCHW tensors """
|
9 |
+
def __init__(self, num_channels, eps=1e-6, affine=True):
|
10 |
+
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
11 |
+
|
12 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
13 |
+
# https://pytorch.org/vision/0.12/_modules/torchvision/models/convnext.html
|
14 |
+
x = x.permute(0, 2, 3, 1)
|
15 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
16 |
+
x = x.permute(0, 3, 1, 2)
|
17 |
+
return x
|
18 |
+
|
19 |
+
|
20 |
+
def get_norm(norm_type, channels):
|
21 |
+
if norm_type == "instance":
|
22 |
+
return nn.InstanceNorm2d(channels)
|
23 |
+
elif norm_type == "layer":
|
24 |
+
# return LayerNorm2d
|
25 |
+
return nn.GroupNorm(num_groups=1, num_channels=channels, affine=True)
|
26 |
+
# return partial(nn.GroupNorm, 1, out_ch, 1e-5, True)
|
27 |
+
else:
|
28 |
+
raise ValueError(norm_type)
|
models/vgg.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy.lib.arraysetops import isin
|
2 |
+
import torchvision.models as models
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class Vgg19(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super(Vgg19, self).__init__()
|
11 |
+
self.vgg19 = self.get_vgg19().eval()
|
12 |
+
vgg_mean = torch.tensor([0.485, 0.456, 0.406]).float()
|
13 |
+
vgg_std = torch.tensor([0.229, 0.224, 0.225]).float()
|
14 |
+
self.mean = vgg_mean.view(-1, 1 ,1)
|
15 |
+
self.std = vgg_std.view(-1, 1, 1)
|
16 |
+
|
17 |
+
def to(self, device):
|
18 |
+
new_self = super(Vgg19, self).to(device)
|
19 |
+
new_self.mean = new_self.mean.to(device)
|
20 |
+
new_self.std = new_self.std.to(device)
|
21 |
+
return new_self
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
return self.vgg19(self.normalize_vgg(x))
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def get_vgg19(last_layer='conv4_4'):
|
28 |
+
vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
|
29 |
+
model_list = []
|
30 |
+
|
31 |
+
i = 0
|
32 |
+
j = 1
|
33 |
+
for layer in vgg.children():
|
34 |
+
if isinstance(layer, nn.MaxPool2d):
|
35 |
+
i = 0
|
36 |
+
j += 1
|
37 |
+
|
38 |
+
elif isinstance(layer, nn.Conv2d):
|
39 |
+
i += 1
|
40 |
+
|
41 |
+
name = f'conv{j}_{i}'
|
42 |
+
|
43 |
+
if name == last_layer:
|
44 |
+
model_list.append(layer)
|
45 |
+
break
|
46 |
+
|
47 |
+
model_list.append(layer)
|
48 |
+
|
49 |
+
|
50 |
+
model = nn.Sequential(*model_list)
|
51 |
+
return model
|
52 |
+
|
53 |
+
|
54 |
+
def normalize_vgg(self, image):
|
55 |
+
'''
|
56 |
+
Expect input in range -1 1
|
57 |
+
'''
|
58 |
+
image = (image + 1.0) / 2.0
|
59 |
+
return (image - self.mean) / self.std
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
from PIL import Image
|
64 |
+
import numpy as np
|
65 |
+
from utils.image_processing import normalize_input
|
66 |
+
|
67 |
+
image = Image.open("example/10.jpg")
|
68 |
+
image = image.resize((224, 224))
|
69 |
+
np_img = np.array(image).astype('float32')
|
70 |
+
np_img = normalize_input(np_img)
|
71 |
+
|
72 |
+
img = torch.from_numpy(np_img)
|
73 |
+
img = img.permute(2, 0, 1)
|
74 |
+
img = img.unsqueeze(0)
|
75 |
+
|
76 |
+
vgg = Vgg19()
|
77 |
+
|
78 |
+
feat = vgg(img)
|
79 |
+
|
80 |
+
print(feat.shape)
|
predict.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from inference import Predictor as MyPredictor
|
3 |
+
from utils import read_image
|
4 |
+
import cv2
|
5 |
+
import tempfile
|
6 |
+
from utils.image_processing import resize_image, normalize_input, denormalize_input
|
7 |
+
import numpy as np
|
8 |
+
from cog import BasePredictor, Path, Input
|
9 |
+
|
10 |
+
|
11 |
+
class Predictor(BasePredictor):
|
12 |
+
def setup(self):
|
13 |
+
pass
|
14 |
+
|
15 |
+
def predict(
|
16 |
+
self,
|
17 |
+
image: Path = Input(description="Image"),
|
18 |
+
model: str = Input(
|
19 |
+
description="Style",
|
20 |
+
default='Hayao:v2',
|
21 |
+
choices=[
|
22 |
+
'Hayao',
|
23 |
+
'Shinkai',
|
24 |
+
'Hayao:v2'
|
25 |
+
]
|
26 |
+
)
|
27 |
+
) -> Path:
|
28 |
+
version = model.split(":")[-1]
|
29 |
+
predictor = MyPredictor(model, version)
|
30 |
+
img = read_image(str(image))
|
31 |
+
anime_img = predictor.transform(resize_image(img))[0]
|
32 |
+
out_path = Path(tempfile.mkdtemp()) / "out.png"
|
33 |
+
cv2.imwrite(str(out_path), anime_img[..., ::-1])
|
34 |
+
return out_path
|
35 |
+
|
train.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
from models.anime_gan import GeneratorV1
|
5 |
+
from models.anime_gan_v2 import GeneratorV2
|
6 |
+
from models.anime_gan_v3 import GeneratorV3
|
7 |
+
from models.anime_gan import Discriminator
|
8 |
+
from datasets import AnimeDataSet
|
9 |
+
from utils.common import load_checkpoint
|
10 |
+
from trainer import Trainer
|
11 |
+
from utils.logger import get_logger
|
12 |
+
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--real_image_dir', type=str, default='dataset/train_photo')
|
17 |
+
parser.add_argument('--anime_image_dir', type=str, default='dataset/Hayao')
|
18 |
+
parser.add_argument('--test_image_dir', type=str, default='dataset/test/HR_photo')
|
19 |
+
parser.add_argument('--model', type=str, default='v1', help="AnimeGAN version, can be {'v1', 'v2', 'v3'}")
|
20 |
+
parser.add_argument('--epochs', type=int, default=70)
|
21 |
+
parser.add_argument('--init_epochs', type=int, default=10)
|
22 |
+
parser.add_argument('--batch_size', type=int, default=8)
|
23 |
+
parser.add_argument('--exp_dir', type=str, default='runs', help="Experiment directory")
|
24 |
+
parser.add_argument('--gan_loss', type=str, default='lsgan', help='lsgan / hinge / bce')
|
25 |
+
parser.add_argument('--resume', action='store_true', help="Continue from current dir")
|
26 |
+
parser.add_argument('--resume_G_init', type=str, default='False')
|
27 |
+
parser.add_argument('--resume_G', type=str, default='False')
|
28 |
+
parser.add_argument('--resume_D', type=str, default='False')
|
29 |
+
parser.add_argument('--device', type=str, default='cuda')
|
30 |
+
parser.add_argument('--use_sn', action='store_true')
|
31 |
+
parser.add_argument('--cache', action='store_true', help="Turn on disk cache")
|
32 |
+
parser.add_argument('--amp', action='store_true', help="Turn on Automatic Mixed Precision")
|
33 |
+
parser.add_argument('--save_interval', type=int, default=1)
|
34 |
+
parser.add_argument('--debug_samples', type=int, default=0)
|
35 |
+
parser.add_argument('--num_workers', type=int, default=2)
|
36 |
+
parser.add_argument('--imgsz', type=int, nargs="+", default=[256],
|
37 |
+
help="Image sizes, can provide multiple values, image size will increase after a proportion of epochs")
|
38 |
+
parser.add_argument('--resize_method', type=str, default="crop",
|
39 |
+
help="Resize image method if origin photo larger than imgsz")
|
40 |
+
# Loss stuff
|
41 |
+
parser.add_argument('--lr_g', type=float, default=2e-5)
|
42 |
+
parser.add_argument('--lr_d', type=float, default=4e-5)
|
43 |
+
parser.add_argument('--init_lr', type=float, default=1e-4)
|
44 |
+
parser.add_argument('--wadvg', type=float, default=300.0, help='Adversarial loss weight for G')
|
45 |
+
parser.add_argument('--wadvd', type=float, default=300.0, help='Adversarial loss weight for D')
|
46 |
+
parser.add_argument(
|
47 |
+
'--gray_adv', action='store_true',
|
48 |
+
help="If given, train adversarial with gray scale image instead of RGB image to reduce color effect of anime style")
|
49 |
+
# Loss weight VGG19
|
50 |
+
parser.add_argument('--wcon', type=float, default=1.5, help='Content loss weight') # 1.5 for Hayao, 2.0 for Paprika, 1.2 for Shinkai
|
51 |
+
parser.add_argument('--wgra', type=float, default=5.0, help='Gram loss weight') # 2.5 for Hayao, 0.6 for Paprika, 2.0 for Shinkai
|
52 |
+
parser.add_argument('--wcol', type=float, default=30.0, help='Color loss weight') # 15. for Hayao, 50. for Paprika, 10. for Shinkai
|
53 |
+
parser.add_argument('--wtvar', type=float, default=1.0, help='Total variation loss') # 1. for Hayao, 0.1 for Paprika, 1. for Shinkai
|
54 |
+
parser.add_argument('--d_layers', type=int, default=2, help='Discriminator conv layers')
|
55 |
+
parser.add_argument('--d_noise', action='store_true')
|
56 |
+
|
57 |
+
# DDP
|
58 |
+
parser.add_argument('--ddp', action='store_true')
|
59 |
+
parser.add_argument("--local-rank", default=0, type=int)
|
60 |
+
parser.add_argument("--world-size", default=2, type=int)
|
61 |
+
|
62 |
+
return parser.parse_args()
|
63 |
+
|
64 |
+
|
65 |
+
def check_params(args):
|
66 |
+
# dataset/Hayao + dataset/train_photo -> train_photo_Hayao
|
67 |
+
args.dataset = f"{os.path.basename(args.real_image_dir)}_{os.path.basename(args.anime_image_dir)}"
|
68 |
+
assert args.gan_loss in {'lsgan', 'hinge', 'bce'}, f'{args.gan_loss} is not supported'
|
69 |
+
|
70 |
+
|
71 |
+
def main(args, logger):
|
72 |
+
check_params(args)
|
73 |
+
|
74 |
+
if not torch.cuda.is_available():
|
75 |
+
logger.info("CUDA not found, use CPU")
|
76 |
+
# Just for debugging purpose, set to minimum config
|
77 |
+
# to avoid 🔥 the computer...
|
78 |
+
args.device = 'cpu'
|
79 |
+
args.debug_samples = 10
|
80 |
+
args.batch_size = 2
|
81 |
+
else:
|
82 |
+
logger.info(f"Use GPU: {torch.cuda.get_device_name(0)}")
|
83 |
+
|
84 |
+
norm_type = "instance"
|
85 |
+
if args.model == 'v1':
|
86 |
+
G = GeneratorV1(args.dataset)
|
87 |
+
elif args.model == 'v2':
|
88 |
+
G = GeneratorV2(args.dataset)
|
89 |
+
norm_type = "layer"
|
90 |
+
elif args.model == 'v3':
|
91 |
+
G = GeneratorV3(args.dataset)
|
92 |
+
|
93 |
+
D = Discriminator(
|
94 |
+
args.dataset,
|
95 |
+
num_layers=args.d_layers,
|
96 |
+
use_sn=args.use_sn,
|
97 |
+
norm_type=norm_type,
|
98 |
+
)
|
99 |
+
|
100 |
+
start_e = 0
|
101 |
+
start_e_init = 0
|
102 |
+
|
103 |
+
trainer = Trainer(
|
104 |
+
generator=G,
|
105 |
+
discriminator=D,
|
106 |
+
config=args,
|
107 |
+
logger=logger,
|
108 |
+
)
|
109 |
+
|
110 |
+
if args.resume_G_init.lower() != 'false':
|
111 |
+
start_e_init = load_checkpoint(G, args.resume_G_init) + 1
|
112 |
+
if args.local_rank == 0:
|
113 |
+
logger.info(f"G content weight loaded from {args.resume_G_init}")
|
114 |
+
elif args.resume_G.lower() != 'false' and args.resume_D.lower() != 'false':
|
115 |
+
# You should provide both
|
116 |
+
try:
|
117 |
+
start_e = load_checkpoint(G, args.resume_G)
|
118 |
+
if args.local_rank == 0:
|
119 |
+
logger.info(f"G weight loaded from {args.resume_G}")
|
120 |
+
load_checkpoint(D, args.resume_D)
|
121 |
+
if args.local_rank == 0:
|
122 |
+
logger.info(f"D weight loaded from {args.resume_D}")
|
123 |
+
# If loaded both weight, turn off init G phrase
|
124 |
+
args.init_epochs = 0
|
125 |
+
|
126 |
+
except Exception as e:
|
127 |
+
print('Could not load checkpoint, train from scratch', e)
|
128 |
+
elif args.resume:
|
129 |
+
# Try to load from working dir
|
130 |
+
logger.info(f"Loading weight from {trainer.checkpoint_path_G}")
|
131 |
+
start_e = load_checkpoint(G, trainer.checkpoint_path_G)
|
132 |
+
logger.info(f"Loading weight from {trainer.checkpoint_path_D}")
|
133 |
+
load_checkpoint(D, trainer.checkpoint_path_D)
|
134 |
+
args.init_epochs = 0
|
135 |
+
|
136 |
+
dataset = AnimeDataSet(
|
137 |
+
args.anime_image_dir,
|
138 |
+
args.real_image_dir,
|
139 |
+
args.debug_samples,
|
140 |
+
args.cache,
|
141 |
+
imgsz=args.imgsz,
|
142 |
+
resize_method=args.resize_method,
|
143 |
+
)
|
144 |
+
if args.local_rank == 0:
|
145 |
+
logger.info(f"Start from epoch {start_e}, {start_e_init}")
|
146 |
+
trainer.train(dataset, start_e, start_e_init)
|
147 |
+
|
148 |
+
if __name__ == '__main__':
|
149 |
+
args = parse_args()
|
150 |
+
real_name = os.path.basename(args.real_image_dir)
|
151 |
+
anime_name = os.path.basename(args.anime_image_dir)
|
152 |
+
args.exp_dir = f"{args.exp_dir}_{real_name}_{anime_name}"
|
153 |
+
|
154 |
+
os.makedirs(args.exp_dir, exist_ok=True)
|
155 |
+
logger = get_logger(os.path.join(args.exp_dir, "train.log"))
|
156 |
+
|
157 |
+
if args.local_rank == 0:
|
158 |
+
logger.info("# ==== Train Config ==== #")
|
159 |
+
for arg in vars(args):
|
160 |
+
logger.info(f"{arg} {getattr(args, arg)}")
|
161 |
+
logger.info("==========================")
|
162 |
+
|
163 |
+
main(args, logger)
|
trainer/__init__.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
import torch.optim as optim
|
8 |
+
import numpy as np
|
9 |
+
from glob import glob
|
10 |
+
from torch.cuda.amp import GradScaler, autocast
|
11 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
12 |
+
from torch.utils.data import Dataset, DataLoader
|
13 |
+
from tqdm import tqdm
|
14 |
+
from utils.image_processing import denormalize_input, preprocess_images, resize_image
|
15 |
+
from losses import LossSummary, AnimeGanLoss, to_gray_scale
|
16 |
+
from utils import load_checkpoint, save_checkpoint, read_image
|
17 |
+
from utils.common import set_lr
|
18 |
+
from color_transfer import color_transfer_pytorch
|
19 |
+
|
20 |
+
|
21 |
+
def transfer_color_and_rescale(src, target):
|
22 |
+
"""Transfer color from src image to target then rescale to [-1, 1]"""
|
23 |
+
out = color_transfer_pytorch(src, target) # [0, 1]
|
24 |
+
out = (out / 0.5) - 1
|
25 |
+
return out
|
26 |
+
|
27 |
+
def gaussian_noise():
|
28 |
+
gaussian_mean = torch.tensor(0.0)
|
29 |
+
gaussian_std = torch.tensor(0.1)
|
30 |
+
return torch.normal(gaussian_mean, gaussian_std)
|
31 |
+
|
32 |
+
def convert_to_readable(seconds):
|
33 |
+
return time.strftime('%H:%M:%S', time.gmtime(seconds))
|
34 |
+
|
35 |
+
|
36 |
+
def revert_to_np_image(image_tensor):
|
37 |
+
image = image_tensor.cpu().numpy()
|
38 |
+
# CHW
|
39 |
+
image = image.transpose(1, 2, 0)
|
40 |
+
image = denormalize_input(image, dtype=np.int16)
|
41 |
+
return image[..., ::-1] # to RGB
|
42 |
+
|
43 |
+
|
44 |
+
def save_generated_images(images: torch.Tensor, save_dir: str):
|
45 |
+
"""Save generated images `(*, 3, H, W)` range [-1, 1] into disk"""
|
46 |
+
os.makedirs(save_dir, exist_ok=True)
|
47 |
+
images = images.clone().detach().cpu().numpy()
|
48 |
+
images = images.transpose(0, 2, 3, 1)
|
49 |
+
n_images = len(images)
|
50 |
+
|
51 |
+
for i in range(n_images):
|
52 |
+
img = images[i]
|
53 |
+
img = denormalize_input(img, dtype=np.int16)
|
54 |
+
img = img[..., ::-1]
|
55 |
+
cv2.imwrite(os.path.join(save_dir, f"G{i}.jpg"), img)
|
56 |
+
|
57 |
+
|
58 |
+
class DDPTrainer:
|
59 |
+
def _init_distributed(self):
|
60 |
+
if self.cfg.ddp:
|
61 |
+
self.logger.info("Setting up DDP")
|
62 |
+
self.pg = torch.distributed.init_process_group(
|
63 |
+
backend="nccl",
|
64 |
+
rank=self.cfg.local_rank,
|
65 |
+
world_size=self.cfg.world_size
|
66 |
+
)
|
67 |
+
self.G = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.G, self.pg)
|
68 |
+
self.D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.D, self.pg)
|
69 |
+
torch.cuda.set_device(self.cfg.local_rank)
|
70 |
+
self.G.cuda(self.cfg.local_rank)
|
71 |
+
self.D.cuda(self.cfg.local_rank)
|
72 |
+
self.logger.info("Setting up DDP Done")
|
73 |
+
|
74 |
+
def _init_amp(self, enabled=False):
|
75 |
+
# self.scaler = torch.cuda.amp.GradScaler(enabled=enabled, growth_interval=100)
|
76 |
+
self.scaler_g = GradScaler(enabled=enabled)
|
77 |
+
self.scaler_d = GradScaler(enabled=enabled)
|
78 |
+
if self.cfg.ddp:
|
79 |
+
self.G = DistributedDataParallel(
|
80 |
+
self.G, device_ids=[self.cfg.local_rank],
|
81 |
+
output_device=self.cfg.local_rank,
|
82 |
+
find_unused_parameters=False)
|
83 |
+
|
84 |
+
self.D = DistributedDataParallel(
|
85 |
+
self.D, device_ids=[self.cfg.local_rank],
|
86 |
+
output_device=self.cfg.local_rank,
|
87 |
+
find_unused_parameters=False)
|
88 |
+
self.logger.info("Set DistributedDataParallel")
|
89 |
+
|
90 |
+
|
91 |
+
class Trainer(DDPTrainer):
|
92 |
+
"""
|
93 |
+
Base Trainer class
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
generator,
|
99 |
+
discriminator,
|
100 |
+
config,
|
101 |
+
logger,
|
102 |
+
) -> None:
|
103 |
+
self.G = generator
|
104 |
+
self.D = discriminator
|
105 |
+
self.cfg = config
|
106 |
+
self.max_norm = 10
|
107 |
+
self.device_type = 'cuda' if self.cfg.device.startswith('cuda') else 'cpu'
|
108 |
+
self.optimizer_g = optim.Adam(self.G.parameters(), lr=self.cfg.lr_g, betas=(0.5, 0.999))
|
109 |
+
self.optimizer_d = optim.Adam(self.D.parameters(), lr=self.cfg.lr_d, betas=(0.5, 0.999))
|
110 |
+
self.loss_tracker = LossSummary()
|
111 |
+
if self.cfg.ddp:
|
112 |
+
self.device = torch.device(f"cuda:{self.cfg.local_rank}")
|
113 |
+
logger.info(f"---------{self.cfg.local_rank} {self.device}")
|
114 |
+
else:
|
115 |
+
self.device = torch.device(self.cfg.device)
|
116 |
+
self.loss_fn = AnimeGanLoss(self.cfg, self.device, self.cfg.gray_adv)
|
117 |
+
self.logger = logger
|
118 |
+
self._init_working_dir()
|
119 |
+
self._init_distributed()
|
120 |
+
self._init_amp(enabled=self.cfg.amp)
|
121 |
+
|
122 |
+
def _init_working_dir(self):
|
123 |
+
"""Init working directory for saving checkpoint, ..."""
|
124 |
+
os.makedirs(self.cfg.exp_dir, exist_ok=True)
|
125 |
+
Gname = self.G.name
|
126 |
+
Dname = self.D.name
|
127 |
+
self.checkpoint_path_G_init = os.path.join(self.cfg.exp_dir, f"{Gname}_init.pt")
|
128 |
+
self.checkpoint_path_G = os.path.join(self.cfg.exp_dir, f"{Gname}.pt")
|
129 |
+
self.checkpoint_path_D = os.path.join(self.cfg.exp_dir, f"{Dname}.pt")
|
130 |
+
self.save_image_dir = os.path.join(self.cfg.exp_dir, "generated_images")
|
131 |
+
self.example_image_dir = os.path.join(self.cfg.exp_dir, "train_images")
|
132 |
+
os.makedirs(self.save_image_dir, exist_ok=True)
|
133 |
+
os.makedirs(self.example_image_dir, exist_ok=True)
|
134 |
+
|
135 |
+
def init_weight_G(self, weight: str):
|
136 |
+
"""Init Generator weight"""
|
137 |
+
return load_checkpoint(self.G, weight)
|
138 |
+
|
139 |
+
def init_weight_D(self, weight: str):
|
140 |
+
"""Init Discriminator weight"""
|
141 |
+
return load_checkpoint(self.D, weight)
|
142 |
+
|
143 |
+
def pretrain_generator(self, train_loader, start_epoch):
|
144 |
+
"""
|
145 |
+
Pretrain Generator to recontruct input image.
|
146 |
+
"""
|
147 |
+
init_losses = []
|
148 |
+
set_lr(self.optimizer_g, self.cfg.init_lr)
|
149 |
+
for epoch in range(start_epoch, self.cfg.init_epochs):
|
150 |
+
# Train with content loss only
|
151 |
+
|
152 |
+
pbar = tqdm(train_loader)
|
153 |
+
for data in pbar:
|
154 |
+
img = data["image"].to(self.device)
|
155 |
+
|
156 |
+
self.optimizer_g.zero_grad()
|
157 |
+
|
158 |
+
with autocast(enabled=self.cfg.amp):
|
159 |
+
fake_img = self.G(img)
|
160 |
+
loss = self.loss_fn.content_loss_vgg(img, fake_img)
|
161 |
+
|
162 |
+
self.scaler_g.scale(loss).backward()
|
163 |
+
self.scaler_g.step(self.optimizer_g)
|
164 |
+
self.scaler_g.update()
|
165 |
+
|
166 |
+
if self.cfg.ddp:
|
167 |
+
torch.distributed.barrier()
|
168 |
+
|
169 |
+
init_losses.append(loss.cpu().detach().numpy())
|
170 |
+
avg_content_loss = sum(init_losses) / len(init_losses)
|
171 |
+
pbar.set_description(f'[Init Training G] content loss: {avg_content_loss:2f}')
|
172 |
+
|
173 |
+
save_checkpoint(self.G, self.checkpoint_path_G_init, self.optimizer_g, epoch)
|
174 |
+
if self.cfg.local_rank == 0:
|
175 |
+
self.generate_and_save(self.cfg.test_image_dir, subname='initg')
|
176 |
+
self.logger.info(f"Epoch {epoch}/{self.cfg.init_epochs}")
|
177 |
+
|
178 |
+
set_lr(self.optimizer_g, self.cfg.lr_g)
|
179 |
+
|
180 |
+
def train_epoch(self, epoch, train_loader):
|
181 |
+
pbar = tqdm(train_loader, total=len(train_loader))
|
182 |
+
for data in pbar:
|
183 |
+
img = data["image"].to(self.device)
|
184 |
+
anime = data["anime"].to(self.device)
|
185 |
+
anime_gray = data["anime_gray"].to(self.device)
|
186 |
+
anime_smt_gray = data["smooth_gray"].to(self.device)
|
187 |
+
|
188 |
+
# ---------------- TRAIN D ---------------- #
|
189 |
+
self.optimizer_d.zero_grad()
|
190 |
+
|
191 |
+
with autocast(enabled=self.cfg.amp):
|
192 |
+
fake_img = self.G(img)
|
193 |
+
# Add some Gaussian noise to images before feeding to D
|
194 |
+
if self.cfg.d_noise:
|
195 |
+
fake_img += gaussian_noise()
|
196 |
+
anime += gaussian_noise()
|
197 |
+
anime_gray += gaussian_noise()
|
198 |
+
anime_smt_gray += gaussian_noise()
|
199 |
+
|
200 |
+
if self.cfg.gray_adv:
|
201 |
+
fake_img = to_gray_scale(fake_img)
|
202 |
+
|
203 |
+
fake_d = self.D(fake_img)
|
204 |
+
real_anime_d = self.D(anime)
|
205 |
+
real_anime_gray_d = self.D(anime_gray)
|
206 |
+
real_anime_smt_gray_d = self.D(anime_smt_gray)
|
207 |
+
|
208 |
+
loss_d = self.loss_fn.compute_loss_D(
|
209 |
+
fake_d,
|
210 |
+
real_anime_d,
|
211 |
+
real_anime_gray_d,
|
212 |
+
real_anime_smt_gray_d
|
213 |
+
)
|
214 |
+
|
215 |
+
self.scaler_d.scale(loss_d).backward()
|
216 |
+
self.scaler_d.unscale_(self.optimizer_d)
|
217 |
+
torch.nn.utils.clip_grad_norm_(self.D.parameters(), max_norm=self.max_norm)
|
218 |
+
self.scaler_d.step(self.optimizer_d)
|
219 |
+
self.scaler_d.update()
|
220 |
+
if self.cfg.ddp:
|
221 |
+
torch.distributed.barrier()
|
222 |
+
self.loss_tracker.update_loss_D(loss_d)
|
223 |
+
|
224 |
+
# ---------------- TRAIN G ---------------- #
|
225 |
+
self.optimizer_g.zero_grad()
|
226 |
+
|
227 |
+
with autocast(enabled=self.cfg.amp):
|
228 |
+
fake_img = self.G(img)
|
229 |
+
|
230 |
+
if self.cfg.gray_adv:
|
231 |
+
fake_d = self.D(to_gray_scale(fake_img))
|
232 |
+
else:
|
233 |
+
fake_d = self.D(fake_img)
|
234 |
+
|
235 |
+
(
|
236 |
+
adv_loss, con_loss,
|
237 |
+
gra_loss, col_loss,
|
238 |
+
tv_loss
|
239 |
+
) = self.loss_fn.compute_loss_G(
|
240 |
+
fake_img,
|
241 |
+
img,
|
242 |
+
fake_d,
|
243 |
+
anime_gray,
|
244 |
+
)
|
245 |
+
loss_g = adv_loss + con_loss + gra_loss + col_loss + tv_loss
|
246 |
+
if torch.isnan(adv_loss).any():
|
247 |
+
self.logger.info("----------------------------------------------")
|
248 |
+
self.logger.info(fake_d)
|
249 |
+
self.logger.info(adv_loss)
|
250 |
+
self.logger.info("----------------------------------------------")
|
251 |
+
raise ValueError("NAN loss!!")
|
252 |
+
|
253 |
+
self.scaler_g.scale(loss_g).backward()
|
254 |
+
self.scaler_d.unscale_(self.optimizer_g)
|
255 |
+
grad = torch.nn.utils.clip_grad_norm_(self.G.parameters(), max_norm=self.max_norm)
|
256 |
+
self.scaler_g.step(self.optimizer_g)
|
257 |
+
self.scaler_g.update()
|
258 |
+
if self.cfg.ddp:
|
259 |
+
torch.distributed.barrier()
|
260 |
+
|
261 |
+
self.loss_tracker.update_loss_G(adv_loss, gra_loss, col_loss, con_loss)
|
262 |
+
pbar.set_description(f"{self.loss_tracker.get_loss_description()} - {grad:.3f}")
|
263 |
+
|
264 |
+
def get_train_loader(self, dataset):
|
265 |
+
if self.cfg.ddp:
|
266 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
267 |
+
else:
|
268 |
+
train_sampler = None
|
269 |
+
return DataLoader(
|
270 |
+
dataset,
|
271 |
+
batch_size=self.cfg.batch_size,
|
272 |
+
num_workers=self.cfg.num_workers,
|
273 |
+
pin_memory=True,
|
274 |
+
shuffle=train_sampler is None,
|
275 |
+
sampler=train_sampler,
|
276 |
+
drop_last=True,
|
277 |
+
# collate_fn=collate_fn,
|
278 |
+
)
|
279 |
+
|
280 |
+
def maybe_increase_imgsz(self, epoch, train_dataset):
|
281 |
+
"""
|
282 |
+
Increase image size at specific epoch
|
283 |
+
+ 50% epochs train at imgsz[0]
|
284 |
+
+ the rest 50% will increase every `len(epochs) / 2 / (len(imgsz) - 1)`
|
285 |
+
|
286 |
+
Args:
|
287 |
+
epoch: Current epoch
|
288 |
+
train_dataset: Dataset
|
289 |
+
|
290 |
+
Examples:
|
291 |
+
```
|
292 |
+
epochs = 100
|
293 |
+
imgsz = [256, 352, 416, 512]
|
294 |
+
=> [(0, 256), (50, 352), (66, 416), (82, 512)]
|
295 |
+
```
|
296 |
+
"""
|
297 |
+
epochs = self.cfg.epochs
|
298 |
+
imgsz = self.cfg.imgsz
|
299 |
+
num_size_remains = len(imgsz) - 1
|
300 |
+
half_epochs = epochs // 2
|
301 |
+
|
302 |
+
if len(imgsz) == 1:
|
303 |
+
new_size = imgsz[0]
|
304 |
+
elif epoch < half_epochs:
|
305 |
+
new_size = imgsz[0]
|
306 |
+
else:
|
307 |
+
per_epoch_increment = int(half_epochs / num_size_remains)
|
308 |
+
found = None
|
309 |
+
for i, size in enumerate(imgsz[:]):
|
310 |
+
if epoch < half_epochs + per_epoch_increment * i:
|
311 |
+
found = size
|
312 |
+
break
|
313 |
+
if not found:
|
314 |
+
found = imgsz[-1]
|
315 |
+
new_size = found
|
316 |
+
|
317 |
+
self.logger.info(f"Check {imgsz}, {new_size}, {train_dataset.imgsz}")
|
318 |
+
if new_size != train_dataset.imgsz:
|
319 |
+
train_dataset.set_imgsz(new_size)
|
320 |
+
self.logger.info(f"Increase image size to {new_size} at epoch {epoch}")
|
321 |
+
|
322 |
+
def train(self, train_dataset: Dataset, start_epoch=0, start_epoch_g=0):
|
323 |
+
"""
|
324 |
+
Train Generator and Discriminator.
|
325 |
+
"""
|
326 |
+
self.logger.info(self.device)
|
327 |
+
self.G.to(self.device)
|
328 |
+
self.D.to(self.device)
|
329 |
+
|
330 |
+
self.pretrain_generator(self.get_train_loader(train_dataset), start_epoch_g)
|
331 |
+
|
332 |
+
if self.cfg.local_rank == 0:
|
333 |
+
self.logger.info(f"Start training for {self.cfg.epochs} epochs")
|
334 |
+
|
335 |
+
for i, data in enumerate(train_dataset):
|
336 |
+
for k in data.keys():
|
337 |
+
image = data[k]
|
338 |
+
cv2.imwrite(
|
339 |
+
os.path.join(self.example_image_dir, f"data_{k}_{i}.jpg"),
|
340 |
+
revert_to_np_image(image)
|
341 |
+
)
|
342 |
+
if i == 2:
|
343 |
+
break
|
344 |
+
|
345 |
+
end = None
|
346 |
+
num_iter = 0
|
347 |
+
per_epoch_times = []
|
348 |
+
for epoch in range(start_epoch, self.cfg.epochs):
|
349 |
+
self.maybe_increase_imgsz(epoch, train_dataset)
|
350 |
+
|
351 |
+
start = time.time()
|
352 |
+
self.train_epoch(epoch, self.get_train_loader(train_dataset))
|
353 |
+
|
354 |
+
if epoch % self.cfg.save_interval == 0 and self.cfg.local_rank == 0:
|
355 |
+
save_checkpoint(self.G, self.checkpoint_path_G,self.optimizer_g, epoch)
|
356 |
+
save_checkpoint(self.D, self.checkpoint_path_D, self.optimizer_d, epoch)
|
357 |
+
self.generate_and_save(self.cfg.test_image_dir)
|
358 |
+
|
359 |
+
if epoch % 10 == 0:
|
360 |
+
self.copy_results(epoch)
|
361 |
+
|
362 |
+
num_iter += 1
|
363 |
+
|
364 |
+
if self.cfg.local_rank == 0:
|
365 |
+
end = time.time()
|
366 |
+
if end is None:
|
367 |
+
eta = 9999
|
368 |
+
else:
|
369 |
+
per_epoch_time = (end - start)
|
370 |
+
per_epoch_times.append(per_epoch_time)
|
371 |
+
eta = np.mean(per_epoch_times) * (self.cfg.epochs - epoch)
|
372 |
+
eta = convert_to_readable(eta)
|
373 |
+
self.logger.info(f"epoch {epoch}/{self.cfg.epochs}, ETA: {eta}")
|
374 |
+
|
375 |
+
def generate_and_save(
|
376 |
+
self,
|
377 |
+
image_dir,
|
378 |
+
max_imgs=15,
|
379 |
+
subname='gen'
|
380 |
+
):
|
381 |
+
'''
|
382 |
+
Generate and save images
|
383 |
+
'''
|
384 |
+
start = time.time()
|
385 |
+
self.G.eval()
|
386 |
+
|
387 |
+
max_iter = max_imgs
|
388 |
+
fake_imgs = []
|
389 |
+
real_imgs = []
|
390 |
+
image_files = glob(os.path.join(image_dir, "*"))
|
391 |
+
|
392 |
+
for i, image_file in enumerate(image_files):
|
393 |
+
image = read_image(image_file)
|
394 |
+
image = resize_image(image)
|
395 |
+
real_imgs.append(image.copy())
|
396 |
+
image = preprocess_images(image)
|
397 |
+
image = image.to(self.device)
|
398 |
+
with torch.no_grad():
|
399 |
+
with autocast(enabled=self.cfg.amp):
|
400 |
+
fake_img = self.G(image)
|
401 |
+
# fake_img = to_gray_scale(fake_img)
|
402 |
+
fake_img = fake_img.detach().cpu().numpy()
|
403 |
+
# Channel first -> channel last
|
404 |
+
fake_img = fake_img.transpose(0, 2, 3, 1)
|
405 |
+
fake_imgs.append(denormalize_input(fake_img, dtype=np.int16)[0])
|
406 |
+
|
407 |
+
if i + 1 == max_iter:
|
408 |
+
break
|
409 |
+
|
410 |
+
# fake_imgs = np.concatenate(fake_imgs, axis=0)
|
411 |
+
|
412 |
+
for i, (real_img, fake_img) in enumerate(zip(real_imgs, fake_imgs)):
|
413 |
+
img = np.concatenate((real_img, fake_img), axis=1) # Concate aross width
|
414 |
+
save_path = os.path.join(self.save_image_dir, f'{subname}_{i}.jpg')
|
415 |
+
if not cv2.imwrite(save_path, img[..., ::-1]):
|
416 |
+
self.logger.info(f"Save generated image failed, {save_path}, {img.shape}")
|
417 |
+
elapsed = time.time() - start
|
418 |
+
self.logger.info(f"Generated {len(fake_imgs)} images in {elapsed:.3f}s.")
|
419 |
+
|
420 |
+
def copy_results(self, epoch):
|
421 |
+
"""Copy result (Weight + Generated images) to each epoch folder
|
422 |
+
Every N epoch
|
423 |
+
"""
|
424 |
+
copy_dir = os.path.join(self.cfg.exp_dir, f"epoch_{epoch}")
|
425 |
+
os.makedirs(copy_dir, exist_ok=True)
|
426 |
+
|
427 |
+
shutil.copy2(
|
428 |
+
self.checkpoint_path_G,
|
429 |
+
copy_dir
|
430 |
+
)
|
431 |
+
|
432 |
+
dest = os.path.join(copy_dir, os.path.basename(self.save_image_dir))
|
433 |
+
shutil.copytree(
|
434 |
+
self.save_image_dir,
|
435 |
+
dest,
|
436 |
+
dirs_exist_ok=True
|
437 |
+
)
|