trafaqat commited on
Commit
f651983
1 Parent(s): 9a0f383

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -7
app.py CHANGED
@@ -4,15 +4,149 @@ from huggingface_hub import hf_hub_download
4
  import torch
5
  import torch.nn as nn
6
  from torchvision import transforms
7
- from PIL import Image
8
 
9
- # Define the model classes (SimpleResidualBlock, BottleneckResidualBlock, ResNet) here...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class ImageClassifier:
12
  def __init__(self, checkpoint_path):
13
  self.checkpoint_path = checkpoint_path
14
  self.model = self.load_model(checkpoint_path)
15
- self.transform = self.get_transform((224, 224)) # Typical size for ResNet
16
  self.labels = [
17
  "airplane",
18
  "automobile",
@@ -32,7 +166,7 @@ class ImageClassifier:
32
  block="simple",
33
  num_classes=10,
34
  )
35
- classifier.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
36
  classifier = classifier.cpu()
37
  classifier.eval()
38
  return classifier
@@ -56,18 +190,18 @@ class ImageClassifier:
56
  def classify(self, input_image):
57
  return self.predict(input_image)
58
 
 
59
  def classify(input_image):
60
  return classifier.classify(input_image)
61
 
 
62
  checkpoint_path = hf_hub_download(
63
  repo_id="SatwikKambham/resnet18-cifar10",
64
  filename="model.pt",
65
  )
66
-
67
  classifier = ImageClassifier(checkpoint_path)
