Caroline Mai Chan commited on
Commit
3984452
1 Parent(s): 641f158

add residual block

Browse files
Files changed (1) hide show
  1. app.py +22 -1
app.py CHANGED
@@ -3,11 +3,32 @@ import torch
3
  import torch.nn as nn
4
  import gradio as gr
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class Generator(nn.Module):
7
  def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
8
  super(Generator, self).__init__()
9
 
10
- norm_layer = nn.InstanceNorm2d
11
  # Initial convolution block
12
  model0 = [ nn.ReflectionPad2d(3),
13
  nn.Conv2d(input_nc, 64, 7),
 
3
  import torch.nn as nn
4
  import gradio as gr
5
 
6
+
7
+ norm_layer = nn.InstanceNorm2d
8
+
9
+ class ResidualBlock(nn.Module):
10
+ def __init__(self, in_features):
11
+ super(ResidualBlock, self).__init__()
12
+
13
+ conv_block = [ nn.ReflectionPad2d(1),
14
+ nn.Conv2d(in_features, in_features, 3),
15
+ norm_layer(in_features),
16
+ nn.ReLU(inplace=True),
17
+ nn.ReflectionPad2d(1),
18
+ nn.Conv2d(in_features, in_features, 3),
19
+ norm_layer(in_features)
20
+ ]
21
+
22
+ self.conv_block = nn.Sequential(*conv_block)
23
+
24
+ def forward(self, x):
25
+ return x + self.conv_block(x)
26
+
27
+
28
  class Generator(nn.Module):
29
  def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
30
  super(Generator, self).__init__()
31
 
 
32
  # Initial convolution block
33
  model0 = [ nn.ReflectionPad2d(3),
34
  nn.Conv2d(input_nc, 64, 7),