yourusername's picture
:zap: cache examples
458fe20
raw
history blame contribute delete
No virus
3.18 kB
import subprocess
from pathlib import Path
import einops
import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from torch import nn
from torchvision.utils import save_image
class Generator(nn.Module):
def __init__(self, nc=4, nz=100, ngf=64):
super(Generator, self).__init__()
self.network = nn.Sequential(
nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, input):
output = self.network(input)
return output
model = Generator()
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
@torch.no_grad()
def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
z1 = torch.randn(rows * cols, 100, 1, 1)
z2 = torch.randn(rows * cols, 100, 1, 1)
zs = []
for i in range(frames):
alpha = i / frames
z = (1 - alpha) * z1 + alpha * z2
zs.append(z)
zs += zs[::-1] # also go in reverse order to complete loop
for i, z in enumerate(zs):
imgs = model(z)
# normalize
imgs = (imgs + 1) / 2
imgs = (imgs.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
# create grid
imgs = einops.rearrange(imgs, "(b1 b2) h w c -> (b1 h) (b2 w) c", b1=rows, b2=cols)
Image.fromarray(imgs).save(save_dir / f"{i:03}.png")
subprocess.call(f"convert -dispose previous -delay 10 -loop 0 {save_dir}/*.png out.gif".split())
def predict(choice, seed):
torch.manual_seed(seed)
if choice == 'interpolation':
interpolate()
return 'out.gif'
else:
z = torch.randn(64, 100, 1, 1)
punks = model(z)
save_image(punks, "punks.png", normalize=True)
return 'punks.png'
gr.Interface(
predict,
inputs=[
gr.inputs.Dropdown(['image', 'interpolation'], label='Output Type'),
gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
],
outputs="image",
title="Cryptopunks GAN",
description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.",
article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
examples=[["interpolation", 123], ["interpolation", 42], ["image", 456], ["image", 42]],
).launch(cache_examples=True)