68
-
69
  iface = gr.Interface(
70
- fn=classify,
71
  inputs=[
72
  gr.Image(label="Input Image", type="pil"),
73
  ],
 
4
  import torch
5
  import torch.nn as nn
6
  from torchvision import transforms
 
7
 
8
+
9
+ class SimpleResidualBlock(nn.Module):
10
+ def __init__(self, in_channels, out_channels, set_stride=False):
11
+ super().__init__()
12
+ stride = 2 if in_channels != out_channels and set_stride else 1
13
+
14
+ self.conv1 = nn.LazyConv2d(
15
+ out_channels,
16
+ kernel_size=3,
17
+ padding="same" if stride == 1 else 1,
18
+ stride=stride,
19
+ )
20
+ self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same")
21
+
22
+ self.bn1 = nn.LazyBatchNorm2d()
23
+ self.bn2 = nn.LazyBatchNorm2d()
24
+
25
+ self.relu = nn.ReLU()
26
+
27
+ if in_channels != out_channels:
28
+ self.residual = nn.Sequential(
29
+ nn.LazyConv2d(out_channels, kernel_size=1, stride=stride),
30
+ nn.LazyBatchNorm2d(),
31
+ )
32
+ else:
33
+ self.residual = nn.Identity()
34
+
35
+ def forward(self, x):
36
+ out = self.relu(self.bn1(self.conv1(x)))
37
+ out = self.bn2(self.conv2(out))
38
+ out += self.residual(x)
39
+ out = self.relu(out)
40
+ return out
41
+
42
+
43
+ class BottleneckResidualBlock(nn.Module):
44
+ def __init__(
45
+ self, in_channels, out_channels, identity_mapping=False, set_stride=False
46
+ ):
47
+ super().__init__()
48
+ stride = 2 if in_channels != out_channels and set_stride else 1
49
+
50
+ self.conv1 = nn.LazyConv2d(
51
+ out_channels,
52
+ kernel_size=1,
53
+ padding="same" if stride == 1 else 0,
54
+ stride=stride,
55
+ )
56
+ self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same")
57
+ self.conv3 = nn.LazyConv2d(out_channels * 4, kernel_size=1, padding="same")
58
+
59
+ self.bn1 = nn.LazyBatchNorm2d()
60
+ self.bn2 = nn.LazyBatchNorm2d()
61
+ self.bn3 = nn.LazyBatchNorm2d()
62
+
63
+ self.relu = nn.ReLU()
64
+
65
+ if in_channels != out_channels or not identity_mapping:
66
+ self.residual = nn.Sequential(
67
+ nn.LazyConv2d(out_channels * 4, kernel_size=1, stride=stride),
68
+ nn.LazyBatchNorm2d(),
69
+ )
70
+ else:
71
+ self.residual = nn.Identity()
72
+
73
+ def forward(self, x):
74
+ out = self.relu(self.bn1(self.conv1(x)))
75
+ out = self.relu(self.bn2(self.conv2(out)))
76
+ out = self.bn3(self.conv3(out))
77
+ out += self.residual(x)
78
+ out = self.relu(out)
79
+ return out
80
+
81
+
82
+ RESNET_18 = [2, 2, 2, 2]
83
+ RESNET_34 = [3, 4, 6, 3]
84
+ RESNET_50 = [3, 4, 6, 3]
85
+ RESNET_101 = [3, 4, 23, 3]
86
+ RESNET_152 = [3, 8, 36, 3]
87
+
88
+
89
+ class ResNet(nn.Module):
90
+ def __init__(self, arch=RESNET_18, block="simple", num_classes=256):
91
+ super().__init__()
92
+ self.conv1 = nn.Sequential(
93
+ nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
94
+ nn.LazyBatchNorm2d(),
95
+ nn.ReLU(),
96
+ )
97
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
98
+ self.conv2 = self._make_layer(64, 64, arch[0], set_stride=False, block=block)
99
+ self.conv3 = self._make_layer(64, 128, arch[1], block=block)
100
+ self.conv4 = self._make_layer(128, 256, arch[2], block=block)
101
+ self.conv5 = self._make_layer(256, 512, arch[3], block=block)
102
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
103
+ self.flatten = nn.Flatten()
104
+ self.fc = nn.LazyLinear(num_classes)
105
+
106
+ def _make_layer(
107
+ self, in_channels, out_channels, num_blocks, set_stride=True, block="simple"
108
+ ):
109
+ """Block is either 'simple' or 'bottleneck'"""
110
+ layers = []
111
+ for i in range(num_blocks):
112
+ layers.append(
113
+ SimpleResidualBlock(in_channels, out_channels, set_stride=set_stride)
114
+ if block == "simple"
115
+ else BottleneckResidualBlock(
116
+ in_channels if i == 0 else out_channels * 4,
117
+ out_channels,
118
+ set_stride=set_stride,
119
+ )
120
+ )
121
+ set_stride = False
122
+ return nn.Sequential(*layers)
123
+
124
+ def forward(self, x):
125
+ out = self.conv1(x)
126
+ out = self.maxpool(self.conv2(out))
127
+ out = self.conv3(out)
128
+ out = self.conv4(out)
129
+ out = self.conv5(out)
130
+ out = self.avgpool(out)
131
+ out = self.flatten(out)
132
+ out = self.fc(out)
133
+ return out
134
+
135
+ def _init_weights(module):
136
+ # Initlize weights with glorot uniform
137
+ if isinstance(module, nn.Conv2d):
138
+ nn.init.xavier_uniform_(module.weight)
139
+ nn.init.zeros_(module.bias)
140
+ elif isinstance(module, nn.Linear):
141
+ nn.init.xavier_uniform_(module.weight)
142
+ nn.init.zeros_(module.bias)
143
+
144
 
145
  class ImageClassifier:
146
  def __init__(self, checkpoint_path):
147
  self.checkpoint_path = checkpoint_path
148
  self.model = self.load_model(checkpoint_path)
149
+ self.transform = self.get_transform((244, 244))
150
  self.labels = [
151
  "airplane",
152
  "automobile",
 
166
  block="simple",
167
  num_classes=10,
168
  )
169
+ classifier.load_state_dict(torch.load(checkpoint_path))
170
  classifier = classifier.cpu()
171
  classifier.eval()
172
  return classifier
 
190
  def classify(self, input_image):
191
  return self.predict(input_image)
192
 
193
+
194
  def classify(input_image):
195
  return classifier.classify(input_image)
196
 
197
+
198
  checkpoint_path = hf_hub_download(
199
  repo_id="SatwikKambham/resnet18-cifar10",
200
  filename="model.pt",
201
  )
 
202
  classifier = ImageClassifier(checkpoint_path)
 
203
  iface = gr.Interface(
204
+ classify,
205
  inputs=[
206
  gr.Image(label="Input Image", type="pil"),
207
  ],