Caroline Mai Chan commited on
Commit
41bee7b
1 Parent(s): 72026fd

add model to code

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py CHANGED
@@ -3,8 +3,69 @@ import torch
3
  import torch.nn as nn
4
  import gradio as gr
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def sepia(input_img):
 
8
  sepia_filter = np.array(
9
  [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
10
  )
 
3
  import torch.nn as nn
4
  import gradio as gr
5
 
6
+ class Generator(nn.Module):
7
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
8
+ super(Generator, self).__init__()
9
+
10
+ # Initial convolution block
11
+ model0 = [ nn.ReflectionPad2d(3),
12
+ nn.Conv2d(input_nc, 64, 7),
13
+ norm_layer(64),
14
+ nn.ReLU(inplace=True) ]
15
+ self.model0 = nn.Sequential(*model0)
16
+
17
+ # Downsampling
18
+ model1 = []
19
+ in_features = 64
20
+ out_features = in_features*2
21
+ for _ in range(2):
22
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
23
+ norm_layer(out_features),
24
+ nn.ReLU(inplace=True) ]
25
+ in_features = out_features
26
+ out_features = in_features*2
27
+ self.model1 = nn.Sequential(*model1)
28
+
29
+ model2 = []
30
+ # Residual blocks
31
+ for _ in range(n_residual_blocks):
32
+ model2 += [ResidualBlock(in_features)]
33
+ self.model2 = nn.Sequential(*model2)
34
+
35
+ # Upsampling
36
+ model3 = []
37
+ out_features = in_features//2
38
+ for _ in range(2):
39
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
40
+ norm_layer(out_features),
41
+ nn.ReLU(inplace=True) ]
42
+ in_features = out_features
43
+ out_features = in_features//2
44
+ self.model3 = nn.Sequential(*model3)
45
+
46
+ # Output layer
47
+ model4 = [ nn.ReflectionPad2d(3),
48
+ nn.Conv2d(64, output_nc, 7)]
49
+ if sigmoid:
50
+ model4 += [nn.Sigmoid()]
51
+
52
+ self.model4 = nn.Sequential(*model4)
53
+
54
+ def forward(self, x, cond=None):
55
+ out = self.model0(x)
56
+ out = self.model1(out)
57
+ out = self.model2(out)
58
+ out = self.model3(out)
59
+ out = self.model4(out)
60
+
61
+ return out
62
+
63
+ # model = Generator(3, 1, 3)
64
+ # model.load_state_dict(torch.load('model.pth'))
65
+ # model.eval()
66
 
67
  def sepia(input_img):
68
+ print(input_img.shape, np.max(input_img), np.min(input_img))
69
  sepia_filter = np.array(
70
  [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
71
  )