File size: 9,579 Bytes
84595ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright (c) SenseTime Research. All rights reserved.

import os
import sys
import torch
import numpy as np
sys.path.append(".")
from torch_utils.models import Generator
import click
import cv2
from typing import List, Optional
import subprocess
import legacy
from edit.edit_helper import conv_warper, decoder, encoder_ifg, encoder_ss, encoder_sefa 


"""
Edit generated images with different SOTA methods. 
 Notes:
 1. We provide some latent directions in the folder, you can play around with them.
 2. ''upper_length'' and ''bottom_length'' of ''attr_name'' are available for demo.
 3. Layers to control and editing strength are set in edit/edit_config.py.
 
Examples:

\b
# Editing with InterfaceGAN, StyleSpace, and Sefa
python edit.py --network pretrained_models/stylegan_human_v2_1024.pkl --attr_name upper_length \\
    --seeds 61531,61570,61571,61610 --outdir outputs/edit_results


# Editing using inverted latent code
python edit.py ---network outputs/pti/checkpoints/model_test.pkl --attr_name upper_length  \\
    --outdir outputs/edit_results --real True --real_w_path outputs/pti/embeddings/test/PTI/test/0.pt --real_img_path aligned_image/test.png

"""



@click.command()
@click.pass_context
@click.option('--network', 'ckpt_path', help='Network pickle filename', required=True)
@click.option('--attr_name', help='choose one of the attr: upper_length or bottom_length', type=str, required=True)
@click.option('--trunc', 'truncation', type=float, help='Truncation psi', default=0.8, show_default=True)
@click.option('--gen_video', type=bool, default=True, help='If want to generate video')
@click.option('--combine', type=bool, default=True,  help='If want to combine different editing results in the same frame')
@click.option('--seeds', type=legacy.num_range, help='List of random seeds')
@click.option('--outdir', help='Where to save the output images', type=str, required=True, default='outputs/editing', metavar='DIR')
@click.option('--real', type=bool, help='True for editing real image', default=False)
@click.option('--real_w_path',  help='Path of latent code for real image')
@click.option('--real_img_path',  help='Path of real image, this just concat real image with inverted and edited results together')



