ethanNeuralImage's picture
fix GPU usage to be optional
6fa3e0e
raw
history blame
No virus
7.18 kB
import argparse
import os
import pickle
import torch
import numpy as np
import torchvision
import sys
sys.path.append(".")
sys.path.append("..")
from hyperstyle_global_directions.global_direction import StyleCLIPGlobalDirection
from models.stylegan2.model import Generator
def parse_args(args_list=None):
parser = argparse.ArgumentParser()
parser.add_argument("--exp_dir", type=str, default="./experiment",
help="Path to inference results with `latents.npy` saved here (obtained with inference.py).")
parser.add_argument("--weight_deltas_path", type=str, default="./weight_deltas",
help="Root path holding all weight deltas (obtained by running inference.py).")
parser.add_argument('--n_images', type=int, default=None,
help="Maximum number of images to edit. If None, edit all images.")
parser.add_argument("--neutral_text", type=str, default="face with hair")
parser.add_argument("--target_text", type=str, default="face with long hair")
parser.add_argument("--stylegan_weights", type=str, default='../pretrained_models/stylegan2-ffhq-config-f.pt')
parser.add_argument("--stylegan_size", type=int, default=1024)
parser.add_argument("--stylegan_truncation", type=int, default=1.)
parser.add_argument("--stylegan_truncation_mean", type=int, default=4096)
parser.add_argument("--beta", type=float, default=0.14)
parser.add_argument("--alpha", type=float, default=4.1)
parser.add_argument("--weight_delta_beta", type=float, default=None)
parser.add_argument("--weight_delta_alpha", type=float, default=None)
parser.add_argument("--delta_i_c", type=str, default='../hyperstyle_global_directions/global_directions/ffhq/fs3.npy',
help="path to file containing delta_i_c")
parser.add_argument("--s_statistics", type=str, default='../hyperstyle_global_directions/global_directions/ffhq/S_mean_std',
help="path to file containing s statistics")
parser.add_argument("--text_prompt_templates", default='../hyperstyle_global_directions/global_directions/templates.txt')
args = parser.parse_args(args_list)
return args
def load_direction_calculator(args):
delta_i_c = torch.from_numpy(np.load(args.delta_i_c)).float().to(args.device)
with open(args.s_statistics, "rb") as channels_statistics:
_, s_std = pickle.load(channels_statistics)
s_std = [torch.from_numpy(s_i).float().to(args.device) for s_i in s_std]
with open(args.text_prompt_templates, "r") as templates:
text_prompt_templates = templates.readlines()
global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates, args.device)
return global_direction_calculator
def load_stylegan_generator(args):
stylegan_model = Generator(args.stylegan_size, 512, 8, channel_multiplier=2).to(args.device)
checkpoint = torch.load(args.stylegan_weights)
stylegan_model.load_state_dict(checkpoint['g_ema'])
return stylegan_model
def run():
args = parse_args()
stylegan_model = load_stylegan_generator(args)
global_direction_calculator = load_direction_calculator(args)
# load latents obtained via inference
latents = np.load(os.path.join(args.exp_dir, 'latents.npy'), allow_pickle=True).item()
# prepare output directory
args.output_path = os.path.join(args.exp_dir, "styleclip_edits", f"{args.neutral_text}_to_{args.target_text}")
os.makedirs(args.output_path, exist_ok=True)
# edit all images
for idx, (image_name, latent) in enumerate(latents.items()):
if args.n_images is not None and idx >= args.n_images:
break
weight_deltas = np.load(os.path.join(args.weight_deltas_path, image_name.split(".")[0] + ".npy"), allow_pickle=True)
weight_deltas = [torch.from_numpy(w).to(args.device) if w is not None else None for w in weight_deltas]
latent = torch.from_numpy(latent)
results, results_latent, source_img = edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas)
torchvision.utils.save_image(results, f"{args.output_path}/{image_name.split('.')[0]}.jpg",
normalize=True, range=(-1, 1), padding=0, nrow=args.num_alphas)
def edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas=None):
latent_code = latent.to(args.device)
truncation = 1
mean_latent = None
input_is_latent = True
latent_code_i = latent_code.unsqueeze(0)
with torch.no_grad():
source_im, _, latent_code_s = stylegan_model([latent_code_i],
input_is_latent=input_is_latent,
randomize_noise=False,
return_latents=True,
truncation=truncation,
truncation_latent=mean_latent,
weights_deltas=weight_deltas)
results = []
edited_latent_code_s = edit_style_code(latent_code_s, global_direction_calculator, args)
if args.edit_weight_delta and weight_deltas is not None:
edited_weight_deltas = edit_weight_delta(weight_deltas, global_direction_calculator, args)
else:
edited_weight_deltas = weight_deltas
for b in range(0, edited_latent_code_s[0].shape[0]):
edited_latent_code_s_batch = [s_i[b:b + 1] for s_i in edited_latent_code_s]
edited_weight_deltas_batch = [w_i[b:b+1] if w_i is not None else None for w_i in edited_weight_deltas] if weight_deltas is not None else None
with torch.no_grad():
edited_image, _, _ = stylegan_model([edited_latent_code_s_batch],
input_is_stylespace=True,
randomize_noise=False,
return_latents=True,
weights_deltas=edited_weight_deltas_batch)
results.append(edited_image)
results = torch.cat(results)
return results
def edit_style_code(latent_code_s, global_direction_calculator, args):
direction = global_direction_calculator.get_delta_s(args.neutral_text, args.target_text, args.beta)
edited_latent_code_s = [torch.cat([s_i + args.alpha * b_i]) for s_i, b_i in zip(latent_code_s, direction)]
return edited_latent_code_s
def edit_weight_delta(weight_delta, global_direction_calculator, args):
beta = args.beta if args.weight_delta_beta is None else args.weight_delta_beta
#alpha = args.alpha if args.weight_delta_alpha is None else args.weight_delta_alpha
direction = global_direction_calculator.get_delta_s(args.neutral_text, args.target_text, beta)
edited_weight_delta = [torch.cat([(w_i) * (1-torch.abs(b_i))]) if w_i is not None else None for w_i, b_i in zip(weight_delta, direction)]
return edited_weight_delta
if __name__ == "__main__":
run()