Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import pickle | |
import torch | |
import numpy as np | |
import torchvision | |
import sys | |
sys.path.append(".") | |
sys.path.append("..") | |
from hyperstyle_global_directions.global_direction import StyleCLIPGlobalDirection | |
from models.stylegan2.model import Generator | |
def parse_args(args_list=None): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--exp_dir", type=str, default="./experiment", | |
help="Path to inference results with `latents.npy` saved here (obtained with inference.py).") | |
parser.add_argument("--weight_deltas_path", type=str, default="./weight_deltas", | |
help="Root path holding all weight deltas (obtained by running inference.py).") | |
parser.add_argument('--n_images', type=int, default=None, | |
help="Maximum number of images to edit. If None, edit all images.") | |
parser.add_argument("--neutral_text", type=str, default="face with hair") | |
parser.add_argument("--target_text", type=str, default="face with long hair") | |
parser.add_argument("--stylegan_weights", type=str, default='../pretrained_models/stylegan2-ffhq-config-f.pt') | |
parser.add_argument("--stylegan_size", type=int, default=1024) | |
parser.add_argument("--stylegan_truncation", type=int, default=1.) | |
parser.add_argument("--stylegan_truncation_mean", type=int, default=4096) | |
parser.add_argument("--beta", type=float, default=0.14) | |
parser.add_argument("--alpha", type=float, default=4.1) | |
parser.add_argument("--weight_delta_beta", type=float, default=None) | |
parser.add_argument("--weight_delta_alpha", type=float, default=None) | |
parser.add_argument("--delta_i_c", type=str, default='../hyperstyle_global_directions/global_directions/ffhq/fs3.npy', | |
help="path to file containing delta_i_c") | |
parser.add_argument("--s_statistics", type=str, default='../hyperstyle_global_directions/global_directions/ffhq/S_mean_std', | |
help="path to file containing s statistics") | |
parser.add_argument("--text_prompt_templates", default='../hyperstyle_global_directions/global_directions/templates.txt') | |
args = parser.parse_args(args_list) | |
return args | |
def load_direction_calculator(args): | |
delta_i_c = torch.from_numpy(np.load(args.delta_i_c)).float().to(args.device) | |
with open(args.s_statistics, "rb") as channels_statistics: | |
_, s_std = pickle.load(channels_statistics) | |
s_std = [torch.from_numpy(s_i).float().to(args.device) for s_i in s_std] | |
with open(args.text_prompt_templates, "r") as templates: | |
text_prompt_templates = templates.readlines() | |
global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates, args.device) | |
return global_direction_calculator | |
def load_stylegan_generator(args): | |
stylegan_model = Generator(args.stylegan_size, 512, 8, channel_multiplier=2).to(args.device) | |
checkpoint = torch.load(args.stylegan_weights) | |
stylegan_model.load_state_dict(checkpoint['g_ema']) | |
return stylegan_model | |
def run(): | |
args = parse_args() | |
stylegan_model = load_stylegan_generator(args) | |
global_direction_calculator = load_direction_calculator(args) | |
# load latents obtained via inference | |
latents = np.load(os.path.join(args.exp_dir, 'latents.npy'), allow_pickle=True).item() | |
# prepare output directory | |
args.output_path = os.path.join(args.exp_dir, "styleclip_edits", f"{args.neutral_text}_to_{args.target_text}") | |
os.makedirs(args.output_path, exist_ok=True) | |
# edit all images | |
for idx, (image_name, latent) in enumerate(latents.items()): | |
if args.n_images is not None and idx >= args.n_images: | |
break | |
weight_deltas = np.load(os.path.join(args.weight_deltas_path, image_name.split(".")[0] + ".npy"), allow_pickle=True) | |
weight_deltas = [torch.from_numpy(w).to(args.device) if w is not None else None for w in weight_deltas] | |
latent = torch.from_numpy(latent) | |
results, results_latent, source_img = edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas) | |
torchvision.utils.save_image(results, f"{args.output_path}/{image_name.split('.')[0]}.jpg", | |
normalize=True, range=(-1, 1), padding=0, nrow=args.num_alphas) | |
def edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas=None): | |
latent_code = latent.to(args.device) | |
truncation = 1 | |
mean_latent = None | |
input_is_latent = True | |
latent_code_i = latent_code.unsqueeze(0) | |
with torch.no_grad(): | |
source_im, _, latent_code_s = stylegan_model([latent_code_i], | |
input_is_latent=input_is_latent, | |
randomize_noise=False, | |
return_latents=True, | |
truncation=truncation, | |
truncation_latent=mean_latent, | |
weights_deltas=weight_deltas) | |
results = [] | |
edited_latent_code_s = edit_style_code(latent_code_s, global_direction_calculator, args) | |
if args.edit_weight_delta and weight_deltas is not None: | |
edited_weight_deltas = edit_weight_delta(weight_deltas, global_direction_calculator, args) | |
else: | |
edited_weight_deltas = weight_deltas | |
for b in range(0, edited_latent_code_s[0].shape[0]): | |
edited_latent_code_s_batch = [s_i[b:b + 1] for s_i in edited_latent_code_s] | |
edited_weight_deltas_batch = [w_i[b:b+1] if w_i is not None else None for w_i in edited_weight_deltas] if weight_deltas is not None else None | |
with torch.no_grad(): | |
edited_image, _, _ = stylegan_model([edited_latent_code_s_batch], | |
input_is_stylespace=True, | |
randomize_noise=False, | |
return_latents=True, | |
weights_deltas=edited_weight_deltas_batch) | |
results.append(edited_image) | |
results = torch.cat(results) | |
return results | |
def edit_style_code(latent_code_s, global_direction_calculator, args): | |
direction = global_direction_calculator.get_delta_s(args.neutral_text, args.target_text, args.beta) | |
edited_latent_code_s = [torch.cat([s_i + args.alpha * b_i]) for s_i, b_i in zip(latent_code_s, direction)] | |
return edited_latent_code_s | |
def edit_weight_delta(weight_delta, global_direction_calculator, args): | |
beta = args.beta if args.weight_delta_beta is None else args.weight_delta_beta | |
#alpha = args.alpha if args.weight_delta_alpha is None else args.weight_delta_alpha | |
direction = global_direction_calculator.get_delta_s(args.neutral_text, args.target_text, beta) | |
edited_weight_delta = [torch.cat([(w_i) * (1-torch.abs(b_i))]) if w_i is not None else None for w_i, b_i in zip(weight_delta, direction)] | |
return edited_weight_delta | |
if __name__ == "__main__": | |
run() | |