from argparse import Namespace import os from os.path import join as pjoin import random import sys from typing import ( Iterable, Optional, ) import cv2 import numpy as np from PIL import Image import torch from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import ( Compose, Grayscale, Resize, ToTensor, Normalize, ) from losses.joint_loss import JointLoss from model import Generator from tools.initialize import Initializer from tools.match_skin_histogram import match_skin_histogram from utils.projector_arguments import ProjectorArguments from utils import torch_helpers as th from utils.torch_helpers import make_image from utils.misc import stem from utils.optimize import Optimizer from models.degrade import ( Degrade, Downsample, ) from huggingface_hub import hf_hub_download TOKEN = "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" def set_random_seed(seed: int): # FIXME (xuanluo): this setup still allows randomness somehow torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) def read_images(paths: str, max_size: Optional[int] = None): transform = Compose( [ Grayscale(), ToTensor(), ] ) imgs = [] for path in paths: img = Image.open(path) if max_size is not None and img.width > max_size: img = img.resize((max_size, max_size)) img = transform(img) imgs.append(img) imgs = torch.stack(imgs, 0) return imgs def normalize(img: torch.Tensor, mean=0.5, std=0.5): """[0, 1] -> [-1, 1]""" return (img - mean) / std def create_generator(file_name: str, path:str,args: Namespace, device: torch.device): path = hf_hub_download(f'{path}', f'{file_name}', use_auth_token=TOKEN) with open(path, 'rb') as f: generator = Generator(args.generator_size, 512, 8) generator.load_state_dict(torch.load(f)['g_ema'], strict=False) generator.eval() generator.to(device) return generator def save( path_prefixes: Iterable[str], imgs: torch.Tensor, # BCHW latents: torch.Tensor, noises: torch.Tensor, imgs_rand: Optional[torch.Tensor] = None, ): assert len(path_prefixes) == len(imgs) and len(latents) == len(path_prefixes) if imgs_rand is not None: assert len(imgs) == len(imgs_rand) imgs_arr = make_image(imgs) for path_prefix, img, latent, noise in zip(path_prefixes, imgs_arr, latents, noises): os.makedirs(os.path.dirname(path_prefix), exist_ok=True) cv2.imwrite(path_prefix + ".png", img[...,::-1]) torch.save({"latent": latent.detach().cpu(), "noise": noise.detach().cpu()}, path_prefix + ".pt") if imgs_rand is not None: imgs_arr = make_image(imgs_rand) for path_prefix, img in zip(path_prefixes, imgs_arr): cv2.imwrite(path_prefix + "-rand.png", img[...,::-1]) def main(args): opt_str = ProjectorArguments.to_string(args) print(opt_str) if args.rand_seed is not None: set_random_seed(args.rand_seed) device = th.device() # read inputs. TODO imgs_orig has channel 1 imgs_orig = read_images([args.input], max_size=args.generator_size).to(device) imgs = normalize(imgs_orig) # actually this will be overwritten by the histogram matching result # initialize with torch.no_grad(): init = Initializer(args).to(device) latent_init = init(imgs_orig) # create generator generator = create_generator(args, device) # init noises with torch.no_grad(): noises_init = generator.make_noise() # create a new input by matching the input's histogram to the sibling image with torch.no_grad(): sibling, _, sibling_rgbs = generator([latent_init], input_is_latent=True, noise=noises_init) mh_dir = pjoin(args.results_dir, stem(args.input)) imgs = match_skin_histogram( imgs, sibling, args.spectral_sensitivity, pjoin(mh_dir, "input_sibling"), pjoin(mh_dir, "skin_mask"), matched_hist_fn=mh_dir.rstrip(os.sep) + f"_{args.spectral_sensitivity}.png", normalize=normalize, ).to(device) torch.cuda.empty_cache() # TODO imgs has channel 3 degrade = Degrade(args).to(device) rgb_levels = generator.get_latent_size(args.coarse_min) // 2 + len(args.wplus_step) - 1 criterion = JointLoss( args, imgs, sibling=sibling.detach(), sibling_rgbs=sibling_rgbs[:rgb_levels]).to(device) # save initialization save( [pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}-init")], sibling, latent_init, noises_init, ) writer = SummaryWriter(pjoin(args.log_dir, f"{stem(args.input)}/{opt_str}")) # start optimize latent, noises = Optimizer.optimize(generator, criterion, degrade, imgs, latent_init, noises_init, args, writer=writer) # generate output img_out, _, _ = generator([latent], input_is_latent=True, noise=noises) img_out_rand_noise, _, _ = generator([latent], input_is_latent=True) # save output save( [pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}")], img_out, latent, noises, imgs_rand=img_out_rand_noise ) def parse_args(): return ProjectorArguments().parse() if __name__ == "__main__": sys.exit(main(parse_args()))