NimaBoscarino commited on
Commit
1e67ab5
1 Parent(s): ad7f9c9

Simplify space

Browse files
Files changed (2) hide show
  1. .gitignore +3 -0
  2. app.py +6 -50
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ flagged
2
+ gradio_cached_examples
3
+ punks.png
app.py CHANGED
@@ -1,12 +1,6 @@
1
- import subprocess
2
- from pathlib import Path
3
-
4
- import einops
5
  import gradio as gr
6
- import numpy as np
7
  import torch
8
  from huggingface_hub import hf_hub_download
9
- from PIL import Image
10
  from torch import nn
11
  from torchvision.utils import save_image
12
 
@@ -38,60 +32,22 @@ weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
38
  model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
39
 
40
 
41
- @torch.no_grad()
42
- def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
43
- save_dir = Path(save_dir)
44
- save_dir.mkdir(exist_ok=True, parents=True)
45
-
46
- z1 = torch.randn(rows * cols, 100, 1, 1)
47
- z2 = torch.randn(rows * cols, 100, 1, 1)
48
-
49
- zs = []
50
- for i in range(frames):
51
- alpha = i / frames
52
- z = (1 - alpha) * z1 + alpha * z2
53
- zs.append(z)
54
-
55
- zs += zs[::-1] # also go in reverse order to complete loop
56
-
57
- for i, z in enumerate(zs):
58
- imgs = model(z)
59
-
60
- # normalize
61
- imgs = (imgs + 1) / 2
62
-
63
- imgs = (imgs.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
64
-
65
- # create grid
66
- imgs = einops.rearrange(imgs, "(b1 b2) h w c -> (b1 h) (b2 w) c", b1=rows, b2=cols)
67
-
68
- Image.fromarray(imgs).save(save_dir / f"{i:03}.png")
69
-
70
- subprocess.call(f"convert -dispose previous -delay 10 -loop 0 {save_dir}/*.png out.gif".split())
71
-
72
-
73
- def predict(choice, seed):
74
  torch.manual_seed(seed)
75
-
76
- if choice == 'interpolation':
77
- interpolate()
78
- return 'out.gif'
79
- else:
80
- z = torch.randn(64, 100, 1, 1)
81
- punks = model(z)
82
- save_image(punks, "punks.png", normalize=True)
83
- return 'punks.png'
84
 
85
 
86
  gr.Interface(
87
  predict,
88
  inputs=[
89
- gr.inputs.Dropdown(['image', 'interpolation'], label='Output Type'),
90
  gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
91
  ],
92
  outputs="image",
93
  title="Cryptopunks GAN",
94
  description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.",
95
  article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
96
- examples=[["interpolation", 123], ["interpolation", 42], ["image", 456], ["image", 42]],
97
  ).launch(cache_examples=True)
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
  from huggingface_hub import hf_hub_download
 
4
  from torch import nn
5
  from torchvision.utils import save_image
6
 
 
32
  model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
33
 
34
 
35
+ def predict(seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  torch.manual_seed(seed)
37
+ z = torch.randn(64, 100, 1, 1)
38
+ punks = model(z)
39
+ save_image(punks, "punks.png", normalize=True)
40
+ return 'punks.png'
 
 
 
 
 
41
 
42
 
43
  gr.Interface(
44
  predict,
45
  inputs=[
 
46
  gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
47
  ],
48
  outputs="image",
49
  title="Cryptopunks GAN",
50
  description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.",
51
  article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
52
+ examples=[[123], [42], [456], [1337]],
53
  ).launch(cache_examples=True)