diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8d2c8b7e49d146d5326baaff4e4c02c3bf635713 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.cache +__pycache__ +output +.token diff --git a/README.md b/README.md index 244d9012149351ce2bddb80612c6525951a05dfd..ff725893bf83b615896f1243533471a5df9bd485 100644 --- a/README.md +++ b/README.md @@ -11,3 +11,4 @@ license: mit --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..06353e719ea982de238662f50689752d69de579f --- /dev/null +++ b/app.py @@ -0,0 +1,95 @@ +import os +import cv2 +import numpy as np +import gradio as gr +from inference import Predictor +from utils.image_processing import resize_image + +os.makedirs('output', exist_ok=True) + + +def inference( + image: np.ndarray, + style, + imgsz=None, +): + retain_color = False + + weight = { + "AnimeGAN_Hayao": "hayao", + "AnimeGAN_Shinkai": "shinkai", + "AnimeGANv2_Hayao": "hayao:v2", + "AnimeGANv2_Shinkai": "shinkai:v2", + "AnimeGANv2_Arcane": "arcane:v2", + }[style] + predictor = Predictor( + weight, + device='cpu', + retain_color=retain_color, + imgsz=imgsz, + ) + + save_path = f"output/out.jpg" + image = resize_image(image, width=imgsz) + anime_image = predictor.transform(image)[0] + cv2.imwrite(save_path, anime_image[..., ::-1]) + return anime_image, save_path + + +title = "AnimeGANv2: To produce your own animation." +description = r"""Turn your photo into anime style 😊""" +article = r""" +[![GitHub Stars](https://img.shields.io/github/stars/ptran1203/pytorch-animeGAN?style=social)](https://github.com/ptran1203/pytorch-animeGAN) +### 🗻 Demo + +""" + +gr.Interface( + fn=inference, + inputs=[ + gr.components.Image(label="Input"), + gr.Dropdown( + [ + 'AnimeGAN_Hayao', + 'AnimeGAN_Shinkai', + 'AnimeGANv2_Hayao', + 'AnimeGANv2_Shinkai', + 'AnimeGANv2_Arcane', + ], + type="value", + value='AnimeGANv2_Hayao', + label='Style' + ), + gr.Dropdown( + [ + None, + 416, + 512, + 768, + 1024, + 1536, + ], + type="value", + value=None, + label='Image size' + ) + ], + outputs=[ + gr.components.Image(type="numpy", label="Output (The whole image)"), + gr.components.File(label="Download the output image") + ], + title=title, + description=description, + article=article, + allow_flagging="never", + examples=[ + ['example/arcane/girl4.jpg', 'AnimeGANv2_Arcane', "Yes"], + ['example/arcane/leo.jpg', 'AnimeGANv2_Arcane', "Yes"], + ['example/arcane/girl.jpg', 'AnimeGANv2_Arcane', "Yes"], + ['example/arcane/anne.jpg', 'AnimeGANv2_Arcane', "Yes"], + # ['example/boy2.jpg', 'AnimeGANv3_Arcane', "No"], + # ['example/cap.jpg', 'AnimeGANv3_Arcane', "No"], + ['example/more/hayao_v2/pexels-camilacarneiro-6318793.jpg', 'AnimeGANv2_Hayao', "Yes"], + ['example/more/hayao_v2/pexels-nandhukumar-450441.jpg', 'AnimeGANv2_Hayao', "Yes"], + ] +).launch() \ No newline at end of file diff --git a/example/arcane/anne.jpg b/example/arcane/anne.jpg new file mode 100644 index 0000000000000000000000000000000000000000..66b1ec251c3a4d4435fa38809f8a6c1e2c8944fd Binary files /dev/null and b/example/arcane/anne.jpg differ diff --git a/example/arcane/boy2.jpg b/example/arcane/boy2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b98d1c2aa069fd309c4a8a3979726633e08edf0d Binary files /dev/null and b/example/arcane/boy2.jpg differ diff --git a/example/arcane/cap.jpg b/example/arcane/cap.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eff95e21cfd90d3d2221f1bc7eb89c89d2d1dd87 Binary files /dev/null and b/example/arcane/cap.jpg differ diff --git a/example/arcane/dune2.jpg b/example/arcane/dune2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..284b7bfe12c6ae257e79c21c3c27a85175c20263 Binary files /dev/null and b/example/arcane/dune2.jpg differ diff --git a/example/arcane/elon.jpg b/example/arcane/elon.jpg new file mode 100644 index 0000000000000000000000000000000000000000..88595752235795d57cd11c462036a7ebe16545e2 Binary files /dev/null and b/example/arcane/elon.jpg differ diff --git a/example/arcane/girl.jpg b/example/arcane/girl.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1a6f3f7cf5e65e301eeea016c4a719c4a5d0c423 Binary files /dev/null and b/example/arcane/girl.jpg differ diff --git a/example/arcane/girl4.jpg b/example/arcane/girl4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c88e40a27a18723ac67d0a2f6a2945c0e2c6c904 Binary files /dev/null and b/example/arcane/girl4.jpg differ diff --git a/example/arcane/girl6.jpg b/example/arcane/girl6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fc4d9f01b594b300458bdf77791d20da126b4666 Binary files /dev/null and b/example/arcane/girl6.jpg differ diff --git a/example/arcane/leo.jpg b/example/arcane/leo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a24c1b8ab7574e4f5fbae17065f4ac534e88f623 Binary files /dev/null and b/example/arcane/leo.jpg differ diff --git a/example/arcane/man2.jpg b/example/arcane/man2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c0ed6292ed3d7db6b122eec8d660f87e77294ad Binary files /dev/null and b/example/arcane/man2.jpg differ diff --git a/example/arcane/nat_.jpg b/example/arcane/nat_.jpg new file mode 100644 index 0000000000000000000000000000000000000000..abb91d68593e67e29340ed4356ac4d27286e59f1 Binary files /dev/null and b/example/arcane/nat_.jpg differ diff --git a/example/arcane/seydoux.jpg b/example/arcane/seydoux.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d04443ff8395d7c5aaece14724eb5ea7668a78a9 Binary files /dev/null and b/example/arcane/seydoux.jpg differ diff --git a/example/arcane/tobey.jpg b/example/arcane/tobey.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e10938f40e8a01a8036b512a1a3b8eaf29fea654 Binary files /dev/null and b/example/arcane/tobey.jpg differ diff --git a/example/face/anne.jpg b/example/face/anne.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8e49331bfea4121a18db7da9f8e4eb634210df86 Binary files /dev/null and b/example/face/anne.jpg differ diff --git a/example/face/boy2.jpg b/example/face/boy2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..290c902434a9dc37896691e8552b18abb9fdc104 Binary files /dev/null and b/example/face/boy2.jpg differ diff --git a/example/face/cap.jpg b/example/face/cap.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e04bd807b24b7ce28ada60f5688884b746bb035b Binary files /dev/null and b/example/face/cap.jpg differ diff --git a/example/face/dune2.jpg b/example/face/dune2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cbeb623613d409ee48ab16acd2decdc0da8145ca Binary files /dev/null and b/example/face/dune2.jpg differ diff --git a/example/face/elon.jpg b/example/face/elon.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a0120cc4079a4c4de08bc75e7ab692e053f1fdc2 Binary files /dev/null and b/example/face/elon.jpg differ diff --git a/example/face/girl.jpg b/example/face/girl.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f15f2b89eddeeeaed9d9d8b6f94c2ced552777ab Binary files /dev/null and b/example/face/girl.jpg differ diff --git a/example/face/girl4.jpg b/example/face/girl4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..aa12aa2debe290a736b002223ca96ae0bba401c3 Binary files /dev/null and b/example/face/girl4.jpg differ diff --git a/example/face/girl6.jpg b/example/face/girl6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0bfc3c92b5c19680c43044935dc02d9d67f91f15 Binary files /dev/null and b/example/face/girl6.jpg differ diff --git a/example/face/leo.jpg b/example/face/leo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ca9b06133a26d78e9c1127acfb4726a4248417c6 Binary files /dev/null and b/example/face/leo.jpg differ diff --git a/example/face/man2.jpg b/example/face/man2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d96e1a392401a4ff9e923703fbd33bc567df2027 Binary files /dev/null and b/example/face/man2.jpg differ diff --git a/example/face/nat_.jpg b/example/face/nat_.jpg new file mode 100644 index 0000000000000000000000000000000000000000..045721714417879f3ea04d77bf2844b44062c1fa Binary files /dev/null and b/example/face/nat_.jpg differ diff --git a/example/face/seydoux.jpg b/example/face/seydoux.jpg new file mode 100644 index 0000000000000000000000000000000000000000..36d0cc07b34a779750791779203be52f53543b29 Binary files /dev/null and b/example/face/seydoux.jpg differ diff --git a/example/face/tobey.jpg b/example/face/tobey.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4611ca5576f7e98704c3aeb02d1b55d3ec943389 Binary files /dev/null and b/example/face/tobey.jpg differ diff --git a/example/generate_examples.py b/example/generate_examples.py new file mode 100644 index 0000000000000000000000000000000000000000..9330ea52753dc56f44e0ecd7a5fa396a732a1885 --- /dev/null +++ b/example/generate_examples.py @@ -0,0 +1,49 @@ +import os +import cv2 +import re + +REG = re.compile(r"[0-9]{3}") +dir_ = './example/result' +readme = './README.md' + + +def anime_2_input(fi): + return fi.replace("_anime", "") + +def rename(f): + return f.replace(" ", "").replace("(", "").replace(")", "") + +def rename_back(f): + nums = REG.search(f) + if nums: + nums = nums.group() + return f.replace(nums, f"{nums[0]} ({nums[1:]})") + + return f.replace('jpeg', 'jpg') + +def copyfile(src, dest): + # copy and resize + im = cv2.imread(src) + + if im is None: + raise FileNotFoundError(src) + + h, w = im.shape[1], im.shape[0] + + s = 448 + size = (s, round(s * w / h)) + im = cv2.resize(im, size) + + print(w, h, im.shape) + cv2.imwrite(dest, im) + +files = os.listdir(dir_) +new_files = [] +for f in files: + input_ver = os.path.join(dir_, anime_2_input(f)) + copyfile(f"dataset/test/HR_photo/{rename_back(anime_2_input(f))}", rename(input_ver)) + + os.rename( + os.path.join(dir_, f), + os.path.join(dir_, rename(f)) + ) diff --git a/example/more/hayao_v2/pexels-arnie-chou-304906-1004122.jpg b/example/more/hayao_v2/pexels-arnie-chou-304906-1004122.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8713a249814505e5a7b3cc4ad7eb8e582f1de53d Binary files /dev/null and b/example/more/hayao_v2/pexels-arnie-chou-304906-1004122.jpg differ diff --git a/example/more/hayao_v2/pexels-camilacarneiro-6318793.jpg b/example/more/hayao_v2/pexels-camilacarneiro-6318793.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3a98b2cf6241c364419038002f1d8feb45a52e18 Binary files /dev/null and b/example/more/hayao_v2/pexels-camilacarneiro-6318793.jpg differ diff --git a/example/more/hayao_v2/pexels-haohd-19859127.jpg b/example/more/hayao_v2/pexels-haohd-19859127.jpg new file mode 100644 index 0000000000000000000000000000000000000000..46a40f28a051cd8d9c2aac6f5ba62c1de4b003fa Binary files /dev/null and b/example/more/hayao_v2/pexels-haohd-19859127.jpg differ diff --git a/example/more/hayao_v2/pexels-huy-nguyen-748440234-19838813.jpg b/example/more/hayao_v2/pexels-huy-nguyen-748440234-19838813.jpg new file mode 100644 index 0000000000000000000000000000000000000000..965411759b3f65991e659086c1f03aa0dd0dad7c Binary files /dev/null and b/example/more/hayao_v2/pexels-huy-nguyen-748440234-19838813.jpg differ diff --git a/example/more/hayao_v2/pexels-huy-phan-316220-1422386.jpg b/example/more/hayao_v2/pexels-huy-phan-316220-1422386.jpg new file mode 100644 index 0000000000000000000000000000000000000000..48eb010f757a7a45e2ef54bd91e2198ad7e4c47e Binary files /dev/null and b/example/more/hayao_v2/pexels-huy-phan-316220-1422386.jpg differ diff --git a/example/more/hayao_v2/pexels-jimmy-teoh-294331-951531.jpg b/example/more/hayao_v2/pexels-jimmy-teoh-294331-951531.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a042a51770e3380c3a11c1d8a13515e40bdc1907 Binary files /dev/null and b/example/more/hayao_v2/pexels-jimmy-teoh-294331-951531.jpg differ diff --git a/example/more/hayao_v2/pexels-nandhukumar-450441.jpg b/example/more/hayao_v2/pexels-nandhukumar-450441.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9f7599ec4635ef9e433d1cb7d2a1935af8c40096 Binary files /dev/null and b/example/more/hayao_v2/pexels-nandhukumar-450441.jpg differ diff --git a/example/more/hayao_v2/pexels-sevenstormphotography-575362.jpg b/example/more/hayao_v2/pexels-sevenstormphotography-575362.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a52a6fb1cbdc581ae7199fc3b1ee5cfd45b9b0da Binary files /dev/null and b/example/more/hayao_v2/pexels-sevenstormphotography-575362.jpg differ diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..24463fd567fbd918d2c9f1de4f1e57f3b945deeb --- /dev/null +++ b/inference.py @@ -0,0 +1,410 @@ +import os +import time +import shutil + +import torch +import cv2 +import numpy as np + +from models.anime_gan import GeneratorV1 +from models.anime_gan_v2 import GeneratorV2 +from models.anime_gan_v3 import GeneratorV3 +from utils.common import load_checkpoint, RELEASED_WEIGHTS +from utils.image_processing import resize_image, normalize_input, denormalize_input +from utils import read_image, is_image_file, is_video_file +from tqdm import tqdm +from color_transfer import color_transfer_pytorch + + +try: + import matplotlib.pyplot as plt +except ImportError: + plt = None + +try: + import moviepy.video.io.ffmpeg_writer as ffmpeg_writer + from moviepy.video.io.VideoFileClip import VideoFileClip +except ImportError: + ffmpeg_writer = None + VideoFileClip = None + + +def profile(func): + def wrap(*args, **kwargs): + started_at = time.time() + result = func(*args, **kwargs) + elapsed = time.time() - started_at + print(f"Processed in {elapsed:.3f}s") + return result + return wrap + + +def auto_load_weight(weight, version=None, map_location=None): + """Auto load Generator version from weight.""" + weight_name = os.path.basename(weight).lower() + if version is not None: + version = version.lower() + assert version in {"v1", "v2", "v3"}, f"Version {version} does not exist" + # If version is provided, use it. + cls = { + "v1": GeneratorV1, + "v2": GeneratorV2, + "v3": GeneratorV3 + }[version] + else: + # Try to get class by name of weight file + # For convenenice, weight should start with classname + # e.g: Generatorv2_{anything}.pt + if weight_name in RELEASED_WEIGHTS: + version = RELEASED_WEIGHTS[weight_name][0] + return auto_load_weight(weight, version=version, map_location=map_location) + + elif weight_name.startswith("generatorv2"): + cls = GeneratorV2 + elif weight_name.startswith("generatorv3"): + cls = GeneratorV3 + elif weight_name.startswith("generator"): + cls = GeneratorV1 + else: + raise ValueError((f"Can not get Model from {weight_name}, " + "you might need to explicitly specify version")) + model = cls() + load_checkpoint(model, weight, strip_optimizer=True, map_location=map_location) + model.eval() + return model + + +class Predictor: + """ + Generic class for transfering Image to anime like image. + """ + def __init__( + self, + weight='hayao', + device='cuda', + amp=True, + retain_color=False, + imgsz=None, + ): + if not torch.cuda.is_available(): + device = 'cpu' + # Amp not working on cpu + amp = False + print("Use CPU device") + else: + print(f"Use GPU {torch.cuda.get_device_name()}") + + self.imgsz = imgsz + self.retain_color = retain_color + self.amp = amp # Automatic Mixed Precision + self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' + self.device = torch.device(device) + self.G = auto_load_weight(weight, map_location=device) + self.G.to(self.device) + + def transform_and_show( + self, + image_path, + figsize=(18, 10), + save_path=None + ): + image = resize_image(read_image(image_path)) + anime_img = self.transform(image) + anime_img = anime_img.astype('uint8') + + fig = plt.figure(figsize=figsize) + fig.add_subplot(1, 2, 1) + # plt.title("Input") + plt.imshow(image) + plt.axis('off') + fig.add_subplot(1, 2, 2) + # plt.title("Anime style") + plt.imshow(anime_img[0]) + plt.axis('off') + plt.tight_layout() + plt.show() + if save_path is not None: + plt.savefig(save_path) + + def transform(self, image, denorm=True): + ''' + Transform a image to animation + + @Arguments: + - image: np.array, shape = (Batch, width, height, channels) + + @Returns: + - anime version of image: np.array + ''' + with torch.no_grad(): + image = self.preprocess_images(image) + # image = image.to(self.device) + # with autocast(self.device_type, enabled=self.amp): + # print(image.dtype, self.G) + fake = self.G(image) + # Transfer color of fake image look similiar color as image + if self.retain_color: + fake = color_transfer_pytorch(fake, image) + fake = (fake / 0.5) - 1.0 # remap to [-1. 1] + fake = fake.detach().cpu().numpy() + # Channel last + fake = fake.transpose(0, 2, 3, 1) + + if denorm: + fake = denormalize_input(fake, dtype=np.uint8) + return fake + + def read_and_resize(self, path, max_size=1536): + image = read_image(path) + _, ext = os.path.splitext(path) + h, w = image.shape[:2] + if self.imgsz is not None: + image = resize_image(image, width=self.imgsz) + elif max(h, w) > max_size: + print(f"Image {os.path.basename(path)} is too big ({h}x{w}), resize to max size {max_size}") + image = resize_image( + image, + width=max_size if w > h else None, + height=max_size if w < h else None, + ) + cv2.imwrite(path.replace(ext, ".jpg"), image[:,:,::-1]) + else: + image = resize_image(image) + # image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + # image = np.stack([image, image, image], -1) + # cv2.imwrite(path.replace(ext, ".jpg"), image[:,:,::-1]) + return image + + @profile + def transform_file(self, file_path, save_path): + if not is_image_file(save_path): + raise ValueError(f"{save_path} is not valid") + + image = self.read_and_resize(file_path) + anime_img = self.transform(image)[0] + cv2.imwrite(save_path, anime_img[..., ::-1]) + print(f"Anime image saved to {save_path}") + return anime_img + + @profile + def transform_gif(self, file_path, save_path, batch_size=4): + import imageio + + def _preprocess_gif(img): + if img.shape[-1] == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) + return resize_image(img) + + images = imageio.mimread(file_path) + images = np.stack([ + _preprocess_gif(img) + for img in images + ]) + + print(images.shape) + + anime_gif = np.zeros_like(images) + + for i in tqdm(range(0, len(images), batch_size)): + end = i + batch_size + anime_gif[i: end] = self.transform( + images[i: end] + ) + + if end < len(images) - 1: + # transform last frame + print("LAST", images[end: ].shape) + anime_gif[end:] = self.transform(images[end:]) + + print(anime_gif.shape) + imageio.mimsave( + save_path, + anime_gif, + + ) + print(f"Anime image saved to {save_path}") + + @profile + def transform_in_dir(self, img_dir, dest_dir, max_images=0, img_size=(512, 512)): + ''' + Read all images from img_dir, transform and write the result + to dest_dir + + ''' + os.makedirs(dest_dir, exist_ok=True) + + files = os.listdir(img_dir) + files = [f for f in files if is_image_file(f)] + print(f'Found {len(files)} images in {img_dir}') + + if max_images: + files = files[:max_images] + + bar = tqdm(files) + for fname in bar: + path = os.path.join(img_dir, fname) + image = self.read_and_resize(path) + anime_img = self.transform(image)[0] + # anime_img = resize_image(anime_img, width=320) + ext = fname.split('.')[-1] + fname = fname.replace(f'.{ext}', '') + cv2.imwrite(os.path.join(dest_dir, f'{fname}.jpg'), anime_img[..., ::-1]) + bar.set_description(f"{fname} {image.shape}") + + def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0): + ''' + Transform a video to animation version + https://github.com/lengstrom/fast-style-transfer/blob/master/evaluate.py#L21 + ''' + if VideoFileClip is None: + raise ImportError("moviepy is not installed, please install with `pip install moviepy>=1.0.3`") + # Force to None + end = end or None + + if not os.path.isfile(input_path): + raise FileNotFoundError(f'{input_path} does not exist') + + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + is_gg_drive = '/drive/' in output_path + temp_file = '' + + if is_gg_drive: + # Writing directly into google drive can be inefficient + temp_file = f'tmp_anime.{output_path.split(".")[-1]}' + + def transform_and_write(frames, count, writer): + anime_images = self.transform(frames) + for i in range(0, count): + img = np.clip(anime_images[i], 0, 255) + writer.write_frame(img) + + video_clip = VideoFileClip(input_path, audio=False) + if start or end: + video_clip = video_clip.subclip(start, end) + + video_writer = ffmpeg_writer.FFMPEG_VideoWriter( + temp_file or output_path, + video_clip.size, video_clip.fps, + codec="libx264", + # preset="medium", bitrate="2000k", + ffmpeg_params=None) + + total_frames = round(video_clip.fps * video_clip.duration) + print(f'Transfroming video {input_path}, {total_frames} frames, size: {video_clip.size}') + + batch_shape = (batch_size, video_clip.size[1], video_clip.size[0], 3) + frame_count = 0 + frames = np.zeros(batch_shape, dtype=np.float32) + for frame in tqdm(video_clip.iter_frames(), total=total_frames): + try: + frames[frame_count] = frame + frame_count += 1 + if frame_count == batch_size: + transform_and_write(frames, frame_count, video_writer) + frame_count = 0 + except Exception as e: + print(e) + break + + # The last frames + if frame_count != 0: + transform_and_write(frames, frame_count, video_writer) + + if temp_file: + # move to output path + shutil.move(temp_file, output_path) + + print(f'Animation video saved to {output_path}') + video_writer.close() + + def preprocess_images(self, images): + ''' + Preprocess image for inference + + @Arguments: + - images: np.ndarray + + @Returns + - images: torch.tensor + ''' + images = images.astype(np.float32) + + # Normalize to [-1, 1] + images = normalize_input(images) + images = torch.from_numpy(images) + + images = images.to(self.device) + + # Add batch dim + if len(images.shape) == 3: + images = images.unsqueeze(0) + + # channel first + images = images.permute(0, 3, 1, 2) + + return images + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + '--weight', + type=str, + default="hayao:v2", + help=f'Model weight, can be path or pretrained {tuple(RELEASED_WEIGHTS.keys())}' + ) + parser.add_argument('--src', type=str, help='Source, can be directory contains images, image file or video file.') + parser.add_argument('--device', type=str, default='cuda', help='Device, cuda or cpu') + parser.add_argument('--imgsz', type=int, default=None, help='Resize image to specified size if provided') + parser.add_argument('--out', type=str, default='inference_images', help='Output, can be directory or file') + parser.add_argument( + '--retain-color', + action='store_true', + help='If provided the generated image will retain original color of input image') + # Video params + parser.add_argument('--batch-size', type=int, default=4, help='Batch size when inference video') + parser.add_argument('--start', type=int, default=0, help='Start time of video (second)') + parser.add_argument('--end', type=int, default=0, help='End time of video (second), 0 if not set') + + return parser.parse_args() + +if __name__ == '__main__': + args = parse_args() + + predictor = Predictor( + args.weight, + args.device, + retain_color=args.retain_color, + imgsz=args.imgsz, + ) + + if not os.path.exists(args.src): + raise FileNotFoundError(args.src) + + if is_video_file(args.src): + predictor.transform_video( + args.src, + args.out, + args.batch_size, + start=args.start, + end=args.end + ) + elif os.path.isdir(args.src): + predictor.transform_in_dir(args.src, args.out) + elif os.path.isfile(args.src): + save_path = args.out + if not is_image_file(args.out): + os.makedirs(args.out, exist_ok=True) + save_path = os.path.join(args.out, os.path.basename(args.src)) + + if args.src.endswith('.gif'): + # GIF file + predictor.transform_gif(args.src, save_path, args.batch_size) + else: + predictor.transform_file(args.src, save_path) + else: + raise NotImplementedError(f"{args.src} is not supported") diff --git a/losses.py b/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c6e1f83cd0efeb0e42c2a88a0c4d1fae9105c0 --- /dev/null +++ b/losses.py @@ -0,0 +1,248 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from models.vgg import Vgg19 +from utils.image_processing import gram + + +def to_gray_scale(image): + # https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/_color.py#L33 + # Image are assum in range 1, -1 + image = (image + 1.0) / 2.0 # To [0, 1] + r, g, b = image.unbind(dim=-3) + l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) + l_img = l_img.unsqueeze(dim=-3) + l_img = l_img.to(image.dtype) + l_img = l_img.expand(image.shape) + l_img = l_img / 0.5 - 1.0 # To [-1, 1] + return l_img + + +class ColorLoss(nn.Module): + def __init__(self): + super(ColorLoss, self).__init__() + self.l1 = nn.L1Loss() + self.huber = nn.SmoothL1Loss() + # self._rgb_to_yuv_kernel = torch.tensor([ + # [0.299, -0.14714119, 0.61497538], + # [0.587, -0.28886916, -0.51496512], + # [0.114, 0.43601035, -0.10001026] + # ]).float() + + self._rgb_to_yuv_kernel = torch.tensor([ + [0.299, 0.587, 0.114], + [-0.14714119, -0.28886916, 0.43601035], + [0.61497538, -0.51496512, -0.10001026], + ]).float() + + def to(self, device): + new_self = super(ColorLoss, self).to(device) + new_self._rgb_to_yuv_kernel = new_self._rgb_to_yuv_kernel.to(device) + return new_self + + def rgb_to_yuv(self, image): + ''' + https://en.wikipedia.org/wiki/YUV + + output: Image of shape (H, W, C) (channel last) + ''' + # -1 1 -> 0 1 + image = (image + 1.0) / 2.0 + image = image.permute(0, 2, 3, 1) # To channel last + + yuv_img = image @ self._rgb_to_yuv_kernel.T + + return yuv_img + + def forward(self, image, image_g): + image = self.rgb_to_yuv(image) + image_g = self.rgb_to_yuv(image_g) + # After convert to yuv, both images have channel last + return ( + self.l1(image[:, :, :, 0], image_g[:, :, :, 0]) + + self.huber(image[:, :, :, 1], image_g[:, :, :, 1]) + + self.huber(image[:, :, :, 2], image_g[:, :, :, 2]) + ) + + +class AnimeGanLoss: + def __init__(self, args, device, gray_adv=False): + if isinstance(device, str): + device = torch.device(device) + + self.content_loss = nn.L1Loss().to(device) + self.gram_loss = nn.L1Loss().to(device) + self.color_loss = ColorLoss().to(device) + self.wadvg = args.wadvg + self.wadvd = args.wadvd + self.wcon = args.wcon + self.wgra = args.wgra + self.wcol = args.wcol + self.wtvar = args.wtvar + # If true, use gray scale image to calculate adversarial loss + self.gray_adv = gray_adv + self.vgg19 = Vgg19().to(device).eval() + self.adv_type = args.gan_loss + self.bce_loss = nn.BCEWithLogitsLoss() + + def compute_loss_G(self, fake_img, img, fake_logit, anime_gray): + ''' + Compute loss for Generator + + @Args: + - fake_img: generated image + - img: real image + - fake_logit: output of Discriminator given fake image + - anime_gray: grayscale of anime image + + @Returns: + - Adversarial Loss of fake logits + - Content loss between real and fake features (vgg19) + - Gram loss between anime and fake features (Vgg19) + - Color loss between image and fake image + - Total variation loss of fake image + ''' + fake_feat = self.vgg19(fake_img) + gray_feat = self.vgg19(anime_gray) + img_feat = self.vgg19(img) + # fake_gray_feat = self.vgg19(to_gray_scale(fake_img)) + + return [ + # Want to be real image. + self.wadvg * self.adv_loss_g(fake_logit), + self.wcon * self.content_loss(img_feat, fake_feat), + self.wgra * self.gram_loss(gram(gray_feat), gram(fake_feat)), + self.wcol * self.color_loss(img, fake_img), + self.wtvar * self.total_variation_loss(fake_img) + ] + + def compute_loss_D( + self, + fake_img_d, + real_anime_d, + real_anime_gray_d, + real_anime_smooth_gray_d=None + ): + if self.gray_adv: + # Treat gray scale image as real + return ( + self.adv_loss_d_real(real_anime_gray_d) + + self.adv_loss_d_fake(fake_img_d) + + 0.3 * self.adv_loss_d_fake(real_anime_smooth_gray_d) + ) + else: + return ( + # Classify real anime as real + self.adv_loss_d_real(real_anime_d) + # Classify generated as fake + + self.adv_loss_d_fake(fake_img_d) + # Classify real anime gray as fake + # + self.adv_loss_d_fake(real_anime_gray_d) + # Classify real anime as fake + # + 0.1 * self.adv_loss_d_fake(real_anime_smooth_gray_d) + ) + + def total_variation_loss(self, fake_img): + """ + A smooth loss in fact. Like the smooth prior in MRF. + V(y) = || y_{n+1} - y_n ||_2 + """ + # Channel first -> channel last + fake_img = fake_img.permute(0, 2, 3, 1) + def _l2(x): + # sum(t ** 2) / 2 + return torch.sum(x ** 2) / 2 + + dh = fake_img[:, :-1, ...] - fake_img[:, 1:, ...] + dw = fake_img[:, :, :-1, ...] - fake_img[:, :, 1:, ...] + return _l2(dh) / dh.numel() + _l2(dw) / dw.numel() + + def content_loss_vgg(self, image, recontruction): + feat = self.vgg19(image) + re_feat = self.vgg19(recontruction) + feature_loss = self.content_loss(feat, re_feat) + content_loss = self.content_loss(image, recontruction) + return feature_loss# + 0.5 * content_loss + + def adv_loss_d_real(self, pred): + """Push pred to class 1 (real)""" + if self.adv_type == 'hinge': + return torch.mean(F.relu(1.0 - pred)) + + elif self.adv_type == 'lsgan': + # pred = torch.sigmoid(pred) + return torch.mean(torch.square(pred - 1.0)) + + elif self.adv_type == 'bce': + return self.bce_loss(pred, torch.ones_like(pred)) + + raise ValueError(f'Do not support loss type {self.adv_type}') + + def adv_loss_d_fake(self, pred): + """Push pred to class 0 (fake)""" + if self.adv_type == 'hinge': + return torch.mean(F.relu(1.0 + pred)) + + elif self.adv_type == 'lsgan': + # pred = torch.sigmoid(pred) + return torch.mean(torch.square(pred)) + + elif self.adv_type == 'bce': + return self.bce_loss(pred, torch.zeros_like(pred)) + + raise ValueError(f'Do not support loss type {self.adv_type}') + + def adv_loss_g(self, pred): + """Push pred to class 1 (real)""" + if self.adv_type == 'hinge': + return -torch.mean(pred) + + elif self.adv_type == 'lsgan': + # pred = torch.sigmoid(pred) + return torch.mean(torch.square(pred - 1.0)) + + elif self.adv_type == 'bce': + return self.bce_loss(pred, torch.ones_like(pred)) + + raise ValueError(f'Do not support loss type {self.adv_type}') + + +class LossSummary: + def __init__(self): + self.reset() + + def reset(self): + self.loss_g_adv = [] + self.loss_content = [] + self.loss_gram = [] + self.loss_color = [] + self.loss_d_adv = [] + + def update_loss_G(self, adv, gram, color, content): + self.loss_g_adv.append(adv.cpu().detach().numpy()) + self.loss_gram.append(gram.cpu().detach().numpy()) + self.loss_color.append(color.cpu().detach().numpy()) + self.loss_content.append(content.cpu().detach().numpy()) + + def update_loss_D(self, loss): + self.loss_d_adv.append(loss.cpu().detach().numpy()) + + def avg_loss_G(self): + return ( + self._avg(self.loss_g_adv), + self._avg(self.loss_gram), + self._avg(self.loss_color), + self._avg(self.loss_content), + ) + + def avg_loss_D(self): + return self._avg(self.loss_d_adv) + + def get_loss_description(self): + avg_adv, avg_gram, avg_color, avg_content = self.avg_loss_G() + avg_adv_d = self.avg_loss_D() + return f'loss G: adv {avg_adv:2f} con {avg_content:2f} gram {avg_gram:2f} color {avg_color:2f} / loss D: {avg_adv_d:2f}' + + @staticmethod + def _avg(losses): + return sum(losses) / len(losses) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..791c1165f382aa885ba350022a24a09867355f4e --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,3 @@ +from .anime_gan import GeneratorV1 +from .anime_gan_v2 import GeneratorV2 +from .anime_gan_v3 import GeneratorV3 diff --git a/models/anime_gan.py b/models/anime_gan.py new file mode 100644 index 0000000000000000000000000000000000000000..cf74b0a8de28d1f1123e06857326050c84daa8d2 --- /dev/null +++ b/models/anime_gan.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import spectral_norm +from .conv_blocks import DownConv +from .conv_blocks import UpConv +from .conv_blocks import SeparableConv2D +from .conv_blocks import InvertedResBlock +from .conv_blocks import ConvBlock +from .layers import get_norm +from utils.common import initialize_weights + + +class GeneratorV1(nn.Module): + def __init__(self, dataset=''): + super(GeneratorV1, self).__init__() + self.name = f'{self.__class__.__name__}_{dataset}' + bias = False + + self.encode_blocks = nn.Sequential( + ConvBlock(3, 64, bias=bias), + ConvBlock(64, 128, bias=bias), + DownConv(128, bias=bias), + ConvBlock(128, 128, bias=bias), + SeparableConv2D(128, 256, bias=bias), + DownConv(256, bias=bias), + ConvBlock(256, 256, bias=bias), + ) + + self.res_blocks = nn.Sequential( + InvertedResBlock(256, 256), + InvertedResBlock(256, 256), + InvertedResBlock(256, 256), + InvertedResBlock(256, 256), + InvertedResBlock(256, 256), + InvertedResBlock(256, 256), + InvertedResBlock(256, 256), + InvertedResBlock(256, 256), + ) + + self.decode_blocks = nn.Sequential( + ConvBlock(256, 128, bias=bias), + UpConv(128, bias=bias), + SeparableConv2D(128, 128, bias=bias), + ConvBlock(128, 128, bias=bias), + UpConv(128, bias=bias), + ConvBlock(128, 64, bias=bias), + ConvBlock(64, 64, bias=bias), + nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, bias=bias), + nn.Tanh(), + ) + + initialize_weights(self) + + def forward(self, x): + out = self.encode_blocks(x) + out = self.res_blocks(out) + img = self.decode_blocks(out) + + return img + + +class Discriminator(nn.Module): + def __init__( + self, + dataset=None, + num_layers=1, + use_sn=False, + norm_type="instance", + ): + super(Discriminator, self).__init__() + self.name = f'discriminator_{dataset}' + self.bias = False + channels = 32 + + layers = [ + nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1, bias=self.bias), + nn.LeakyReLU(0.2, True) + ] + + in_channels = channels + for i in range(num_layers): + layers += [ + nn.Conv2d(in_channels, channels * 2, kernel_size=3, stride=2, padding=1, bias=self.bias), + nn.LeakyReLU(0.2, True), + nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=self.bias), + get_norm(norm_type, channels * 4), + nn.LeakyReLU(0.2, True), + ] + in_channels = channels * 4 + channels *= 2 + + channels *= 2 + layers += [ + nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=self.bias), + get_norm(norm_type, channels), + nn.LeakyReLU(0.2, True), + nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=self.bias), + ] + + if use_sn: + for i in range(len(layers)): + if isinstance(layers[i], nn.Conv2d): + layers[i] = spectral_norm(layers[i]) + + self.discriminate = nn.Sequential(*layers) + + initialize_weights(self) + + def forward(self, img): + logits = self.discriminate(img) + return logits diff --git a/models/anime_gan_v2.py b/models/anime_gan_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..e03136a83a4ef585d39eadb96fc80da5231cf537 --- /dev/null +++ b/models/anime_gan_v2.py @@ -0,0 +1,65 @@ + +import torch.nn as nn +import torch.nn.functional as F +from models.conv_blocks import InvertedResBlock +from models.conv_blocks import ConvBlock +from models.conv_blocks import UpConvLNormLReLU +from utils.common import initialize_weights + + +class GeneratorV2(nn.Module): + def __init__(self, dataset=''): + super(GeneratorV2, self).__init__() + self.name = f'{self.__class__.__name__}_{dataset}' + + self.conv_block1 = nn.Sequential( + ConvBlock(3, 32, kernel_size=7, stride=1, padding=3, norm_type="layer"), + ConvBlock(32, 64, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"), + ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"), + ) + + self.conv_block2 = nn.Sequential( + ConvBlock(64, 128, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"), + ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"), + ) + + self.res_blocks = nn.Sequential( + ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"), + InvertedResBlock(128, 256, expand_ratio=2, norm_type="layer"), + InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"), + InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"), + InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"), + ConvBlock(256, 128, kernel_size=3, stride=1, norm_type="layer"), + ) + + self.conv_block3 = nn.Sequential( + # UpConvLNormLReLU(128, 128, norm_type="layer"), + ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"), + ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"), + ) + + self.conv_block4 = nn.Sequential( + # UpConvLNormLReLU(128, 64, norm_type="layer"), + ConvBlock(128, 64, kernel_size=3, stride=1, norm_type="layer"), + ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"), + ConvBlock(64, 32, kernel_size=7, padding=3, stride=1, norm_type="layer"), + ) + + self.decode_blocks = nn.Sequential( + nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), + nn.Tanh(), + ) + + initialize_weights(self) + + def forward(self, x): + out = self.conv_block1(x) + out = self.conv_block2(out) + out = self.res_blocks(out) + out = F.interpolate(out, scale_factor=2, mode="bilinear") + out = self.conv_block3(out) + out = F.interpolate(out, scale_factor=2, mode="bilinear") + out = self.conv_block4(out) + img = self.decode_blocks(out) + + return img diff --git a/models/anime_gan_v3.py b/models/anime_gan_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..2ec9108c5dd29ad978714161d78d2c93591b28b8 --- /dev/null +++ b/models/anime_gan_v3.py @@ -0,0 +1,14 @@ + +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import spectral_norm +from models.conv_blocks import DownConv +from models.conv_blocks import UpConv +from models.conv_blocks import SeparableConv2D +from models.conv_blocks import InvertedResBlock +from models.conv_blocks import ConvBlock +from utils.common import initialize_weights + + +class GeneratorV3(nn.Module): + pass \ No newline at end of file diff --git a/models/conv_blocks.py b/models/conv_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..b54b42fe469770ec152d6ff69a32777d457ba858 --- /dev/null +++ b/models/conv_blocks.py @@ -0,0 +1,171 @@ +import torch.nn as nn +import torch.nn.functional as F +from utils.common import initialize_weights +from .layers import LayerNorm2d, get_norm + + +class DownConv(nn.Module): + + def __init__(self, channels, bias=False): + super(DownConv, self).__init__() + + self.conv1 = SeparableConv2D(channels, channels, stride=2, bias=bias) + self.conv2 = SeparableConv2D(channels, channels, stride=1, bias=bias) + + def forward(self, x): + out1 = self.conv1(x) + out2 = F.interpolate(x, scale_factor=0.5, mode='bilinear') + out2 = self.conv2(out2) + + return out1 + out2 + + +class UpConv(nn.Module): + def __init__(self, channels, bias=False): + super(UpConv, self).__init__() + + self.conv = SeparableConv2D(channels, channels, stride=1, bias=bias) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2.0, mode='bilinear') + out = self.conv(out) + return out + + +class UpConvLNormLReLU(nn.Module): + """Upsample Conv block with Layer Norm and Leaky ReLU""" + def __init__(self, in_channels, out_channels, norm_type="instance", bias=False): + super(UpConvLNormLReLU, self).__init__() + + self.conv_block = ConvBlock( + in_channels, + out_channels, + kernel_size=3, + norm_type=norm_type, + bias=bias, + ) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2.0, mode='bilinear') + out = self.conv_block(out) + return out + +class SeparableConv2D(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, bias=False): + super(SeparableConv2D, self).__init__() + self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, + stride=stride, padding=1, groups=in_channels, bias=bias) + self.pointwise = nn.Conv2d(in_channels, out_channels, + kernel_size=1, stride=1, bias=bias) + # self.pad = + self.ins_norm1 = nn.InstanceNorm2d(in_channels) + self.activation1 = nn.LeakyReLU(0.2, True) + self.ins_norm2 = nn.InstanceNorm2d(out_channels) + self.activation2 = nn.LeakyReLU(0.2, True) + + initialize_weights(self) + + def forward(self, x): + out = self.depthwise(x) + out = self.ins_norm1(out) + out = self.activation1(out) + + out = self.pointwise(out) + out = self.ins_norm2(out) + + return self.activation2(out) + + +class ConvBlock(nn.Module): + """Stack of Conv2D + Norm + LeakyReLU""" + def __init__( + self, + channels, + out_channels, + kernel_size=3, + stride=1, + groups=1, + padding=1, + bias=False, + norm_type="instance" + ): + super(ConvBlock, self).__init__() + + # if kernel_size == 3 and stride == 1: + # self.pad = nn.ReflectionPad2d((1, 1, 1, 1)) + # elif kernel_size == 7 and stride == 1: + # self.pad = nn.ReflectionPad2d((3, 3, 3, 3)) + # elif stride == 2: + # self.pad = nn.ReflectionPad2d((0, 1, 1, 0)) + # else: + # self.pad = None + + self.pad = nn.ReflectionPad2d(padding) + self.conv = nn.Conv2d( + channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + groups=groups, + padding=0, + bias=bias + ) + self.ins_norm = get_norm(norm_type, out_channels) + self.activation = nn.LeakyReLU(0.2, True) + + # initialize_weights(self) + + def forward(self, x): + if self.pad is not None: + x = self.pad(x) + out = self.conv(x) + out = self.ins_norm(out) + out = self.activation(out) + return out + + +class InvertedResBlock(nn.Module): + def __init__( + self, + channels=256, + out_channels=256, + expand_ratio=2, + norm_type="instance", + ): + super(InvertedResBlock, self).__init__() + bottleneck_dim = round(expand_ratio * channels) + self.conv_block = ConvBlock( + channels, + bottleneck_dim, + kernel_size=1, + padding=0, + norm_type=norm_type, + bias=False + ) + self.conv_block2 = ConvBlock( + bottleneck_dim, + bottleneck_dim, + groups=bottleneck_dim, + norm_type=norm_type, + bias=True + ) + self.conv = nn.Conv2d( + bottleneck_dim, + out_channels, + kernel_size=1, + padding=0, + bias=False + ) + self.norm = get_norm(norm_type, out_channels) + + def forward(self, x): + out = self.conv_block(x) + out = self.conv_block2(out) + # out = self.activation(out) + out = self.conv(out) + out = self.norm(out) + + if out.shape[1] != x.shape[1]: + # Only concate if same shape + return out + return out + x diff --git a/models/layers.py b/models/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..78b688ea68c87950e4c314823388b3a71abb92c2 --- /dev/null +++ b/models/layers.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + + +class LayerNorm2d(nn.LayerNorm): + """ LayerNorm for channels of '2D' spatial NCHW tensors """ + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # https://pytorch.org/vision/0.12/_modules/torchvision/models/convnext.html + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + return x + + +def get_norm(norm_type, channels): + if norm_type == "instance": + return nn.InstanceNorm2d(channels) + elif norm_type == "layer": + # return LayerNorm2d + return nn.GroupNorm(num_groups=1, num_channels=channels, affine=True) + # return partial(nn.GroupNorm, 1, out_ch, 1e-5, True) + else: + raise ValueError(norm_type) diff --git a/models/vgg.py b/models/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..29c3a9b57f6b5d715e6d68be367ff63586b6e0c3 --- /dev/null +++ b/models/vgg.py @@ -0,0 +1,80 @@ +from numpy.lib.arraysetops import isin +import torchvision.models as models +import torch.nn as nn +import torch + + + +class Vgg19(nn.Module): + def __init__(self): + super(Vgg19, self).__init__() + self.vgg19 = self.get_vgg19().eval() + vgg_mean = torch.tensor([0.485, 0.456, 0.406]).float() + vgg_std = torch.tensor([0.229, 0.224, 0.225]).float() + self.mean = vgg_mean.view(-1, 1 ,1) + self.std = vgg_std.view(-1, 1, 1) + + def to(self, device): + new_self = super(Vgg19, self).to(device) + new_self.mean = new_self.mean.to(device) + new_self.std = new_self.std.to(device) + return new_self + + def forward(self, x): + return self.vgg19(self.normalize_vgg(x)) + + @staticmethod + def get_vgg19(last_layer='conv4_4'): + vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features + model_list = [] + + i = 0 + j = 1 + for layer in vgg.children(): + if isinstance(layer, nn.MaxPool2d): + i = 0 + j += 1 + + elif isinstance(layer, nn.Conv2d): + i += 1 + + name = f'conv{j}_{i}' + + if name == last_layer: + model_list.append(layer) + break + + model_list.append(layer) + + + model = nn.Sequential(*model_list) + return model + + + def normalize_vgg(self, image): + ''' + Expect input in range -1 1 + ''' + image = (image + 1.0) / 2.0 + return (image - self.mean) / self.std + + +if __name__ == '__main__': + from PIL import Image + import numpy as np + from utils.image_processing import normalize_input + + image = Image.open("example/10.jpg") + image = image.resize((224, 224)) + np_img = np.array(image).astype('float32') + np_img = normalize_input(np_img) + + img = torch.from_numpy(np_img) + img = img.permute(2, 0, 1) + img = img.unsqueeze(0) + + vgg = Vgg19() + + feat = vgg(img) + + print(feat.shape) \ No newline at end of file diff --git a/predict.py b/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..66efef0d56e298f9750fdb98ca1d3ccdbae6d2ab --- /dev/null +++ b/predict.py @@ -0,0 +1,35 @@ +from pathlib import Path +from inference import Predictor as MyPredictor +from utils import read_image +import cv2 +import tempfile +from utils.image_processing import resize_image, normalize_input, denormalize_input +import numpy as np +from cog import BasePredictor, Path, Input + + +class Predictor(BasePredictor): + def setup(self): + pass + + def predict( + self, + image: Path = Input(description="Image"), + model: str = Input( + description="Style", + default='Hayao:v2', + choices=[ + 'Hayao', + 'Shinkai', + 'Hayao:v2' + ] + ) + ) -> Path: + version = model.split(":")[-1] + predictor = MyPredictor(model, version) + img = read_image(str(image)) + anime_img = predictor.transform(resize_image(img))[0] + out_path = Path(tempfile.mkdtemp()) / "out.png" + cv2.imwrite(str(out_path), anime_img[..., ::-1]) + return out_path + diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b9cfe84142c9d59709c9a219bda06688b4a5751b --- /dev/null +++ b/train.py @@ -0,0 +1,163 @@ +import torch +import argparse +import os +from models.anime_gan import GeneratorV1 +from models.anime_gan_v2 import GeneratorV2 +from models.anime_gan_v3 import GeneratorV3 +from models.anime_gan import Discriminator +from datasets import AnimeDataSet +from utils.common import load_checkpoint +from trainer import Trainer +from utils.logger import get_logger + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--real_image_dir', type=str, default='dataset/train_photo') + parser.add_argument('--anime_image_dir', type=str, default='dataset/Hayao') + parser.add_argument('--test_image_dir', type=str, default='dataset/test/HR_photo') + parser.add_argument('--model', type=str, default='v1', help="AnimeGAN version, can be {'v1', 'v2', 'v3'}") + parser.add_argument('--epochs', type=int, default=70) + parser.add_argument('--init_epochs', type=int, default=10) + parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--exp_dir', type=str, default='runs', help="Experiment directory") + parser.add_argument('--gan_loss', type=str, default='lsgan', help='lsgan / hinge / bce') + parser.add_argument('--resume', action='store_true', help="Continue from current dir") + parser.add_argument('--resume_G_init', type=str, default='False') + parser.add_argument('--resume_G', type=str, default='False') + parser.add_argument('--resume_D', type=str, default='False') + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--use_sn', action='store_true') + parser.add_argument('--cache', action='store_true', help="Turn on disk cache") + parser.add_argument('--amp', action='store_true', help="Turn on Automatic Mixed Precision") + parser.add_argument('--save_interval', type=int, default=1) + parser.add_argument('--debug_samples', type=int, default=0) + parser.add_argument('--num_workers', type=int, default=2) + parser.add_argument('--imgsz', type=int, nargs="+", default=[256], + help="Image sizes, can provide multiple values, image size will increase after a proportion of epochs") + parser.add_argument('--resize_method', type=str, default="crop", + help="Resize image method if origin photo larger than imgsz") + # Loss stuff + parser.add_argument('--lr_g', type=float, default=2e-5) + parser.add_argument('--lr_d', type=float, default=4e-5) + parser.add_argument('--init_lr', type=float, default=1e-4) + parser.add_argument('--wadvg', type=float, default=300.0, help='Adversarial loss weight for G') + parser.add_argument('--wadvd', type=float, default=300.0, help='Adversarial loss weight for D') + parser.add_argument( + '--gray_adv', action='store_true', + help="If given, train adversarial with gray scale image instead of RGB image to reduce color effect of anime style") + # Loss weight VGG19 + parser.add_argument('--wcon', type=float, default=1.5, help='Content loss weight') # 1.5 for Hayao, 2.0 for Paprika, 1.2 for Shinkai + parser.add_argument('--wgra', type=float, default=5.0, help='Gram loss weight') # 2.5 for Hayao, 0.6 for Paprika, 2.0 for Shinkai + parser.add_argument('--wcol', type=float, default=30.0, help='Color loss weight') # 15. for Hayao, 50. for Paprika, 10. for Shinkai + parser.add_argument('--wtvar', type=float, default=1.0, help='Total variation loss') # 1. for Hayao, 0.1 for Paprika, 1. for Shinkai + parser.add_argument('--d_layers', type=int, default=2, help='Discriminator conv layers') + parser.add_argument('--d_noise', action='store_true') + + # DDP + parser.add_argument('--ddp', action='store_true') + parser.add_argument("--local-rank", default=0, type=int) + parser.add_argument("--world-size", default=2, type=int) + + return parser.parse_args() + + +def check_params(args): + # dataset/Hayao + dataset/train_photo -> train_photo_Hayao + args.dataset = f"{os.path.basename(args.real_image_dir)}_{os.path.basename(args.anime_image_dir)}" + assert args.gan_loss in {'lsgan', 'hinge', 'bce'}, f'{args.gan_loss} is not supported' + + +def main(args, logger): + check_params(args) + + if not torch.cuda.is_available(): + logger.info("CUDA not found, use CPU") + # Just for debugging purpose, set to minimum config + # to avoid 🔥 the computer... + args.device = 'cpu' + args.debug_samples = 10 + args.batch_size = 2 + else: + logger.info(f"Use GPU: {torch.cuda.get_device_name(0)}") + + norm_type = "instance" + if args.model == 'v1': + G = GeneratorV1(args.dataset) + elif args.model == 'v2': + G = GeneratorV2(args.dataset) + norm_type = "layer" + elif args.model == 'v3': + G = GeneratorV3(args.dataset) + + D = Discriminator( + args.dataset, + num_layers=args.d_layers, + use_sn=args.use_sn, + norm_type=norm_type, + ) + + start_e = 0 + start_e_init = 0 + + trainer = Trainer( + generator=G, + discriminator=D, + config=args, + logger=logger, + ) + + if args.resume_G_init.lower() != 'false': + start_e_init = load_checkpoint(G, args.resume_G_init) + 1 + if args.local_rank == 0: + logger.info(f"G content weight loaded from {args.resume_G_init}") + elif args.resume_G.lower() != 'false' and args.resume_D.lower() != 'false': + # You should provide both + try: + start_e = load_checkpoint(G, args.resume_G) + if args.local_rank == 0: + logger.info(f"G weight loaded from {args.resume_G}") + load_checkpoint(D, args.resume_D) + if args.local_rank == 0: + logger.info(f"D weight loaded from {args.resume_D}") + # If loaded both weight, turn off init G phrase + args.init_epochs = 0 + + except Exception as e: + print('Could not load checkpoint, train from scratch', e) + elif args.resume: + # Try to load from working dir + logger.info(f"Loading weight from {trainer.checkpoint_path_G}") + start_e = load_checkpoint(G, trainer.checkpoint_path_G) + logger.info(f"Loading weight from {trainer.checkpoint_path_D}") + load_checkpoint(D, trainer.checkpoint_path_D) + args.init_epochs = 0 + + dataset = AnimeDataSet( + args.anime_image_dir, + args.real_image_dir, + args.debug_samples, + args.cache, + imgsz=args.imgsz, + resize_method=args.resize_method, + ) + if args.local_rank == 0: + logger.info(f"Start from epoch {start_e}, {start_e_init}") + trainer.train(dataset, start_e, start_e_init) + +if __name__ == '__main__': + args = parse_args() + real_name = os.path.basename(args.real_image_dir) + anime_name = os.path.basename(args.anime_image_dir) + args.exp_dir = f"{args.exp_dir}_{real_name}_{anime_name}" + + os.makedirs(args.exp_dir, exist_ok=True) + logger = get_logger(os.path.join(args.exp_dir, "train.log")) + + if args.local_rank == 0: + logger.info("# ==== Train Config ==== #") + for arg in vars(args): + logger.info(f"{arg} {getattr(args, arg)}") + logger.info("==========================") + + main(args, logger) diff --git a/trainer/__init__.py b/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62643f82f3271add547b47aac31c1d3db7a0c9 --- /dev/null +++ b/trainer/__init__.py @@ -0,0 +1,437 @@ +import os +import time +import shutil + +import torch +import cv2 +import torch.optim as optim +import numpy as np +from glob import glob +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm +from utils.image_processing import denormalize_input, preprocess_images, resize_image +from losses import LossSummary, AnimeGanLoss, to_gray_scale +from utils import load_checkpoint, save_checkpoint, read_image +from utils.common import set_lr +from color_transfer import color_transfer_pytorch + + +def transfer_color_and_rescale(src, target): + """Transfer color from src image to target then rescale to [-1, 1]""" + out = color_transfer_pytorch(src, target) # [0, 1] + out = (out / 0.5) - 1 + return out + +def gaussian_noise(): + gaussian_mean = torch.tensor(0.0) + gaussian_std = torch.tensor(0.1) + return torch.normal(gaussian_mean, gaussian_std) + +def convert_to_readable(seconds): + return time.strftime('%H:%M:%S', time.gmtime(seconds)) + + +def revert_to_np_image(image_tensor): + image = image_tensor.cpu().numpy() + # CHW + image = image.transpose(1, 2, 0) + image = denormalize_input(image, dtype=np.int16) + return image[..., ::-1] # to RGB + + +def save_generated_images(images: torch.Tensor, save_dir: str): + """Save generated images `(*, 3, H, W)` range [-1, 1] into disk""" + os.makedirs(save_dir, exist_ok=True) + images = images.clone().detach().cpu().numpy() + images = images.transpose(0, 2, 3, 1) + n_images = len(images) + + for i in range(n_images): + img = images[i] + img = denormalize_input(img, dtype=np.int16) + img = img[..., ::-1] + cv2.imwrite(os.path.join(save_dir, f"G{i}.jpg"), img) + + +class DDPTrainer: + def _init_distributed(self): + if self.cfg.ddp: + self.logger.info("Setting up DDP") + self.pg = torch.distributed.init_process_group( + backend="nccl", + rank=self.cfg.local_rank, + world_size=self.cfg.world_size + ) + self.G = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.G, self.pg) + self.D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.D, self.pg) + torch.cuda.set_device(self.cfg.local_rank) + self.G.cuda(self.cfg.local_rank) + self.D.cuda(self.cfg.local_rank) + self.logger.info("Setting up DDP Done") + + def _init_amp(self, enabled=False): + # self.scaler = torch.cuda.amp.GradScaler(enabled=enabled, growth_interval=100) + self.scaler_g = GradScaler(enabled=enabled) + self.scaler_d = GradScaler(enabled=enabled) + if self.cfg.ddp: + self.G = DistributedDataParallel( + self.G, device_ids=[self.cfg.local_rank], + output_device=self.cfg.local_rank, + find_unused_parameters=False) + + self.D = DistributedDataParallel( + self.D, device_ids=[self.cfg.local_rank], + output_device=self.cfg.local_rank, + find_unused_parameters=False) + self.logger.info("Set DistributedDataParallel") + + +class Trainer(DDPTrainer): + """ + Base Trainer class + """ + + def __init__( + self, + generator, + discriminator, + config, + logger, + ) -> None: + self.G = generator + self.D = discriminator + self.cfg = config + self.max_norm = 10 + self.device_type = 'cuda' if self.cfg.device.startswith('cuda') else 'cpu' + self.optimizer_g = optim.Adam(self.G.parameters(), lr=self.cfg.lr_g, betas=(0.5, 0.999)) + self.optimizer_d = optim.Adam(self.D.parameters(), lr=self.cfg.lr_d, betas=(0.5, 0.999)) + self.loss_tracker = LossSummary() + if self.cfg.ddp: + self.device = torch.device(f"cuda:{self.cfg.local_rank}") + logger.info(f"---------{self.cfg.local_rank} {self.device}") + else: + self.device = torch.device(self.cfg.device) + self.loss_fn = AnimeGanLoss(self.cfg, self.device, self.cfg.gray_adv) + self.logger = logger + self._init_working_dir() + self._init_distributed() + self._init_amp(enabled=self.cfg.amp) + + def _init_working_dir(self): + """Init working directory for saving checkpoint, ...""" + os.makedirs(self.cfg.exp_dir, exist_ok=True) + Gname = self.G.name + Dname = self.D.name + self.checkpoint_path_G_init = os.path.join(self.cfg.exp_dir, f"{Gname}_init.pt") + self.checkpoint_path_G = os.path.join(self.cfg.exp_dir, f"{Gname}.pt") + self.checkpoint_path_D = os.path.join(self.cfg.exp_dir, f"{Dname}.pt") + self.save_image_dir = os.path.join(self.cfg.exp_dir, "generated_images") + self.example_image_dir = os.path.join(self.cfg.exp_dir, "train_images") + os.makedirs(self.save_image_dir, exist_ok=True) + os.makedirs(self.example_image_dir, exist_ok=True) + + def init_weight_G(self, weight: str): + """Init Generator weight""" + return load_checkpoint(self.G, weight) + + def init_weight_D(self, weight: str): + """Init Discriminator weight""" + return load_checkpoint(self.D, weight) + + def pretrain_generator(self, train_loader, start_epoch): + """ + Pretrain Generator to recontruct input image. + """ + init_losses = [] + set_lr(self.optimizer_g, self.cfg.init_lr) + for epoch in range(start_epoch, self.cfg.init_epochs): + # Train with content loss only + + pbar = tqdm(train_loader) + for data in pbar: + img = data["image"].to(self.device) + + self.optimizer_g.zero_grad() + + with autocast(enabled=self.cfg.amp): + fake_img = self.G(img) + loss = self.loss_fn.content_loss_vgg(img, fake_img) + + self.scaler_g.scale(loss).backward() + self.scaler_g.step(self.optimizer_g) + self.scaler_g.update() + + if self.cfg.ddp: + torch.distributed.barrier() + + init_losses.append(loss.cpu().detach().numpy()) + avg_content_loss = sum(init_losses) / len(init_losses) + pbar.set_description(f'[Init Training G] content loss: {avg_content_loss:2f}') + + save_checkpoint(self.G, self.checkpoint_path_G_init, self.optimizer_g, epoch) + if self.cfg.local_rank == 0: + self.generate_and_save(self.cfg.test_image_dir, subname='initg') + self.logger.info(f"Epoch {epoch}/{self.cfg.init_epochs}") + + set_lr(self.optimizer_g, self.cfg.lr_g) + + def train_epoch(self, epoch, train_loader): + pbar = tqdm(train_loader, total=len(train_loader)) + for data in pbar: + img = data["image"].to(self.device) + anime = data["anime"].to(self.device) + anime_gray = data["anime_gray"].to(self.device) + anime_smt_gray = data["smooth_gray"].to(self.device) + + # ---------------- TRAIN D ---------------- # + self.optimizer_d.zero_grad() + + with autocast(enabled=self.cfg.amp): + fake_img = self.G(img) + # Add some Gaussian noise to images before feeding to D + if self.cfg.d_noise: + fake_img += gaussian_noise() + anime += gaussian_noise() + anime_gray += gaussian_noise() + anime_smt_gray += gaussian_noise() + + if self.cfg.gray_adv: + fake_img = to_gray_scale(fake_img) + + fake_d = self.D(fake_img) + real_anime_d = self.D(anime) + real_anime_gray_d = self.D(anime_gray) + real_anime_smt_gray_d = self.D(anime_smt_gray) + + loss_d = self.loss_fn.compute_loss_D( + fake_d, + real_anime_d, + real_anime_gray_d, + real_anime_smt_gray_d + ) + + self.scaler_d.scale(loss_d).backward() + self.scaler_d.unscale_(self.optimizer_d) + torch.nn.utils.clip_grad_norm_(self.D.parameters(), max_norm=self.max_norm) + self.scaler_d.step(self.optimizer_d) + self.scaler_d.update() + if self.cfg.ddp: + torch.distributed.barrier() + self.loss_tracker.update_loss_D(loss_d) + + # ---------------- TRAIN G ---------------- # + self.optimizer_g.zero_grad() + + with autocast(enabled=self.cfg.amp): + fake_img = self.G(img) + + if self.cfg.gray_adv: + fake_d = self.D(to_gray_scale(fake_img)) + else: + fake_d = self.D(fake_img) + + ( + adv_loss, con_loss, + gra_loss, col_loss, + tv_loss + ) = self.loss_fn.compute_loss_G( + fake_img, + img, + fake_d, + anime_gray, + ) + loss_g = adv_loss + con_loss + gra_loss + col_loss + tv_loss + if torch.isnan(adv_loss).any(): + self.logger.info("----------------------------------------------") + self.logger.info(fake_d) + self.logger.info(adv_loss) + self.logger.info("----------------------------------------------") + raise ValueError("NAN loss!!") + + self.scaler_g.scale(loss_g).backward() + self.scaler_d.unscale_(self.optimizer_g) + grad = torch.nn.utils.clip_grad_norm_(self.G.parameters(), max_norm=self.max_norm) + self.scaler_g.step(self.optimizer_g) + self.scaler_g.update() + if self.cfg.ddp: + torch.distributed.barrier() + + self.loss_tracker.update_loss_G(adv_loss, gra_loss, col_loss, con_loss) + pbar.set_description(f"{self.loss_tracker.get_loss_description()} - {grad:.3f}") + + def get_train_loader(self, dataset): + if self.cfg.ddp: + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + else: + train_sampler = None + return DataLoader( + dataset, + batch_size=self.cfg.batch_size, + num_workers=self.cfg.num_workers, + pin_memory=True, + shuffle=train_sampler is None, + sampler=train_sampler, + drop_last=True, + # collate_fn=collate_fn, + ) + + def maybe_increase_imgsz(self, epoch, train_dataset): + """ + Increase image size at specific epoch + + 50% epochs train at imgsz[0] + + the rest 50% will increase every `len(epochs) / 2 / (len(imgsz) - 1)` + + Args: + epoch: Current epoch + train_dataset: Dataset + + Examples: + ``` + epochs = 100 + imgsz = [256, 352, 416, 512] + => [(0, 256), (50, 352), (66, 416), (82, 512)] + ``` + """ + epochs = self.cfg.epochs + imgsz = self.cfg.imgsz + num_size_remains = len(imgsz) - 1 + half_epochs = epochs // 2 + + if len(imgsz) == 1: + new_size = imgsz[0] + elif epoch < half_epochs: + new_size = imgsz[0] + else: + per_epoch_increment = int(half_epochs / num_size_remains) + found = None + for i, size in enumerate(imgsz[:]): + if epoch < half_epochs + per_epoch_increment * i: + found = size + break + if not found: + found = imgsz[-1] + new_size = found + + self.logger.info(f"Check {imgsz}, {new_size}, {train_dataset.imgsz}") + if new_size != train_dataset.imgsz: + train_dataset.set_imgsz(new_size) + self.logger.info(f"Increase image size to {new_size} at epoch {epoch}") + + def train(self, train_dataset: Dataset, start_epoch=0, start_epoch_g=0): + """ + Train Generator and Discriminator. + """ + self.logger.info(self.device) + self.G.to(self.device) + self.D.to(self.device) + + self.pretrain_generator(self.get_train_loader(train_dataset), start_epoch_g) + + if self.cfg.local_rank == 0: + self.logger.info(f"Start training for {self.cfg.epochs} epochs") + + for i, data in enumerate(train_dataset): + for k in data.keys(): + image = data[k] + cv2.imwrite( + os.path.join(self.example_image_dir, f"data_{k}_{i}.jpg"), + revert_to_np_image(image) + ) + if i == 2: + break + + end = None + num_iter = 0 + per_epoch_times = [] + for epoch in range(start_epoch, self.cfg.epochs): + self.maybe_increase_imgsz(epoch, train_dataset) + + start = time.time() + self.train_epoch(epoch, self.get_train_loader(train_dataset)) + + if epoch % self.cfg.save_interval == 0 and self.cfg.local_rank == 0: + save_checkpoint(self.G, self.checkpoint_path_G,self.optimizer_g, epoch) + save_checkpoint(self.D, self.checkpoint_path_D, self.optimizer_d, epoch) + self.generate_and_save(self.cfg.test_image_dir) + + if epoch % 10 == 0: + self.copy_results(epoch) + + num_iter += 1 + + if self.cfg.local_rank == 0: + end = time.time() + if end is None: + eta = 9999 + else: + per_epoch_time = (end - start) + per_epoch_times.append(per_epoch_time) + eta = np.mean(per_epoch_times) * (self.cfg.epochs - epoch) + eta = convert_to_readable(eta) + self.logger.info(f"epoch {epoch}/{self.cfg.epochs}, ETA: {eta}") + + def generate_and_save( + self, + image_dir, + max_imgs=15, + subname='gen' + ): + ''' + Generate and save images + ''' + start = time.time() + self.G.eval() + + max_iter = max_imgs + fake_imgs = [] + real_imgs = [] + image_files = glob(os.path.join(image_dir, "*")) + + for i, image_file in enumerate(image_files): + image = read_image(image_file) + image = resize_image(image) + real_imgs.append(image.copy()) + image = preprocess_images(image) + image = image.to(self.device) + with torch.no_grad(): + with autocast(enabled=self.cfg.amp): + fake_img = self.G(image) + # fake_img = to_gray_scale(fake_img) + fake_img = fake_img.detach().cpu().numpy() + # Channel first -> channel last + fake_img = fake_img.transpose(0, 2, 3, 1) + fake_imgs.append(denormalize_input(fake_img, dtype=np.int16)[0]) + + if i + 1 == max_iter: + break + + # fake_imgs = np.concatenate(fake_imgs, axis=0) + + for i, (real_img, fake_img) in enumerate(zip(real_imgs, fake_imgs)): + img = np.concatenate((real_img, fake_img), axis=1) # Concate aross width + save_path = os.path.join(self.save_image_dir, f'{subname}_{i}.jpg') + if not cv2.imwrite(save_path, img[..., ::-1]): + self.logger.info(f"Save generated image failed, {save_path}, {img.shape}") + elapsed = time.time() - start + self.logger.info(f"Generated {len(fake_imgs)} images in {elapsed:.3f}s.") + + def copy_results(self, epoch): + """Copy result (Weight + Generated images) to each epoch folder + Every N epoch + """ + copy_dir = os.path.join(self.cfg.exp_dir, f"epoch_{epoch}") + os.makedirs(copy_dir, exist_ok=True) + + shutil.copy2( + self.checkpoint_path_G, + copy_dir + ) + + dest = os.path.join(copy_dir, os.path.basename(self.save_image_dir)) + shutil.copytree( + self.save_image_dir, + dest, + dirs_exist_ok=True + ) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5eb1d53c3e0a53bf5c9c57edbebea70516fb5a --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,21 @@ +from .common import * +from .image_processing import * + +class DefaultArgs: + dataset ='Hayao' + data_dir ='/content' + epochs = 10 + batch_size = 1 + checkpoint_dir ='/content/checkpoints' + save_image_dir ='/content/images' + display_image =True + save_interval =2 + debug_samples =0 + lr_g = 0.001 + lr_d = 0.002 + wadvg = 300.0 + wadvd = 300.0 + wcon = 1.5 + wgra = 3 + wcol = 10 + use_sn = False diff --git a/utils/common.py b/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..478d989b0495e246e4027e66dce75c49356ab369 --- /dev/null +++ b/utils/common.py @@ -0,0 +1,188 @@ +import torch +import gc +import os +import torch.nn as nn +import urllib.request +import cv2 +from tqdm import tqdm + +HTTP_PREFIXES = [ + 'http', + 'data:image/jpeg', +] + + +RELEASED_WEIGHTS = { + "hayao:v1": ( + "v1", + "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth" + ), + "hayao": ( + "v1", + "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth" + ), + "shinkai:v1": ( + "v1", + "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth" + ), + "shinkai": ( + "v1", + "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth" + ), + + ## VER 2 ## + "hayao:v2": ( + # Dataset trained on Google Landmark micro as training real photo + "v2", + "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.2/GeneratorV2_gldv2_Hayao.pt" + ), + "shinkai:v2": ( + # Dataset trained on Google Landmark micro as training real photo + "v2", + "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.2/GeneratorV2_gldv2_Shinkai.pt" + ), + ## Face portrait + "arcane:v2": ( + "v2", + "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.2/GeneratorV2_ffhq_Arcane_512.pt" + ) +} + +def is_image_file(path): + _, ext = os.path.splitext(path) + return ext.lower() in (".png", ".jpg", ".jpeg", ".webp") + +def is_video_file(path): + # https://moviepy-tburrows13.readthedocs.io/en/improve-docs/ref/VideoClip/VideoFileClip.html + _, ext = os.path.splitext(path) + return ext.lower() in (".mp4", ".mov", ".ogv", ".avi", ".mpeg") + + +def read_image(path): + """ + Read image from given path + """ + + if any(path.startswith(p) for p in HTTP_PREFIXES): + urllib.request.urlretrieve(path, "temp.jpg") + path = "temp.jpg" + + img = cv2.imread(path) + if img.shape[-1] == 4: + # 4 channels image + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def save_checkpoint(model, path, optimizer=None, epoch=None): + checkpoint = { + 'model_state_dict': model.state_dict(), + 'epoch': epoch, + } + if optimizer is not None: + checkpoint['optimizer_state_dict'] = optimizer.state_dict() + + torch.save(checkpoint, path) + +def maybe_remove_module(state_dict): + # Remove added module ins state_dict in ddp training + # https://discuss.pytorch.org/t/why-are-state-dict-keys-getting-prepended-with-the-string-module/104627/3 + new_state_dict = {} + module_str = 'module.' + for k, v in state_dict.items(): + + if k.startswith(module_str): + k = k[len(module_str):] + new_state_dict[k] = v + return new_state_dict + + +def load_checkpoint(model, path, optimizer=None, strip_optimizer=False, map_location=None) -> int: + state_dict, path = load_state_dict(path, map_location) + model_state_dict = maybe_remove_module(state_dict['model_state_dict']) + model.load_state_dict( + model_state_dict, + strict=True + ) + if 'optimizer_state_dict' in state_dict: + if optimizer is not None: + optimizer.load_state_dict(state_dict['optimizer_state_dict']) + if strip_optimizer: + del state_dict["optimizer_state_dict"] + torch.save(state_dict, path) + print(f"Optimizer stripped and saved to {path}") + + epoch = state_dict.get('epoch', 0) + return epoch + + +def load_state_dict(weight, map_location) -> dict: + if weight.lower() in RELEASED_WEIGHTS: + weight = _download_weight(weight.lower()) + + if map_location is None: + # auto select + map_location = 'cuda' if torch.cuda.is_available() else 'cpu' + state_dict = torch.load(weight, map_location=map_location) + + return state_dict, weight + + +def initialize_weights(net): + for m in net.modules(): + try: + if isinstance(m, nn.Conv2d): + # m.weight.data.normal_(0, 0.02) + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.zero_() + elif isinstance(m, nn.ConvTranspose2d): + # m.weight.data.normal_(0, 0.02) + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + # m.weight.data.normal_(0, 0.02) + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + except Exception as e: + # print(f'SKip layer {m}, {e}') + pass + + +def set_lr(optimizer, lr): + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +class DownloadProgressBar(tqdm): + ''' + https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads + ''' + def update_to(self, b=1, bsize=1, tsize=None): + if tsize is not None: + self.total = tsize + self.update(b * bsize - self.n) + + +def _download_weight(weight): + ''' + Download weight and save to local file + ''' + os.makedirs('.cache', exist_ok=True) + url = RELEASED_WEIGHTS[weight][1] + filename = os.path.basename(url) + save_path = f'.cache/{filename}' + + if os.path.isfile(save_path): + return save_path + + desc = f'Downloading {url} to {save_path}' + with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t: + urllib.request.urlretrieve(url, save_path, reporthook=t.update_to) + + return save_path + diff --git a/utils/fast_numpyio.py b/utils/fast_numpyio.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4ebf822fa3b874c8c2c596da6dd9f6385b9108 --- /dev/null +++ b/utils/fast_numpyio.py @@ -0,0 +1,43 @@ +# code from https://github.com/divideconcept/fastnumpyio/blob/main/fastnumpyio.py + +import sys +import numpy as np +import numpy.lib.format +import struct + +def save(file, array): + magic_string=b"\x93NUMPY\x01\x00v\x00" + header=bytes(("{'descr': '"+array.dtype.descr[0][1]+"', 'fortran_order': False, 'shape': "+str(array.shape)+", }").ljust(127-len(magic_string))+"\n",'utf-8') + if type(file) == str: + file=open(file,"wb") + file.write(magic_string) + file.write(header) + file.write(array.data) + +def pack(array): + size=len(array.shape) + return bytes(array.dtype.byteorder.replace('=','<' if sys.byteorder == 'little' else '>')+array.dtype.kind,'utf-8')+array.dtype.itemsize.to_bytes(1,byteorder='little')+struct.pack(f' 1.0e2] = 1.0e2 + # x[x < -1.0e2] = -1.0e2 + + G = torch.mm(x, x.T) + G = torch.clamp(G, -64990.0, 64990.0) + # normalize by total elements + result = G.div(b * c * w * h) + return result + + + +def divisible(dim): + ''' + Make width and height divisible by 32 + ''' + width, height = dim + return width - (width % 32), height - (height % 32) + + +def resize_image(image, width=None, height=None, inter=cv2.INTER_AREA): + dim = None + h, w = image.shape[:2] + + if width and height: + return cv2.resize(image, divisible((width, height)), interpolation=inter) + + if width is None and height is None: + return cv2.resize(image, divisible((w, h)), interpolation=inter) + + if width is None: + r = height / float(h) + dim = (int(w * r), height) + + else: + r = width / float(w) + dim = (width, int(h * r)) + + return cv2.resize(image, divisible(dim), interpolation=inter) + + +def normalize_input(images): + ''' + [0, 255] -> [-1, 1] + ''' + return images / 127.5 - 1.0 + + +def denormalize_input(images, dtype=None): + ''' + [-1, 1] -> [0, 255] + ''' + images = images * 127.5 + 127.5 + + if dtype is not None: + if isinstance(images, torch.Tensor): + images = images.type(dtype) + else: + # numpy.ndarray + images = images.astype(dtype) + + return images + + +def preprocess_images(images): + ''' + Preprocess image for inference + + @Arguments: + - images: np.ndarray + + @Returns + - images: torch.tensor + ''' + images = images.astype(np.float32) + + # Normalize to [-1, 1] + images = normalize_input(images) + images = torch.from_numpy(images) + + # Add batch dim + if len(images.shape) == 3: + images = images.unsqueeze(0) + + # channel first + images = images.permute(0, 3, 1, 2) + + return images + +def compute_data_mean(data_folder): + if not os.path.exists(data_folder): + raise FileNotFoundError(f'Folder {data_folder} does not exits') + + image_files = os.listdir(data_folder) + total = np.zeros(3) + + print(f"Compute mean (R, G, B) from {len(image_files)} images") + + for img_file in tqdm(image_files): + path = os.path.join(data_folder, img_file) + image = cv2.imread(path) + total += image.mean(axis=(0, 1)) + + channel_mean = total / len(image_files) + mean = np.mean(channel_mean) + + return mean - channel_mean[...,::-1] # Convert to BGR for training + + +if __name__ == '__main__': + t = torch.rand(2, 14, 32, 32) + + with torch.autocast("cpu"): + print(gram(t)) diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f72fcbffaafa4b9c2d32e161a2a149e7b22ea556 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,24 @@ +import logging + + +def get_logger(path, *args, **kwargs): + # logger = logging.getLogger('train') + # logger.setLevel(logging.NOTSET) + # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + # # add filehandler + # fh = logging.FileHandler(path) + # fh.setLevel(logging.NOTSET) + # fh.setFormatter(formatter) + # ch = logging.StreamHandler() + # ch.setLevel(logging.ERROR) + # logger.addHandler(fh) + # logger.addHandler(ch) + # return logger + logging.basicConfig(format = '%(asctime)s %(message)s', + datefmt = '%m/%d/%Y %I:%M:%S %p', + handlers=[ + logging.FileHandler(path), + logging.StreamHandler() + ], + level=logging.DEBUG) + return logging \ No newline at end of file