mehdiabruee commited on
Commit
0696517
1 Parent(s): e78ca20

Upload 3 files

Browse files
Files changed (3) hide show
  1. G_A_HW4_SAVE.pt +3 -0
  2. app.py +169 -0
  3. requirements.txt +2 -0
G_A_HW4_SAVE.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2f304a37c21b8bc1d31157b40515a1db4847b0be7b1e5c4123a4bbf0d00d81a
3
+ size 1466847
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """app.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/13tu6v1reMxLATyBwle-BgpQrql9p4nqn
8
+ """
9
+
10
+ from fastai.vision.all import *
11
+ from fastai.basics import *
12
+ import torchvision
13
+ import gradio as gr
14
+ #dls = get_dls_from_hf("huggan/horse2zebra", load_size=286)
15
+ #cycle_gan = CycleGAN.from_pretrained('mehdiabruee/HW4')
16
+ # -*- coding: utf-8 -*-
17
+ """cyclegan_inference.ipynb
18
+
19
+ Automatically generated by Colaboratory.
20
+
21
+ Original file is located at
22
+ https://colab.research.google.com/drive/12lelsBZXqNOe7xaXI724rEHAbppRt07y
23
+ """
24
+
25
+ import gradio as gr
26
+ import torch
27
+ import torchvision
28
+ from torch import nn
29
+ from typing import List
30
+
31
+ def ifnone(a, b): # a fastai-specific (fastcore) function used below, redefined so it's independent
32
+ "`b` if `a` is None else `a`"
33
+ return b if a is None else a
34
+
35
+ class ConvBlock(torch.nn.Module):
36
+ def __init__(self,input_size,output_size,kernel_size=4,stride=2,padding=1,activation='relu',batch_norm=True):
37
+ super(ConvBlock,self).__init__()
38
+ self.conv = torch.nn.Conv2d(input_size,output_size,kernel_size,stride,padding)
39
+ self.batch_norm = batch_norm
40
+ self.bn = torch.nn.InstanceNorm2d(output_size)
41
+ self.activation = activation
42
+ self.relu = torch.nn.ReLU(True)
43
+ self.lrelu = torch.nn.LeakyReLU(0.2,True)
44
+ self.tanh = torch.nn.Tanh()
45
+ self.sigmoid = torch.nn.Sigmoid()
46
+ def forward(self,x):
47
+ if self.batch_norm:
48
+ out = self.bn(self.conv(x))
49
+ else:
50
+ out = self.conv(x)
51
+
52
+ if self.activation == 'relu':
53
+ return self.relu(out)
54
+ elif self.activation == 'lrelu':
55
+ return self.lrelu(out)
56
+ elif self.activation == 'tanh':
57
+ return self.tanh(out)
58
+ elif self.activation == 'no_act':
59
+ return out
60
+ elif self.activation =='sigmoid':
61
+ return self.sigmoid(out)
62
+
63
+
64
+ class ResnetBlock(torch.nn.Module):
65
+ def __init__(self,num_filter,kernel_size=3,stride=1,padding=0):
66
+ super(ResnetBlock,self).__init__()
67
+ conv1 = torch.nn.Conv2d(num_filter,num_filter,kernel_size,stride,padding)
68
+ conv2 = torch.nn.Conv2d(num_filter,num_filter,kernel_size,stride,padding)
69
+ bn = torch.nn.InstanceNorm2d(num_filter)
70
+ relu = torch.nn.ReLU(True)
71
+ pad = torch.nn.ReflectionPad2d(1)
72
+
73
+ self.resnet_block = torch.nn.Sequential(
74
+ pad,
75
+ conv1,
76
+ bn,
77
+ relu,
78
+ pad,
79
+ conv2,
80
+ bn
81
+ )
82
+ def forward(self,x):
83
+ out = self.resnet_block(x)
84
+ return out
85
+
86
+ class DeconvBlock(torch.nn.Module):
87
+ def __init__(self,input_size,output_size,kernel_size=4,stride=2,padding=1,activation='relu',batch_norm=True):
88
+ super(DeconvBlock,self).__init__()
89
+ self.deconv = torch.nn.ConvTranspose2d(input_size,output_size,kernel_size,stride,padding)
90
+ self.batch_norm = batch_norm
91
+ self.bn = torch.nn.InstanceNorm2d(output_size)
92
+ self.activation = activation
93
+ self.relu = torch.nn.ReLU(True)
94
+ self.tanh = torch.nn.Tanh()
95
+ def forward(self,x):
96
+ if self.batch_norm:
97
+ out = self.bn(self.deconv(x))
98
+ else:
99
+ out = self.deconv(x)
100
+ if self.activation == 'relu':
101
+ return self.relu(out)
102
+ elif self.activation == 'lrelu':
103
+ return self.lrelu(out)
104
+ elif self.activation == 'tanh':
105
+ return self.tanh(out)
106
+ elif self.activation == 'no_act':
107
+ return out
108
+
109
+ class Generator(torch.nn.Module):
110
+ def __init__(self,input_dim,num_filter,output_dim,num_resnet):
111
+ super(Generator,self).__init__()
112
+
113
+ #Reflection padding
114
+ #self.pad = torch.nn.ReflectionPad2d(3)
115
+ #Encoder
116
+ self.conv1 = ConvBlock(input_dim,num_filter,kernel_size=4,stride=2,padding=1)
117
+ self.conv2 = ConvBlock(num_filter,num_filter*2)
118
+ #self.conv3 = ConvBlock(num_filter*2,num_filter*4)
119
+ #Resnet blocks
120
+ self.resnet_blocks = []
121
+ for i in range(num_resnet):
122
+ self.resnet_blocks.append(ResnetBlock(num_filter*2))
123
+ self.resnet_blocks = torch.nn.Sequential(*self.resnet_blocks)
124
+ #Decoder
125
+ self.deconv1 = DeconvBlock(num_filter*2,num_filter)
126
+ self.deconv2 = DeconvBlock(num_filter,output_dim,activation='tanh')
127
+ #self.deconv3 = ConvBlock(num_filter,output_dim,kernel_size=7,stride=1,padding=0,activation='tanh',batch_norm=False)
128
+
129
+ def forward(self,x):
130
+ #Encoder
131
+ enc1 = self.conv1(x)
132
+ enc2 = self.conv2(enc1)
133
+ #enc3 = self.conv3(enc2)
134
+ #Resnet blocks
135
+ res = self.resnet_blocks(enc2)
136
+ #Decoder
137
+ dec1 = self.deconv1(res)
138
+ dec2 = self.deconv2(dec1)
139
+ #out = self.deconv3(self.pad(dec2))
140
+ return dec2
141
+
142
+ def normal_weight_init(self,mean=0.0,std=0.02):
143
+ for m in self.children():
144
+ if isinstance(m,ConvBlock):
145
+ torch.nn.init.normal_(m.conv.weight,mean,std)
146
+ if isinstance(m,DeconvBlock):
147
+ torch.nn.init.normal_(m.deconv.weight,mean,std)
148
+ if isinstance(m,ResnetBlock):
149
+ torch.nn.init.normal_(m.conv.weight,mean,std)
150
+ torch.nn.init.constant_(m.conv.bias,0)
151
+
152
+ model = Generator(3, 32, 3, 4).cuda() # input_dim, num_filter, output_dim, num_resnet
153
+ model.load_state_dict(torch.load('G_A_HW4_SAVE.pt',map_location=torch.device('cpu')))
154
+ print(model)
155
+ model.eval()
156
+
157
+
158
+ totensor = torchvision.transforms.ToTensor()
159
+ normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
160
+ topilimage = torchvision.transforms.ToPILImage()
161
+
162
+ def predict(input):
163
+ im = normalize_fn(totensor(input))
164
+ print(im.shape)
165
+ preds = model(im.unsqueeze(0))/2 + 0.5
166
+ print(preds.shape)
167
+ return topilimage(preds.squeeze(0).detach())
168
+
169
+ gr_interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(256, 256)), outputs="image", title='Horse-to-Zebra CycleGAN')
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ torchvision