ldkong commited on
Commit
d63c6cd
β€’
1 Parent(s): d74393f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -5
app.py CHANGED
@@ -1,9 +1,57 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
6
 
7
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
8
 
9
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ class Generator(nn.Module):
4
+ # Refer to the link below for explanations about nc, nz, and ngf
5
+ # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#inputs
6
+ def __init__(self, nc=4, nz=100, ngf=64):
7
+ super(Generator, self).__init__()
8
+ self.network = nn.Sequential(
9
+ nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
10
+ nn.BatchNorm2d(ngf * 4),
11
+ nn.ReLU(True),
12
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
13
+ nn.BatchNorm2d(ngf * 2),
14
+ nn.ReLU(True),
15
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
16
+ nn.BatchNorm2d(ngf),
17
+ nn.ReLU(True),
18
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
19
+ nn.Tanh(),
20
+ )
21
+
22
+ def forward(self, input):
23
+ output = self.network(input)
24
+ return output
25
+
26
+
27
+
28
+ from huggingface_hub import hf_hub_download
29
+ import torch
30
 
31
+ model = Generator()
32
+ weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
33
+ model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available
34
 
 
 
35
 
 
36
 
37
+ from torchvision.utils import save_image
38
+
39
+ def predict(seed):
40
+ num_punks = 4
41
+ torch.manual_seed(seed)
42
+ z = torch.randn(num_punks, 100, 1, 1)
43
+ punks = model(z)
44
+ save_image(punks, "punks.png", normalize=True)
45
+ return 'punks.png'
46
+
47
+
48
+
49
+ import gradio as gr
50
+
51
+ gr.Interface(
52
+ predict,
53
+ inputs=[
54
+ gr.Slider(0, 1000, label='Seed', default=42),
55
+ ],
56
+ outputs="image",
57
+ ).launch()