import argparse import os import pickle import torch import numpy as np import torchvision import sys sys.path.append(".") sys.path.append("..") from hyperstyle_global_directions.global_direction import StyleCLIPGlobalDirection from models.stylegan2.model import Generator def parse_args(args_list=None): parser = argparse.ArgumentParser() parser.add_argument("--exp_dir", type=str, default="./experiment", help="Path to inference results with `latents.npy` saved here (obtained with inference.py).") parser.add_argument("--weight_deltas_path", type=str, default="./weight_deltas", help="Root path holding all weight deltas (obtained by running inference.py).") parser.add_argument('--n_images', type=int, default=None, help="Maximum number of images to edit. If None, edit all images.") parser.add_argument("--neutral_text", type=str, default="face with hair") parser.add_argument("--target_text", type=str, default="face with long hair") parser.add_argument("--stylegan_weights", type=str, default='../pretrained_models/stylegan2-ffhq-config-f.pt') parser.add_argument("--stylegan_size", type=int, default=1024) parser.add_argument("--stylegan_truncation", type=int, default=1.) parser.add_argument("--stylegan_truncation_mean", type=int, default=4096) parser.add_argument("--beta", type=float, default=0.14) parser.add_argument("--alpha", type=float, default=4.1) parser.add_argument("--weight_delta_beta", type=float, default=None) parser.add_argument("--weight_delta_alpha", type=float, default=None) parser.add_argument("--delta_i_c", type=str, default='../hyperstyle_global_directions/global_directions/ffhq/fs3.npy', help="path to file containing delta_i_c") parser.add_argument("--s_statistics", type=str, default='../hyperstyle_global_directions/global_directions/ffhq/S_mean_std', help="path to file containing s statistics") parser.add_argument("--text_prompt_templates", default='../hyperstyle_global_directions/global_directions/templates.txt') args = parser.parse_args(args_list) return args def load_direction_calculator(args): delta_i_c = torch.from_numpy(np.load(args.delta_i_c)).float().to(args.device) with open(args.s_statistics, "rb") as channels_statistics: _, s_std = pickle.load(channels_statistics) s_std = [torch.from_numpy(s_i).float().to(args.device) for s_i in s_std] with open(args.text_prompt_templates, "r") as templates: text_prompt_templates = templates.readlines() global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates, args.device) return global_direction_calculator def load_stylegan_generator(args): stylegan_model = Generator(args.stylegan_size, 512, 8, channel_multiplier=2).to(args.device) checkpoint = torch.load(args.stylegan_weights) stylegan_model.load_state_dict(checkpoint['g_ema']) return stylegan_model def run(): args = parse_args() stylegan_model = load_stylegan_generator(args) global_direction_calculator = load_direction_calculator(args) # load latents obtained via inference latents = np.load(os.path.join(args.exp_dir, 'latents.npy'), allow_pickle=True).item() # prepare output directory args.output_path = os.path.join(args.exp_dir, "styleclip_edits", f"{args.neutral_text}_to_{args.target_text}") os.makedirs(args.output_path, exist_ok=True) # edit all images for idx, (image_name, latent) in enumerate(latents.items()): if args.n_images is not None and idx >= args.n_images: break weight_deltas = np.load(os.path.join(args.weight_deltas_path, image_name.split(".")[0] + ".npy"), allow_pickle=True) weight_deltas = [torch.from_numpy(w).to(args.device) if w is not None else None for w in weight_deltas] latent = torch.from_numpy(latent) results, results_latent, source_img = edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas) torchvision.utils.save_image(results, f"{args.output_path}/{image_name.split('.')[0]}.jpg", normalize=True, range=(-1, 1), padding=0, nrow=args.num_alphas) def edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas=None): latent_code = latent.to(args.device) truncation = 1 mean_latent = None input_is_latent = True latent_code_i = latent_code.unsqueeze(0) with torch.no_grad(): source_im, _, latent_code_s = stylegan_model([latent_code_i], input_is_latent=input_is_latent, randomize_noise=False, return_latents=True, truncation=truncation, truncation_latent=mean_latent, weights_deltas=weight_deltas) results = [] edited_latent_code_s = edit_style_code(latent_code_s, global_direction_calculator, args) if args.edit_weight_delta and weight_deltas is not None: edited_weight_deltas = edit_weight_delta(weight_deltas, global_direction_calculator, args) else: edited_weight_deltas = weight_deltas for b in range(0, edited_latent_code_s[0].shape[0]): edited_latent_code_s_batch = [s_i[b:b + 1] for s_i in edited_latent_code_s] edited_weight_deltas_batch = [w_i[b:b+1] if w_i is not None else None for w_i in edited_weight_deltas] if weight_deltas is not None else None with torch.no_grad(): edited_image, _, _ = stylegan_model([edited_latent_code_s_batch], input_is_stylespace=True, randomize_noise=False, return_latents=True, weights_deltas=edited_weight_deltas_batch) results.append(edited_image) results = torch.cat(results) return results def edit_style_code(latent_code_s, global_direction_calculator, args): direction = global_direction_calculator.get_delta_s(args.neutral_text, args.target_text, args.beta) edited_latent_code_s = [torch.cat([s_i + args.alpha * b_i]) for s_i, b_i in zip(latent_code_s, direction)] return edited_latent_code_s def edit_weight_delta(weight_delta, global_direction_calculator, args): beta = args.beta if args.weight_delta_beta is None else args.weight_delta_beta #alpha = args.alpha if args.weight_delta_alpha is None else args.weight_delta_alpha direction = global_direction_calculator.get_delta_s(args.neutral_text, args.target_text, beta) edited_weight_delta = [torch.cat([(w_i) * (1-torch.abs(b_i))]) if w_i is not None else None for w_i, b_i in zip(weight_delta, direction)] return edited_weight_delta if __name__ == "__main__": run()