ldkong commited on
Commit
d2b9628
β€’
1 Parent(s): e423491

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -16
app.py CHANGED
@@ -1,4 +1,8 @@
 
1
  from torch import nn
 
 
 
2
 
3
  class Generator(nn.Module):
4
  # Refer to the link below for explanations about nc, nz, and ngf
@@ -23,35 +27,23 @@ class Generator(nn.Module):
23
  output = self.network(input)
24
  return output
25
 
26
-
27
-
28
- from huggingface_hub import hf_hub_download
29
- import torch
30
-
31
  model = Generator()
32
  weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
33
  model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available
34
 
35
-
36
-
37
- from torchvision.utils import save_image
38
-
39
- def predict(seed):
40
- num_punks = 4
41
  torch.manual_seed(seed)
42
  z = torch.randn(num_punks, 100, 1, 1)
43
  punks = model(z)
44
  save_image(punks, "punks.png", normalize=True)
45
  return 'punks.png'
46
 
47
-
48
-
49
- import gradio as gr
50
-
51
  gr.Interface(
52
  predict,
53
  inputs=[
54
  gr.Slider(0, 1000, label='Seed', default=42),
 
55
  ],
56
  outputs="image",
57
- ).launch()
 
 
1
+ import torch
2
  from torch import nn
3
+ from huggingface_hub import hf_hub_download
4
+ from torchvision.utils import save_image
5
+ import gradio as gr
6
 
7
  class Generator(nn.Module):
8
  # Refer to the link below for explanations about nc, nz, and ngf
 
27
  output = self.network(input)
28
  return output
29
 
 
 
 
 
 
30
  model = Generator()
31
  weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
32
  model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available
33
 
34
+ def predict(seed, num_punks):
 
 
 
 
 
35
  torch.manual_seed(seed)
36
  z = torch.randn(num_punks, 100, 1, 1)
37
  punks = model(z)
38
  save_image(punks, "punks.png", normalize=True)
39
  return 'punks.png'
40
 
 
 
 
 
41
  gr.Interface(
42
  predict,
43
  inputs=[
44
  gr.Slider(0, 1000, label='Seed', default=42),
45
+ gr.Slider(4, 64, label='Number of Punks', step=1, default=10),
46
  ],
47
  outputs="image",
48
+ examples=[[123, 15], [42, 29], [456, 8], [1337, 35]],
49
+ ).launch(cache_examples=True)