Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import gradio as gr | |
import numpy as np | |
import torch | |
import pickle | |
import PIL.Image | |
import types | |
from projector import project, imageio, _MODELS | |
from huggingface_hub import hf_hub_url, cached_download | |
# with open("../models/gamma500/network-snapshot-010000.pkl", "rb") as f: | |
# with open("../models/gamma400/network-snapshot-010600.pkl", "rb") as f: | |
# with open("../models/gamma400/network-snapshot-019600.pkl", "rb") as f: | |
with open(cached_download(hf_hub_url('ykilcher/apes', 'gamma500/network-snapshot-010000.pkl')), 'rb') as f: | |
G = pickle.load(f)["G_ema"] # torch.nn.Module | |
device = torch.device("cpu") | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
G = G.to(device) | |
else: | |
_old_forward = G.forward | |
def _new_forward(self, *args, **kwargs): | |
kwargs["force_fp32"] = True | |
return _old_forward(*args, **kwargs) | |
G.forward = types.MethodType(_new_forward, G) | |
_old_synthesis_forward = G.synthesis.forward | |
def _new_synthesis_forward(self, *args, **kwargs): | |
kwargs["force_fp32"] = True | |
return _old_synthesis_forward(*args, **kwargs) | |
G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis) | |
def generate( | |
target_image_upload, | |
# target_image_webcam, | |
num_steps, | |
seed, | |
learning_rate, | |
model_name, | |
normalize_for_clip, | |
loss_type, | |
regularize_noise_weight, | |
initial_noise_factor, | |
): | |
seed = round(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
target_image = target_image_upload | |
# if target_image is None: | |
# target_image = target_image_webcam | |
num_steps = round(num_steps) | |
print(type(target_image)) | |
print(target_image.dtype) | |
print(target_image.max()) | |
print(target_image.min()) | |
print(target_image.shape) | |
target_pil = PIL.Image.fromarray(target_image).convert("RGB") | |
w, h = target_pil.size | |
s = min(w, h) | |
target_pil = target_pil.crop( | |
((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2) | |
) | |
target_pil = target_pil.resize( | |
(G.img_resolution, G.img_resolution), PIL.Image.LANCZOS | |
) | |
target_uint8 = np.array(target_pil, dtype=np.uint8) | |
target_image = torch.from_numpy(target_uint8.transpose([2, 0, 1])).to(device) | |
projected_w_steps = project( | |
G, | |
target=target_image, | |
num_steps=num_steps, | |
device=device, | |
verbose=True, | |
initial_learning_rate=learning_rate, | |
model_name=model_name, | |
normalize_for_clip=normalize_for_clip, | |
loss_type=loss_type, | |
regularize_noise_weight=regularize_noise_weight, | |
initial_noise_factor=initial_noise_factor, | |
) | |
with torch.no_grad(): | |
video = imageio.get_writer(f'proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M') | |
for w in projected_w_steps: | |
synth_image = G.synthesis(w.to(device).unsqueeze(0), noise_mode="const") | |
synth_image = (synth_image + 1) * (255 / 2) | |
synth_image = ( | |
synth_image.permute(0, 2, 3, 1) | |
.clamp(0, 255) | |
.to(torch.uint8)[0] | |
.cpu() | |
.numpy() | |
) | |
video.append_data(np.concatenate([target_uint8, synth_image], axis=1)) | |
video.close() | |
return synth_image, "proj.mp4" | |
iface = gr.Interface( | |
fn=generate, | |
inputs=[ | |
gr.inputs.Image(source="upload", optional=True), | |
# gr.inputs.Image(source="webcam", optional=True), | |
gr.inputs.Number(default=250, label="steps"), | |
gr.inputs.Number(default=69420, label="seed"), | |
gr.inputs.Number(default=0.05, label="learning_rate"), | |
gr.inputs.Dropdown(default='RN50', label="model_name", choices=['vgg16', *_MODELS.keys()]), | |
gr.inputs.Checkbox(default=True, label="normalize_for_clip"), | |
gr.inputs.Dropdown( | |
default="l2", label="loss_type", choices=["l2", "l1", "cosine"] | |
), | |
gr.inputs.Number(default=1e5, label="regularize_noise_weight"), | |
gr.inputs.Number(default=0.05, label="initial_noise_factor"), | |
], | |
outputs=["image", "video"], | |
) | |
iface.launch(inbrowser=True) | |