mehdiabruee commited on
Commit
7e0282a
1 Parent(s): 6041fbb

Upload cyclegan_inference.py

Browse files
Files changed (1) hide show
  1. cyclegan_inference.py +168 -0
cyclegan_inference.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """cyclegan_inference.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/12lelsBZXqNOe7xaXI724rEHAbppRt07y
8
+ """
9
+
10
+ import gradio as gr
11
+ import torch
12
+ import torchvision
13
+ from torch import nn
14
+ from typing import List
15
+
16
+ def ifnone(a, b): # a fastai-specific (fastcore) function used below, redefined so it's independent
17
+ "`b` if `a` is None else `a`"
18
+ return b if a is None else a
19
+
20
+ class ConvBlock(torch.nn.Module):
21
+ def __init__(self,input_size,output_size,kernel_size=4,stride=2,padding=1,activation='relu',batch_norm=True):
22
+ super(ConvBlock,self).__init__()
23
+ self.conv = torch.nn.Conv2d(input_size,output_size,kernel_size,stride,padding)
24
+ self.batch_norm = batch_norm
25
+ self.bn = torch.nn.InstanceNorm2d(output_size)
26
+ self.activation = activation
27
+ self.relu = torch.nn.ReLU(True)
28
+ self.lrelu = torch.nn.LeakyReLU(0.2,True)
29
+ self.tanh = torch.nn.Tanh()
30
+ self.sigmoid = torch.nn.Sigmoid()
31
+ def forward(self,x):
32
+ if self.batch_norm:
33
+ out = self.bn(self.conv(x))
34
+ else:
35
+ out = self.conv(x)
36
+
37
+ if self.activation == 'relu':
38
+ return self.relu(out)
39
+ elif self.activation == 'lrelu':
40
+ return self.lrelu(out)
41
+ elif self.activation == 'tanh':
42
+ return self.tanh(out)
43
+ elif self.activation == 'no_act':
44
+ return out
45
+ elif self.activation =='sigmoid':
46
+ return self.sigmoid(out)
47
+
48
+
49
+ class ResnetBlock(torch.nn.Module):
50
+ def __init__(self,num_filter,kernel_size=3,stride=1,padding=0):
51
+ super(ResnetBlock,self).__init__()
52
+ conv1 = torch.nn.Conv2d(num_filter,num_filter,kernel_size,stride,padding)
53
+ conv2 = torch.nn.Conv2d(num_filter,num_filter,kernel_size,stride,padding)
54
+ bn = torch.nn.InstanceNorm2d(num_filter)
55
+ relu = torch.nn.ReLU(True)
56
+ pad = torch.nn.ReflectionPad2d(1)
57
+
58
+ self.resnet_block = torch.nn.Sequential(
59
+ pad,
60
+ conv1,
61
+ bn,
62
+ relu,
63
+ pad,
64
+ conv2,
65
+ bn
66
+ )
67
+ def forward(self,x):
68
+ out = self.resnet_block(x)
69
+ return out
70
+ def resnet_generator(ch_in:int, ch_out:int, n_ftrs:int=64, norm_layer:nn.Module=None,
71
+ dropout:float=0., n_blocks:int=9, pad_mode:str='reflection')->nn.Module:
72
+ norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
73
+ bias = (norm_layer == nn.InstanceNorm2d)
74
+ layers = pad_conv_norm_relu(ch_in, n_ftrs, 'reflection', norm_layer, pad=3, ks=7, bias=bias)
75
+ for i in range(2):
76
+ layers += pad_conv_norm_relu(n_ftrs, n_ftrs *2, 'zeros', norm_layer, stride=2, bias=bias)
77
+ n_ftrs *= 2
78
+ layers += [ResnetBlock(n_ftrs, pad_mode, norm_layer, dropout, bias) for _ in range(n_blocks)]
79
+ for i in range(2):
80
+ layers += convT_norm_relu(n_ftrs, n_ftrs//2, norm_layer, bias=bias)
81
+ n_ftrs //= 2
82
+ layers += [nn.ReflectionPad2d(3), nn.Conv2d(n_ftrs, ch_out, kernel_size=7, padding=0), nn.Tanh()]
83
+ return nn.Sequential(*layers)
84
+
85
+ class DeconvBlock(torch.nn.Module):
86
+ def __init__(self,input_size,output_size,kernel_size=4,stride=2,padding=1,activation='relu',batch_norm=True):
87
+ super(DeconvBlock,self).__init__()
88
+ self.deconv = torch.nn.ConvTranspose2d(input_size,output_size,kernel_size,stride,padding)
89
+ self.batch_norm = batch_norm
90
+ self.bn = torch.nn.InstanceNorm2d(output_size)
91
+ self.activation = activation
92
+ self.relu = torch.nn.ReLU(True)
93
+ self.tanh = torch.nn.Tanh()
94
+ def forward(self,x):
95
+ if self.batch_norm:
96
+ out = self.bn(self.deconv(x))
97
+ else:
98
+ out = self.deconv(x)
99
+ if self.activation == 'relu':
100
+ return self.relu(out)
101
+ elif self.activation == 'lrelu':
102
+ return self.lrelu(out)
103
+ elif self.activation == 'tanh':
104
+ return self.tanh(out)
105
+ elif self.activation == 'no_act':
106
+ return out
107
+
108
+ class Generator(torch.nn.Module):
109
+ def __init__(self,input_dim,num_filter,output_dim,num_resnet):
110
+ super(Generator,self).__init__()
111
+
112
+ #Reflection padding
113
+ #self.pad = torch.nn.ReflectionPad2d(3)
114
+ #Encoder
115
+ self.conv1 = ConvBlock(input_dim,num_filter,kernel_size=4,stride=2,padding=1)
116
+ self.conv2 = ConvBlock(num_filter,num_filter*2)
117
+ #self.conv3 = ConvBlock(num_filter*2,num_filter*4)
118
+ #Resnet blocks
119
+ self.resnet_blocks = []
120
+ for i in range(num_resnet):
121
+ self.resnet_blocks.append(ResnetBlock(num_filter*2))
122
+ self.resnet_blocks = torch.nn.Sequential(*self.resnet_blocks)
123
+ #Decoder
124
+ self.deconv1 = DeconvBlock(num_filter*2,num_filter)
125
+ self.deconv2 = DeconvBlock(num_filter,output_dim,activation='tanh')
126
+ #self.deconv3 = ConvBlock(num_filter,output_dim,kernel_size=7,stride=1,padding=0,activation='tanh',batch_norm=False)
127
+
128
+ def forward(self,x):
129
+ #Encoder
130
+ enc1 = self.conv1(x)
131
+ enc2 = self.conv2(enc1)
132
+ #enc3 = self.conv3(enc2)
133
+ #Resnet blocks
134
+ res = self.resnet_blocks(enc2)
135
+ #Decoder
136
+ dec1 = self.deconv1(res)
137
+ dec2 = self.deconv2(dec1)
138
+ #out = self.deconv3(self.pad(dec2))
139
+ return dec2
140
+
141
+ def normal_weight_init(self,mean=0.0,std=0.02):
142
+ for m in self.children():
143
+ if isinstance(m,ConvBlock):
144
+ torch.nn.init.normal_(m.conv.weight,mean,std)
145
+ if isinstance(m,DeconvBlock):
146
+ torch.nn.init.normal_(m.deconv.weight,mean,std)
147
+ if isinstance(m,ResnetBlock):
148
+ torch.nn.init.normal_(m.conv.weight,mean,std)
149
+ torch.nn.init.constant_(m.conv.bias,0)
150
+
151
+ model = G_A = Generator(3, 32, 3, 4).cuda() # input_dim, num_filter, output_dim, num_resnet
152
+ model.load_state_dict(torch.load('G_A_HW4_SAVE.pt',map_location=torch.device('cpu')))
153
+ model.eval()
154
+
155
+
156
+ totensor = torchvision.transforms.ToTensor()
157
+ normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
158
+ topilimage = torchvision.transforms.ToPILImage()
159
+
160
+ def predict(input):
161
+ im = normalize_fn(totensor(input))
162
+ print(im.shape)
163
+ preds = model(im.unsqueeze(0))/2 + 0.5
164
+ print(preds.shape)
165
+ return topilimage(preds.squeeze(0).detach())
166
+
167
+ gr_interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(256, 256)), outputs="image", title='Horse-to-Zebra CycleGAN')
168
+ gr_interface.launch(inline=False,share=False)