File size: 5,285 Bytes
3a19a1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbbf71e
 
3a19a1a
 
 
210c702
 
 
 
 
 
3a19a1a
 
 
 
 
 
 
 
 
210c702
 
 
 
 
 
 
 
 
3a19a1a
720df82
 
16c6665
bbbf71e
 
 
 
 
 
 
16c6665
bbbf71e
951f6cb
3a19a1a
210c702
 
16c6665
3a19a1a
 
bbbf71e
3a19a1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbbf71e
3a19a1a
 
bbbf71e
 
 
 
 
3a19a1a
 
 
 
 
 
210c702
3a19a1a
210c702
3a19a1a
 
 
 
 
210c702
3a19a1a
 
 
 
 
 
 
 
 
 
 
7a331ca
3a19a1a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
'''
Tool for generating editing videos across different domains.

Given a set of latent codes and pre-trained models, it will interpolate between the different codes in each of the target domains
and combine the resulting images into a video.

Example run command:

python generate_videos.py --ckpt /model_dir/pixar.pt \
                                 /model_dir/ukiyoe.pt \
                                 /model_dir/edvard_munch.pt \
                                 /model_dir/botero.pt \
                          --out_dir /output/video/ \
                          --source_latent /latents/latent000.npy \
                          --target_latents /latents/

'''

import os
import argparse

import torch
from torchvision import utils

from model.sg2_model import Generator
from tqdm import tqdm
from pathlib import Path

import numpy as np

import subprocess
import shutil
import copy

from styleclip.styleclip_global import style_tensor_to_style_dict, style_dict_to_style_tensor

VALID_EDITS = ["pose", "age", "smile", "gender", "hair_length", "beard"]

SUGGESTED_DISTANCES = {
                       "pose": 3.0,
                       "smile": 2.0,
                       "age": 4.0,
                       "gender": 3.0,
                       "hair_length": -4.0,
                       "beard": 2.0
                      }
                    
def project_code(latent_code, boundary, distance=3.0):

    if len(boundary) == 2:
        boundary = boundary.reshape(1, 1, -1)

    return latent_code + distance * boundary

def project_code_by_edit_name(latent_code, name, strength):
    boundary_dir = Path(os.path.abspath(__file__)).parents[0].joinpath("editing", "interfacegan_boundaries")

    distance = SUGGESTED_DISTANCES[name] * strength
    boundary = torch.load(os.path.join(boundary_dir, f'{name}.pt'), map_location="cpu").numpy()

    return project_code(latent_code, boundary, distance)

def generate_frames(source_latent, target_latents, g_ema_list, output_dir):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    code_is_s = target_latents[0].size()[1] == 9088

    if code_is_s:
        source_s_dict = g_ema_list[0].get_s_code(source_latent, input_is_latent=True)[0]
        np_latent = style_dict_to_style_tensor(source_s_dict, g_ema_list[0]).cpu().detach().numpy()
    else:
        np_latent = source_latent.squeeze(0).cpu().detach().numpy()

    np_target_latents = [target_latent.cpu().detach().numpy() for target_latent in target_latents]

    num_alphas = 20 if code_is_s else min(10, 30 // len(target_latents))

    alphas = np.linspace(0, 1, num=num_alphas)
    
    latents = interpolate_with_target_latents(np_latent, np_target_latents, alphas)

    segments = len(g_ema_list) - 1

    if segments:
        segment_length = len(latents) / segments

        g_ema = copy.deepcopy(g_ema_list[0])

        src_pars = dict(g_ema.named_parameters())
        mix_pars = [dict(model.named_parameters()) for model in g_ema_list]
    else:
        g_ema = g_ema_list[0]

    print("Generating frames for video...")
    for idx, latent in tqdm(enumerate(latents), total=len(latents)):

        if segments:
            mix_alpha = (idx % segment_length) * 1.0 / segment_length
            segment_id = int(idx // segment_length)

            for k in src_pars.keys():
                src_pars[k].data.copy_(mix_pars[segment_id][k] * (1 - mix_alpha) + mix_pars[segment_id + 1][k] * mix_alpha)

        if idx == 0 or segments or latent is not latents[idx - 1]:
            latent_tensor = torch.from_numpy(latent).float().to(device)

            with torch.no_grad():
                if code_is_s:
                    latent_for_gen = style_tensor_to_style_dict(latent_tensor, g_ema)
                    img, _ = g_ema(latent_for_gen, input_is_s_code=True, input_is_latent=True, truncation=1, randomize_noise=False)
                else:
                    img, _ = g_ema([latent_tensor], input_is_latent=True, truncation=1, randomize_noise=False)

        utils.save_image(img, f"{output_dir}/{str(idx).zfill(3)}.jpg", nrow=1, normalize=True, scale_each=True, range=(-1, 1))

def interpolate_forward_backward(source_latent, target_latent, alphas):
    latents_forward  = [a * target_latent + (1-a) * source_latent for a in alphas] # interpolate from source to target
    latents_backward = latents_forward[::-1]                                       # interpolate from target to source
    return latents_forward + [target_latent] * len(alphas) + latents_backward      # forward + short delay at target + return

def interpolate_with_target_latents(source_latent, target_latents, alphas):    
    # interpolate latent codes with all targets

    print("Interpolating latent codes...")
    
    latents = []
    for target_latent in target_latents:
        latents.extend(interpolate_forward_backward(source_latent, target_latent, alphas))

    return latents

def video_from_interpolations(fps, output_dir):

    # combine frames to a video
    command = ["ffmpeg", 
               "-r", f"{fps}", 
               "-i", f"{output_dir}/%03d.jpg", 
               "-c:v", "libx264", 
               "-vf", f"fps={fps}", 
               "-pix_fmt", "yuv420p", 
               f"{output_dir}/out.mp4"]
    
    subprocess.call(command)