Spaces:
Runtime error
Runtime error
File size: 5,452 Bytes
47c46ea bca4552 47c46ea 34bf162 a965396 47c46ea 69542b0 da5331d 69542b0 8d584e9 47c46ea 4a93191 47c46ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
from argparse import Namespace
import os
from os.path import join as pjoin
import random
import sys
from typing import (
Iterable,
Optional,
)
import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import (
Compose,
Grayscale,
Resize,
ToTensor,
Normalize,
)
from losses.joint_loss import JointLoss
from model import Generator
from tools.initialize import Initializer
from tools.match_skin_histogram import match_skin_histogram
from utils.projector_arguments import ProjectorArguments
from utils import torch_helpers as th
from utils.torch_helpers import make_image
from utils.misc import stem
from utils.optimize import Optimizer
from models.degrade import (
Degrade,
Downsample,
)
from huggingface_hub import hf_hub_download
TOKEN = "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"
def set_random_seed(seed: int):
# FIXME (xuanluo): this setup still allows randomness somehow
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def read_images(paths: str, max_size: Optional[int] = None):
transform = Compose(
[
Grayscale(),
ToTensor(),
]
)
imgs = []
for path in paths:
img = Image.open(path)
if max_size is not None and img.width > max_size:
img = img.resize((max_size, max_size))
img = transform(img)
imgs.append(img)
imgs = torch.stack(imgs, 0)
return imgs
def normalize(img: torch.Tensor, mean=0.5, std=0.5):
"""[0, 1] -> [-1, 1]"""
return (img - mean) / std
def create_generator(file_name: str, path:str,args: Namespace, device: torch.device):
path = hf_hub_download(f'{path}',
f'{file_name}',
use_auth_token=TOKEN)
with open(path, 'rb') as f:
generator = Generator(args.generator_size, 512, 8)
generator.load_state_dict(torch.load(f)['g_ema'], strict=False)
generator.eval()
generator.to(device)
return generator
def save(
path_prefixes: Iterable[str],
imgs: torch.Tensor, # BCHW
latents: torch.Tensor,
noises: torch.Tensor,
imgs_rand: Optional[torch.Tensor] = None,
):
assert len(path_prefixes) == len(imgs) and len(latents) == len(path_prefixes)
if imgs_rand is not None:
assert len(imgs) == len(imgs_rand)
imgs_arr = make_image(imgs)
for path_prefix, img, latent, noise in zip(path_prefixes, imgs_arr, latents, noises):
os.makedirs(os.path.dirname(path_prefix), exist_ok=True)
cv2.imwrite(path_prefix + ".png", img[...,::-1])
torch.save({"latent": latent.detach().cpu(), "noise": noise.detach().cpu()},
path_prefix + ".pt")
if imgs_rand is not None:
imgs_arr = make_image(imgs_rand)
for path_prefix, img in zip(path_prefixes, imgs_arr):
cv2.imwrite(path_prefix + "-rand.png", img[...,::-1])
def main(args):
opt_str = ProjectorArguments.to_string(args)
print(opt_str)
if args.rand_seed is not None:
set_random_seed(args.rand_seed)
device = th.device()
# read inputs. TODO imgs_orig has channel 1
imgs_orig = read_images([args.input], max_size=args.generator_size).to(device)
imgs = normalize(imgs_orig) # actually this will be overwritten by the histogram matching result
# initialize
with torch.no_grad():
init = Initializer(args).to(device)
latent_init = init(imgs_orig)
# create generator
generator = create_generator(args, device)
# init noises
with torch.no_grad():
noises_init = generator.make_noise()
# create a new input by matching the input's histogram to the sibling image
with torch.no_grad():
sibling, _, sibling_rgbs = generator([latent_init], input_is_latent=True, noise=noises_init)
mh_dir = pjoin(args.results_dir, stem(args.input))
imgs = match_skin_histogram(
imgs, sibling,
args.spectral_sensitivity,
pjoin(mh_dir, "input_sibling"),
pjoin(mh_dir, "skin_mask"),
matched_hist_fn=mh_dir.rstrip(os.sep) + f"_{args.spectral_sensitivity}.png",
normalize=normalize,
).to(device)
torch.cuda.empty_cache()
# TODO imgs has channel 3
degrade = Degrade(args).to(device)
rgb_levels = generator.get_latent_size(args.coarse_min) // 2 + len(args.wplus_step) - 1
criterion = JointLoss(
args, imgs,
sibling=sibling.detach(), sibling_rgbs=sibling_rgbs[:rgb_levels]).to(device)
# save initialization
save(
[pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}-init")],
sibling, latent_init, noises_init,
)
writer = SummaryWriter(pjoin(args.log_dir, f"{stem(args.input)}/{opt_str}"))
# start optimize
latent, noises = Optimizer.optimize(generator, criterion, degrade, imgs, latent_init, noises_init, args, writer=writer)
# generate output
img_out, _, _ = generator([latent], input_is_latent=True, noise=noises)
img_out_rand_noise, _, _ = generator([latent], input_is_latent=True)
# save output
save(
[pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}")],
img_out, latent, noises,
imgs_rand=img_out_rand_noise
)
def parse_args():
return ProjectorArguments().parse()
if __name__ == "__main__":
sys.exit(main(parse_args()))
|