File size: 7,175 Bytes
822dd00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fa3e0e
822dd00
 
6fa3e0e
822dd00
 
6fa3e0e
822dd00
 
 
 
6fa3e0e
822dd00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fa3e0e
822dd00
 
 
 
 
 
 
6fa3e0e
822dd00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import os
import pickle
import torch
import numpy as np
import torchvision

import sys
sys.path.append(".")
sys.path.append("..")

from hyperstyle_global_directions.global_direction import StyleCLIPGlobalDirection
from models.stylegan2.model import Generator


def parse_args(args_list=None):
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp_dir", type=str, default="./experiment",
                        help="Path to inference results with `latents.npy` saved here (obtained with inference.py).")
    parser.add_argument("--weight_deltas_path", type=str, default="./weight_deltas",
                        help="Root path holding all weight deltas (obtained by running inference.py).")
    parser.add_argument('--n_images', type=int, default=None,
                        help="Maximum number of images to edit. If None, edit all images.")
    parser.add_argument("--neutral_text", type=str, default="face with hair")
    parser.add_argument("--target_text", type=str, default="face with long hair")
    parser.add_argument("--stylegan_weights", type=str, default='../pretrained_models/stylegan2-ffhq-config-f.pt')
    parser.add_argument("--stylegan_size", type=int, default=1024)
    parser.add_argument("--stylegan_truncation", type=int, default=1.)
    parser.add_argument("--stylegan_truncation_mean", type=int, default=4096)
    parser.add_argument("--beta", type=float, default=0.14)
    parser.add_argument("--alpha", type=float, default=4.1)
    parser.add_argument("--weight_delta_beta", type=float, default=None)
    parser.add_argument("--weight_delta_alpha", type=float, default=None)
    parser.add_argument("--delta_i_c", type=str, default='../hyperstyle_global_directions/global_directions/ffhq/fs3.npy',
                        help="path to file containing delta_i_c")
    parser.add_argument("--s_statistics", type=str, default='../hyperstyle_global_directions/global_directions/ffhq/S_mean_std',
                        help="path to file containing s statistics")
    parser.add_argument("--text_prompt_templates", default='../hyperstyle_global_directions/global_directions/templates.txt')
    args = parser.parse_args(args_list)
    return args


def load_direction_calculator(args):
    delta_i_c = torch.from_numpy(np.load(args.delta_i_c)).float().to(args.device)
    with open(args.s_statistics, "rb") as channels_statistics:
        _, s_std = pickle.load(channels_statistics)
        s_std = [torch.from_numpy(s_i).float().to(args.device) for s_i in s_std]
    with open(args.text_prompt_templates, "r") as templates:
        text_prompt_templates = templates.readlines()
    global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates, args.device)
    return global_direction_calculator


def load_stylegan_generator(args):
    stylegan_model = Generator(args.stylegan_size, 512, 8, channel_multiplier=2).to(args.device)
    checkpoint = torch.load(args.stylegan_weights)
    stylegan_model.load_state_dict(checkpoint['g_ema'])
    return stylegan_model


def run():
    args = parse_args()
    stylegan_model = load_stylegan_generator(args)
    global_direction_calculator = load_direction_calculator(args)
    # load latents obtained via inference
    latents = np.load(os.path.join(args.exp_dir, 'latents.npy'), allow_pickle=True).item()
    # prepare output directory
    args.output_path = os.path.join(args.exp_dir, "styleclip_edits", f"{args.neutral_text}_to_{args.target_text}")
    os.makedirs(args.output_path, exist_ok=True)
    # edit all images
    for idx, (image_name, latent) in enumerate(latents.items()):
        if args.n_images is not None and idx >= args.n_images:
            break
        weight_deltas = np.load(os.path.join(args.weight_deltas_path, image_name.split(".")[0] + ".npy"), allow_pickle=True)
        weight_deltas = [torch.from_numpy(w).to(args.device) if w is not None else None for w in weight_deltas]
        latent = torch.from_numpy(latent)
        results, results_latent, source_img = edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas)
        torchvision.utils.save_image(results, f"{args.output_path}/{image_name.split('.')[0]}.jpg",
                                 normalize=True, range=(-1, 1), padding=0, nrow=args.num_alphas)


def edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas=None):
    latent_code = latent.to(args.device)
    truncation = 1
    mean_latent = None
    input_is_latent = True
    latent_code_i = latent_code.unsqueeze(0)

    with torch.no_grad():

        source_im, _, latent_code_s = stylegan_model([latent_code_i],
                                                     input_is_latent=input_is_latent,
                                                     randomize_noise=False,
                                                     return_latents=True,
                                                     truncation=truncation,
                                                     truncation_latent=mean_latent,
                                                     weights_deltas=weight_deltas)

    

    results = []
    edited_latent_code_s = edit_style_code(latent_code_s, global_direction_calculator, args)
    if args.edit_weight_delta and weight_deltas is not None:
        edited_weight_deltas = edit_weight_delta(weight_deltas, global_direction_calculator, args)
    else:
        edited_weight_deltas = weight_deltas
    for b in range(0, edited_latent_code_s[0].shape[0]):
        edited_latent_code_s_batch = [s_i[b:b + 1] for s_i in edited_latent_code_s]
        edited_weight_deltas_batch = [w_i[b:b+1] if w_i is not None else None for w_i in edited_weight_deltas] if weight_deltas is not None else None
        with torch.no_grad():
            edited_image, _, _ = stylegan_model([edited_latent_code_s_batch],
                                                input_is_stylespace=True,
                                                randomize_noise=False,
                                                 return_latents=True,
                                                weights_deltas=edited_weight_deltas_batch)
            results.append(edited_image)

    results = torch.cat(results)
    return results


def edit_style_code(latent_code_s, global_direction_calculator, args):
    direction = global_direction_calculator.get_delta_s(args.neutral_text, args.target_text, args.beta)
    edited_latent_code_s = [torch.cat([s_i + args.alpha * b_i]) for s_i, b_i in zip(latent_code_s, direction)]
    return edited_latent_code_s

def edit_weight_delta(weight_delta, global_direction_calculator, args):
    beta = args.beta if  args.weight_delta_beta is None else args.weight_delta_beta
    #alpha = args.alpha if  args.weight_delta_alpha is None else args.weight_delta_alpha
    direction = global_direction_calculator.get_delta_s(args.neutral_text, args.target_text, beta)
    edited_weight_delta = [torch.cat([(w_i) * (1-torch.abs(b_i))]) if w_i is not None else None for w_i, b_i in zip(weight_delta, direction)]
    return edited_weight_delta

if __name__ == "__main__":
    run()