Spaces:
Running
Running
AlekseyKorshuk
commited on
Commit
•
ae2d652
1
Parent(s):
a94775f
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ from huggingface_hub import hf_hub_download
|
|
9 |
from PIL import Image
|
10 |
from torch import nn
|
11 |
from torchvision.utils import save_image
|
12 |
-
|
13 |
|
14 |
class Generator(nn.Module):
|
15 |
def __init__(self, num_channels=4, latent_dim=100, hidden_size=64):
|
@@ -42,10 +42,6 @@ class Generator(nn.Module):
|
|
42 |
|
43 |
return pixel_values
|
44 |
|
45 |
-
model = Generator()
|
46 |
-
weights_path = hf_hub_download('huggingnft/dooggies', 'pytorch_model.bin')
|
47 |
-
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
|
48 |
-
|
49 |
|
50 |
@torch.no_grad()
|
51 |
def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
|
@@ -76,7 +72,10 @@ def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
|
|
76 |
save_all=True, duration=100, loop=1)
|
77 |
|
78 |
|
79 |
-
def predict(choice, seed):
|
|
|
|
|
|
|
80 |
torch.manual_seed(seed)
|
81 |
|
82 |
if choice == 'interpolation':
|
@@ -92,9 +91,11 @@ def predict(choice, seed):
|
|
92 |
return 'image.png'
|
93 |
|
94 |
|
|
|
95 |
gr.Interface(
|
96 |
predict,
|
97 |
inputs=[
|
|
|
98 |
gr.inputs.Dropdown(['image', 'interpolation'], label='Output Type'),
|
99 |
gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
|
100 |
],
|
@@ -102,5 +103,5 @@ gr.Interface(
|
|
102 |
title="Cryptopunks GAN",
|
103 |
description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.",
|
104 |
article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
|
105 |
-
examples=[["interpolation",
|
106 |
).launch(cache_examples=True)
|
|
|
9 |
from PIL import Image
|
10 |
from torch import nn
|
11 |
from torchvision.utils import save_image
|
12 |
+
hfapi = HfApi()
|
13 |
|
14 |
class Generator(nn.Module):
|
15 |
def __init__(self, num_channels=4, latent_dim=100, hidden_size=64):
|
|
|
42 |
|
43 |
return pixel_values
|
44 |
|
|
|
|
|
|
|
|
|
45 |
|
46 |
@torch.no_grad()
|
47 |
def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
|
|
|
72 |
save_all=True, duration=100, loop=1)
|
73 |
|
74 |
|
75 |
+
def predict(model_name, choice, seed):
|
76 |
+
model = Generator()
|
77 |
+
weights_path = hf_hub_download(f'huggingnft/{model_name}', 'pytorch_model.bin')
|
78 |
+
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
|
79 |
torch.manual_seed(seed)
|
80 |
|
81 |
if choice == 'interpolation':
|
|
|
91 |
return 'image.png'
|
92 |
|
93 |
|
94 |
+
models = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")]
|
95 |
gr.Interface(
|
96 |
predict,
|
97 |
inputs=[
|
98 |
+
gr.inputs.Dropdown(models, label='Model'),
|
99 |
gr.inputs.Dropdown(['image', 'interpolation'], label='Output Type'),
|
100 |
gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
|
101 |
],
|
|
|
103 |
title="Cryptopunks GAN",
|
104 |
description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.",
|
105 |
article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
|
106 |
+
examples=[["interpolation", 100], ["interpolation", 500], ["image", 100], ["image", 500]],
|
107 |
).launch(cache_examples=True)
|