Spaces:
Runtime error
Runtime error
import sys | |
import os | |
import gradio as gr | |
from PIL import Image | |
os.system("git clone https://github.com/autonomousvision/projected_gan.git") | |
sys.path.append("projected_gan") | |
"""Generate images using pretrained network pickle.""" | |
import re | |
from typing import List, Optional, Tuple, Union | |
import click | |
import dnnlib | |
import numpy as np | |
import PIL.Image | |
import torch | |
import legacy | |
from huggingface_hub import hf_hub_url | |
#---------------------------------------------------------------------------- | |
def parse_range(s: Union[str, List]) -> List[int]: | |
'''Parse a comma separated list of numbers or ranges and return a list of ints. | |
Example: '1,2,5-10' returns [1, 2, 5, 6, 7] | |
''' | |
if isinstance(s, list): return s | |
ranges = [] | |
range_re = re.compile(r'^(\d+)-(\d+)$') | |
for p in s.split(','): | |
m = range_re.match(p) | |
if m: | |
ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) | |
else: | |
ranges.append(int(p)) | |
return ranges | |
#---------------------------------------------------------------------------- | |
def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: | |
'''Parse a floating point 2-vector of syntax 'a,b'. | |
Example: | |
'0,1' returns (0,1) | |
''' | |
if isinstance(s, tuple): return s | |
parts = s.split(',') | |
if len(parts) == 2: | |
return (float(parts[0]), float(parts[1])) | |
raise ValueError(f'cannot parse 2-vector {s}') | |
#---------------------------------------------------------------------------- | |
def make_transform(translate: Tuple[float,float], angle: float): | |
m = np.eye(3) | |
s = np.sin(angle/360.0*np.pi*2) | |
c = np.cos(angle/360.0*np.pi*2) | |
m[0][0] = c | |
m[0][1] = s | |
m[0][2] = translate[0] | |
m[1][0] = -s | |
m[1][1] = c | |
m[1][2] = translate[1] | |
return m | |
#---------------------------------------------------------------------------- | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
config_file_url = hf_hub_url("autonomousvision/Projected_GAN_Pokemon", filename="pokemon.pkl") | |
with dnnlib.util.open_url(config_file_url) as f: | |
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore | |
def generate_images(seeds): | |
"""Generate images using pretrained network pickle. | |
Examples: | |
\b | |
# Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left). | |
python gen_images.py --outdir=out --trunc=1 --seeds=2 \\ | |
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl | |
\b | |
# Generate uncurated images with truncation using the MetFaces-U dataset | |
python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\ | |
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl | |
""" | |
# Labels. | |
label = torch.zeros([1, G.c_dim], device=device) | |
# Generate images. | |
for seed_idx, seed in enumerate(seeds): | |
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) | |
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device).float() | |
# Construct an inverse rotation/translation matrix and pass to the generator. The | |
# generator expects this matrix as an inverse to avoid potentially failing numerical | |
# operations in the network. | |
if hasattr(G.synthesis, 'input'): | |
m = make_transform('0,0', 0) | |
m = np.linalg.inv(m) | |
G.synthesis.input.transform.copy_(torch.from_numpy(m)) | |
img = G(z, label, truncation_psi=1, noise_mode='const') | |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
pilimg = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB') | |
return pilimg | |
def inference(seedin): | |
listseed = [int(seedin)] | |
output = generate_images(listseed) | |
return output | |
title = "Example: Pokemon GAN" | |
description = "Gradio demo for Pokemon GAN. To use it, provide a seed, or click one of the examples to load them. Read more at the links below." | |
article = "<p style='text-align: center'><a href='http://www.cvlibs.net/publications/Sauer2021NEURIPS.pdf' target='_blank'>Projected GANs Converge Faster</a> | <a href='https://github.com/autonomousvision/projected_gan' target='_blank'>Github Repo</p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_projected_gan' alt='visitor badge'></center>" | |
gr.Interface(inference,gr.inputs.Slider(label="Seed",minimum=0, maximum=5000, step=1, default=0),"pil",title=title,description=description,article=article, allow_screenshot=False, allow_flagging="never", live=True, examples=[ | |
[0],[1],[10],[20],[30],[42],[50],[60],[77],[102] | |
]).launch(enable_queue=True,cache_examples=True) |