AlekseyKorshuk commited on
Commit
ae2d652
1 Parent(s): a94775f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -9,7 +9,7 @@ from huggingface_hub import hf_hub_download
9
  from PIL import Image
10
  from torch import nn
11
  from torchvision.utils import save_image
12
-
13
 
14
  class Generator(nn.Module):
15
  def __init__(self, num_channels=4, latent_dim=100, hidden_size=64):
@@ -42,10 +42,6 @@ class Generator(nn.Module):
42
 
43
  return pixel_values
44
 
45
- model = Generator()
46
- weights_path = hf_hub_download('huggingnft/dooggies', 'pytorch_model.bin')
47
- model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
48
-
49
 
50
  @torch.no_grad()
51
  def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
@@ -76,7 +72,10 @@ def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
76
  save_all=True, duration=100, loop=1)
77
 
78
 
79
- def predict(choice, seed):
 
 
 
80
  torch.manual_seed(seed)
81
 
82
  if choice == 'interpolation':
@@ -92,9 +91,11 @@ def predict(choice, seed):
92
  return 'image.png'
93
 
94
 
 
95
  gr.Interface(
96
  predict,
97
  inputs=[
 
98
  gr.inputs.Dropdown(['image', 'interpolation'], label='Output Type'),
99
  gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
100
  ],
@@ -102,5 +103,5 @@ gr.Interface(
102
  title="Cryptopunks GAN",
103
  description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.",
104
  article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
105
- examples=[["interpolation", 123], ["interpolation", 42], ["image", 456], ["image", 42]],
106
  ).launch(cache_examples=True)
 
9
  from PIL import Image
10
  from torch import nn
11
  from torchvision.utils import save_image
12
+ hfapi = HfApi()
13
 
14
  class Generator(nn.Module):
15
  def __init__(self, num_channels=4, latent_dim=100, hidden_size=64):
 
42
 
43
  return pixel_values
44
 
 
 
 
 
45
 
46
  @torch.no_grad()
47
  def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
 
72
  save_all=True, duration=100, loop=1)
73
 
74
 
75
+ def predict(model_name, choice, seed):
76
+ model = Generator()
77
+ weights_path = hf_hub_download(f'huggingnft/{model_name}', 'pytorch_model.bin')
78
+ model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
79
  torch.manual_seed(seed)
80
 
81
  if choice == 'interpolation':
 
91
  return 'image.png'
92
 
93
 
94
+ models = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")]
95
  gr.Interface(
96
  predict,
97
  inputs=[
98
+ gr.inputs.Dropdown(models, label='Model'),
99
  gr.inputs.Dropdown(['image', 'interpolation'], label='Output Type'),
100
  gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
101
  ],
 
103
  title="Cryptopunks GAN",
104
  description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.",
105
  article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
106
+ examples=[["interpolation", 100], ["interpolation", 500], ["image", 100], ["image", 500]],
107
  ).launch(cache_examples=True)