def main(
    ctx: click.Context,
    ckpt_path: str,
    attr_name: str,
    truncation: float,
    gen_video: bool,
    combine: bool,
    seeds: Optional[List[int]],
    outdir: str,
    real: str,
    real_w_path: str,
    real_img_path: str
):
    ## convert pkl to pth
    # if not os.path.exists(ckpt_path.replace('.pkl','.pth')):
    legacy.convert(ckpt_path, ckpt_path.replace('.pkl','.pth'), G_only=real)
    ckpt_path = ckpt_path.replace('.pkl','.pth') 
    print("start...", flush=True)
    config = {"latent" : 512, "n_mlp" : 8, "channel_multiplier": 2}
    generator = Generator(
            size = 1024,
            style_dim=config["latent"],
            n_mlp=config["n_mlp"],
            channel_multiplier=config["channel_multiplier"]
        )
    
    generator.load_state_dict(torch.load(ckpt_path)['g_ema'])
    generator.eval().cuda()

    with torch.no_grad():
        mean_path = os.path.join('edit','mean_latent.pkl')
        if not os.path.exists(mean_path):
            mean_n = 3000
            mean_latent = generator.mean_latent(mean_n).detach()
            legacy.save_obj(mean_latent, mean_path)
        else:
            mean_latent = legacy.load_pkl(mean_path).cuda() 
        finals = []

        ## -- selected sample seeds -- ##
        # seeds = [60948,60965,61174,61210,61511,61598,61610] #bottom -> long
        #         [60941,61064,61103,61313,61531,61570,61571] # bottom -> short
        #         [60941,60965,61064,61103,6117461210,61531,61570,61571,61610] # upper --> long
        #         [60948,61313,61511,61598] # upper --> short
        if real: seeds = [0]

        for t in seeds:
            if real: # now assume process single real image only
                if real_img_path:
                    real_image = cv2.imread(real_img_path)
                    real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB)
                    import torchvision.transforms as transforms
                    transform = transforms.Compose( # normalize to (-1, 1)
                        [transforms.ToTensor(),
                        transforms.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5))]
                    )
                    real_image = transform(real_image).unsqueeze(0).cuda()                

                test_input = torch.load(real_w_path)
                output, _ = generator(test_input, False, truncation=1,input_is_latent=True, real=True)

            else: # generate image from random seeds
                test_input = torch.from_numpy(np.random.RandomState(t).randn(1, 512)).float().cuda()  # torch.Size([1, 512])
                output, _ = generator([test_input], False, truncation=truncation, truncation_latent=mean_latent, real=real)
            
            # interfacegan
            style_space, latent, noise = encoder_ifg(generator, test_input, attr_name, truncation, mean_latent,real=real)
            image1 = decoder(generator, style_space, latent, noise)
            # stylespace
            style_space, latent, noise = encoder_ss(generator, test_input, attr_name, truncation, mean_latent,real=real)
            image2 = decoder(generator, style_space, latent, noise)
            # sefa
            latent, noise = encoder_sefa(generator, test_input, attr_name, truncation, mean_latent,real=real)
            image3, _ = generator([latent], noise=noise, input_is_latent=True)
            if real_img_path:
                final = torch.cat((real_image, output, image1, image2, image3), 3)
            else:
                final = torch.cat((output, image1, image2, image3), 3)

            # legacy.visual(output, f'{outdir}/{attr_name}_{t:05d}_raw.jpg')
            # legacy.visual(image1, f'{outdir}/{attr_name}_{t:05d}_ifg.jpg')
            # legacy.visual(image2, f'{outdir}/{attr_name}_{t:05d}_ss.jpg')
            # legacy.visual(image3, f'{outdir}/{attr_name}_{t:05d}_sefa.jpg')

            if gen_video:
                total_step = 90
                if real:
                    video_ifg_path = f"{outdir}/video/ifg_{attr_name}_{real_w_path.split('/')[-2]}/"
                    video_ss_path = f"{outdir}/video/ss_{attr_name}_{real_w_path.split('/')[-2]}/"
                    video_sefa_path = f"{outdir}/video/ss_{attr_name}_{real_w_path.split('/')[-2]}/"
                else:
                    video_ifg_path = f"{outdir}/video/ifg_{attr_name}_{t:05d}/"
                    video_ss_path = f"{outdir}/video/ss_{attr_name}_{t:05d}/"
                    video_sefa_path = f"{outdir}/video/ss_{attr_name}_{t:05d}/"
                video_comb_path = f"{outdir}/video/tmp"    

                if combine:
                    if not os.path.exists(video_comb_path):
                        os.makedirs(video_comb_path)
                else:
                    if not os.path.exists(video_ifg_path):
                        os.makedirs(video_ifg_path)
                    if not os.path.exists(video_ss_path):
                        os.makedirs(video_ss_path)
                    if not os.path.exists(video_sefa_path):
                        os.makedirs(video_sefa_path)
                for i in range(total_step):
                    style_space, latent, noise = encoder_ifg(generator, test_input, attr_name, truncation, mean_latent, step=i, total=total_step,real=real)
                    image1 = decoder(generator, style_space, latent, noise)
                    style_space, latent, noise = encoder_ss(generator, test_input, attr_name, truncation, mean_latent, step=i, total=total_step,real=real)
                    image2 = decoder(generator, style_space, latent, noise)
                    latent, noise = encoder_sefa(generator, test_input, attr_name, truncation, mean_latent, step=i, total=total_step,real=real)
                    image3, _ = generator([latent], noise=noise, input_is_latent=True)
                    if combine:
                        if real_img_path:
                            comb_img = torch.cat((real_image, output, image1, image2, image3), 3)
                        else:
                            comb_img = torch.cat((output, image1, image2, image3), 3)
                        legacy.visual(comb_img, os.path.join(video_comb_path, f'{i:05d}.jpg'))
                    else:
                        legacy.visual(image1, os.path.join(video_ifg_path, f'{i:05d}.jpg'))
                        legacy.visual(image2, os.path.join(video_ss_path, f'{i:05d}.jpg'))
                if combine:
                    cmd=f"ffmpeg -hide_banner -loglevel error -y -r 30 -i {video_comb_path}/%05d.jpg -vcodec libx264 -pix_fmt yuv420p {video_ifg_path.replace('ifg_', '')[:-1] + '.mp4'}"
                    subprocess.call(cmd, shell=True)
                else:
                    cmd=f"ffmpeg -hide_banner -loglevel error -y -r 30 -i {video_ifg_path}/%05d.jpg -vcodec libx264 -pix_fmt yuv420p {video_ifg_path[:-1] + '.mp4'}"
                    subprocess.call(cmd, shell=True)
                    cmd=f"ffmpeg -hide_banner -loglevel error -y -r 30 -i {video_ss_path}/%05d.jpg -vcodec libx264 -pix_fmt yuv420p {video_ss_path[:-1] + '.mp4'}"
                    subprocess.call(cmd, shell=True)

        # interfacegan, stylespace, sefa
        finals.append(final)

        final = torch.cat(finals, 2)
        legacy.visual(final, os.path.join(outdir,'final.jpg'))


if __name__ == "__main__":
    main()