NimaBoscarino commited on
Commit
d015ad6
1 Parent(s): cb42512

Part 1 demo for Gradio guide

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -1,11 +1,8 @@
1
- import gradio as gr
2
- import torch
3
- from huggingface_hub import hf_hub_download
4
  from torch import nn
5
- from torchvision.utils import save_image
6
-
7
 
8
  class Generator(nn.Module):
 
 
9
  def __init__(self, nc=4, nz=100, ngf=64):
10
  super(Generator, self).__init__()
11
  self.network = nn.Sequential(
@@ -26,19 +23,24 @@ class Generator(nn.Module):
26
  output = self.network(input)
27
  return output
28
 
 
 
29
 
30
  model = Generator()
31
  weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
32
- model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
33
 
 
34
 
35
  def predict(seed):
 
36
  torch.manual_seed(seed)
37
- z = torch.randn(64, 100, 1, 1)
38
  punks = model(z)
39
  save_image(punks, "punks.png", normalize=True)
40
  return 'punks.png'
41
 
 
42
 
43
  gr.Interface(
44
  predict,
@@ -46,8 +48,4 @@ gr.Interface(
46
  gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
47
  ],
48
  outputs="image",
49
- title="Cryptopunks GAN",
50
- description="These CryptoPunks do not exist. Generate random punks with an initial seed!",
51
- article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
52
- examples=[[123], [42], [456], [1337]],
53
- ).launch(cache_examples=True)
 
 
 
 
1
  from torch import nn
 
 
2
 
3
  class Generator(nn.Module):
4
+ # Refer to the link below for explanations about nc, nz, and ngf
5
+ # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#inputs
6
  def __init__(self, nc=4, nz=100, ngf=64):
7
  super(Generator, self).__init__()
8
  self.network = nn.Sequential(
 
23
  output = self.network(input)
24
  return output
25
 
26
+ from huggingface_hub import hf_hub_download
27
+ import torch
28
 
29
  model = Generator()
30
  weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
31
+ model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available
32
 
33
+ from torchvision.utils import save_image
34
 
35
  def predict(seed):
36
+ num_punks = 4
37
  torch.manual_seed(seed)
38
+ z = torch.randn(num_punks, 100, 1, 1)
39
  punks = model(z)
40
  save_image(punks, "punks.png", normalize=True)
41
  return 'punks.png'
42
 
43
+ import gradio as gr
44
 
45
  gr.Interface(
46
  predict,
 
48
  gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
49
  ],
50
  outputs="image",
51
+ ).launch()