File size: 4,172 Bytes
5c824be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a74df5c
5c824be
 
 
 
 
 
 
a74df5c
5c824be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3

import gradio as gr

import numpy as np
import torch
import pickle
import PIL.Image
import types

from projector import project, imageio, _MODELS

from huggingface_hub import hf_hub_url, cached_download

# with open("../models/gamma500/network-snapshot-010000.pkl", "rb") as f:
# with open("../models/gamma400/network-snapshot-010600.pkl", "rb") as f:
# with open("../models/gamma400/network-snapshot-019600.pkl", "rb") as f:
with open(cached_download(hf_hub_url('ykilcher/apes', 'gamma500/network-snapshot-010000.pkl')), 'rb') as f:
    G = pickle.load(f)["G_ema"]  # torch.nn.Module

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
    G = G.to(device)
else:
    _old_forward = G.forward

    def _new_forward(self, *args, **kwargs):
        kwargs["force_fp32"] = True
        return _old_forward(*args, **kwargs)

    G.forward = types.MethodType(_new_forward, G)

    _old_synthesis_forward = G.synthesis.forward

    def _new_synthesis_forward(self, *args, **kwargs):
        kwargs["force_fp32"] = True
        return _old_synthesis_forward(*args, **kwargs)

    G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)


def generate(
    target_image_upload,
    # target_image_webcam,
    num_steps,
    seed,
    learning_rate,
    model_name,
    normalize_for_clip,
    loss_type,
    regularize_noise_weight,
    initial_noise_factor,
):
    seed = round(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    target_image = target_image_upload
    # if target_image is None:
        # target_image = target_image_webcam
    num_steps = round(num_steps)
    print(type(target_image))
    print(target_image.dtype)
    print(target_image.max())
    print(target_image.min())
    print(target_image.shape)
    target_pil = PIL.Image.fromarray(target_image).convert("RGB")
    w, h = target_pil.size
    s = min(w, h)
    target_pil = target_pil.crop(
        ((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)
    )
    target_pil = target_pil.resize(
        (G.img_resolution, G.img_resolution), PIL.Image.LANCZOS
    )
    target_uint8 = np.array(target_pil, dtype=np.uint8)
    target_image = torch.from_numpy(target_uint8.transpose([2, 0, 1])).to(device)
    projected_w_steps = project(
        G,
        target=target_image,
        num_steps=num_steps,
        device=device,
        verbose=True,
        initial_learning_rate=learning_rate,
        model_name=model_name,
        normalize_for_clip=normalize_for_clip,
        loss_type=loss_type,
        regularize_noise_weight=regularize_noise_weight,
        initial_noise_factor=initial_noise_factor,
    )
    with torch.no_grad():
        video = imageio.get_writer(f'proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
        for w in projected_w_steps:
            synth_image = G.synthesis(w.to(device).unsqueeze(0), noise_mode="const")
            synth_image = (synth_image + 1) * (255 / 2)
            synth_image = (
                synth_image.permute(0, 2, 3, 1)
                .clamp(0, 255)
                .to(torch.uint8)[0]
                .cpu()
                .numpy()
            )
            video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
        video.close()
    return synth_image, "proj.mp4"


iface = gr.Interface(
    fn=generate,
    inputs=[
        gr.inputs.Image(source="upload", optional=True),
        # gr.inputs.Image(source="webcam", optional=True),
        gr.inputs.Number(default=250, label="steps"),
        gr.inputs.Number(default=69420, label="seed"),
        gr.inputs.Number(default=0.05, label="learning_rate"),
        gr.inputs.Dropdown(default='RN50', label="model_name", choices=['vgg16', *_MODELS.keys()]),
        gr.inputs.Checkbox(default=True, label="normalize_for_clip"),
        gr.inputs.Dropdown(
            default="l2", label="loss_type", choices=["l2", "l1", "cosine"]
        ),
        gr.inputs.Number(default=1e5, label="regularize_noise_weight"),
        gr.inputs.Number(default=0.05, label="initial_noise_factor"),
    ],
    outputs=["image", "video"],
)
iface.launch(inbrowser=True)