mpamt commited on
Commit
b348589
1 Parent(s): ba45d1f

Added the Generator Function

Browse files
Files changed (1) hide show
  1. app.py +24 -0
app.py CHANGED
@@ -3,6 +3,30 @@ from huggingface_hub import hf_hub_download
3
  import torch
4
  import matplotlib.pyplot as plt
5
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  path = hf_hub_download('huggan/ArtGAN', 'ArtGAN.pt')
8
  model = torch.load(path)
 
3
  import torch
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
+ from torch import nn
7
+
8
+ class Generator(nn.Module):
9
+ def __init__(self):
10
+ super(Generator, self).__init__()
11
+ self.main = nn.Sequential(
12
+ nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
13
+ nn.BatchNorm2d(64 * 8),
14
+ nn.ReLU(True),
15
+ nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
16
+ nn.BatchNorm2d(64 * 4),
17
+ nn.ReLU(True),
18
+ nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
19
+ nn.BatchNorm2d(64 * 2),
20
+ nn.ReLU(True),
21
+ nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
22
+ nn.BatchNorm2d(64),
23
+ nn.ReLU(True),
24
+ nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
25
+ nn.Tanh()
26
+ )
27
+
28
+ def forward(self, input):
29
+ return self.main(input)
30
 
31
  path = hf_hub_download('huggan/ArtGAN', 'ArtGAN.pt')
32
  model = torch.load(path)