feng2022's picture
anothertry
89d1ee7
raw
history blame contribute delete
No virus
4.74 kB
from argparse import ArgumentParser, Namespace
from typing import (
List,
Tuple,
)
import numpy as np
from PIL import Image
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import (
Compose,
Grayscale,
Resize,
ToTensor,
)
from models.encoder import Encoder
from models.encoder4editing import (
get_latents as get_e4e_latents,
setup_model as setup_e4e_model,
)
from utils.misc import (
optional_string,
iterable_to_str,
stem,
)
class ColorEncoderArguments:
def __init__(self):
parser = ArgumentParser("Encode an image via a feed-forward encoder")
self.add_arguments(parser)
self.parser = parser
@staticmethod
def add_arguments(parser: ArgumentParser):
parser.add_argument("--encoder_ckpt", default=None,
help="encoder checkpoint path. initialize w with encoder output if specified")
parser.add_argument("--encoder_size", type=int, default=256,
help="Resize to this size to pass as input to the encoder")
class InitializerArguments:
@classmethod
def add_arguments(cls, parser: ArgumentParser):
ColorEncoderArguments.add_arguments(parser)
cls.add_e4e_arguments(parser)
parser.add_argument("--mix_layer_range", default=[10, 18], type=int, nargs=2,
help="replace layers <start> to <end> in the e4e code by the color code")
parser.add_argument("--init_latent", default=None, help="path to init wp")
@staticmethod
def to_string(args: Namespace):
return (f"init{stem(args.init_latent).lstrip('0')[:10]}" if args.init_latent
else f"init({iterable_to_str(args.mix_layer_range)})")
#+ optional_string(args.init_noise > 0, f"-initN{args.init_noise}")
@staticmethod
def add_e4e_arguments(parser: ArgumentParser):
parser.add_argument("--e4e_ckpt", default='checkpoint/e4e_ffhq_encode.pt',
help="e4e checkpoint path.")
parser.add_argument("--e4e_size", type=int, default=256,
help="Resize to this size to pass as input to the e4e")
def create_color_encoder(args: Namespace):
encoder = Encoder(1, args.encoder_size, 512)
ckpt = torch.load(args.encoder_ckpt)
encoder.load_state_dict(ckpt["model"])
return encoder
def transform_input(img: Image):
tsfm = Compose([
Grayscale(),
Resize(args.encoder_size),
ToTensor(),
])
return tsfm(img)
def encode_color(imgs: torch.Tensor, args: Namespace) -> torch.Tensor:
assert args.encoder_size is not None
imgs = Resize(args.encoder_size)(imgs)
color_encoder = create_color_encoder(args).to(imgs.device)
color_encoder.eval()
with torch.no_grad():
latent = color_encoder(imgs)
return latent.detach()
def resize(imgs: torch.Tensor, size: int) -> torch.Tensor:
return F.interpolate(imgs, size=size, mode='bilinear')
class Initializer(nn.Module):
def __init__(self, args: Namespace):
super().__init__()
self.path = None
if args.init_latent is not None:
self.path = args.init_latent
return
assert args.encoder_size is not None
self.color_encoder = create_color_encoder(args)
self.color_encoder.eval()
self.color_encoder_size = args.encoder_size
self.e4e, e4e_opts = setup_e4e_model(args.e4e_ckpt)
assert 'cars_' not in e4e_opts.dataset_type
self.e4e.decoder.eval()
self.e4e.eval()
self.e4e_size = args.e4e_size
self.mix_layer_range = args.mix_layer_range
def encode_color(self, imgs: torch.Tensor) -> torch.Tensor:
"""
Get the color W code
"""
imgs = resize(imgs, self.color_encoder_size)
latent = self.color_encoder(imgs)
return latent
def encode_shape(self, imgs: torch.Tensor) -> torch.Tensor:
imgs = resize(imgs, self.e4e_size)
imgs = (imgs - 0.5) / 0.5
if imgs.shape[1] == 1: # 1 channel
imgs = imgs.repeat(1, 3, 1, 1)
return get_e4e_latents(self.e4e, imgs)
def load(self, device: torch.device):
latent_np = np.load(self.path)
return torch.tensor(latent_np, device=device)[None, ...]
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
if self.path is not None:
return self.load(imgs.device)
shape_code = self.encode_shape(imgs)
color_code = self.encode_color(imgs)
# style mix
latent = shape_code
start, end = self.mix_layer_range
latent[:, start:end] = color_code
return latent