Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,15 +4,149 @@ from huggingface_hub import hf_hub_download
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
from torchvision import transforms
|
7 |
-
from PIL import Image
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
class ImageClassifier:
|
12 |
def __init__(self, checkpoint_path):
|
13 |
self.checkpoint_path = checkpoint_path
|
14 |
self.model = self.load_model(checkpoint_path)
|
15 |
-
self.transform = self.get_transform((
|
16 |
self.labels = [
|
17 |
"airplane",
|
18 |
"automobile",
|
@@ -32,7 +166,7 @@ class ImageClassifier:
|
|
32 |
block="simple",
|
33 |
num_classes=10,
|
34 |
)
|
35 |
-
classifier.load_state_dict(torch.load(checkpoint_path
|
36 |
classifier = classifier.cpu()
|
37 |
classifier.eval()
|
38 |
return classifier
|
@@ -56,18 +190,18 @@ class ImageClassifier:
|
|
56 |
def classify(self, input_image):
|
57 |
return self.predict(input_image)
|
58 |
|
|
|
59 |
def classify(input_image):
|
60 |
return classifier.classify(input_image)
|
61 |
|
|
|
62 |
checkpoint_path = hf_hub_download(
|
63 |
repo_id="SatwikKambham/resnet18-cifar10",
|
64 |
filename="model.pt",
|
65 |
)
|
66 |
-
|
67 |
classifier = ImageClassifier(checkpoint_path)
|
68 |
-
|
69 |
iface = gr.Interface(
|
70 |
-
|
71 |
inputs=[
|
72 |
gr.Image(label="Input Image", type="pil"),
|
73 |
],
|
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
from torchvision import transforms
|
|
|
7 |
|
8 |
+
|
9 |
+
class SimpleResidualBlock(nn.Module):
|
10 |
+
def __init__(self, in_channels, out_channels, set_stride=False):
|
11 |
+
super().__init__()
|
12 |
+
stride = 2 if in_channels != out_channels and set_stride else 1
|
13 |
+
|
14 |
+
self.conv1 = nn.LazyConv2d(
|
15 |
+
out_channels,
|
16 |
+
kernel_size=3,
|
17 |
+
padding="same" if stride == 1 else 1,
|
18 |
+
stride=stride,
|
19 |
+
)
|
20 |
+
self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same")
|
21 |
+
|
22 |
+
self.bn1 = nn.LazyBatchNorm2d()
|
23 |
+
self.bn2 = nn.LazyBatchNorm2d()
|
24 |
+
|
25 |
+
self.relu = nn.ReLU()
|
26 |
+
|
27 |
+
if in_channels != out_channels:
|
28 |
+
self.residual = nn.Sequential(
|
29 |
+
nn.LazyConv2d(out_channels, kernel_size=1, stride=stride),
|
30 |
+
nn.LazyBatchNorm2d(),
|
31 |
+
)
|
32 |
+
else:
|
33 |
+
self.residual = nn.Identity()
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
37 |
+
out = self.bn2(self.conv2(out))
|
38 |
+
out += self.residual(x)
|
39 |
+
out = self.relu(out)
|
40 |
+
return out
|
41 |
+
|
42 |
+
|
43 |
+
class BottleneckResidualBlock(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self, in_channels, out_channels, identity_mapping=False, set_stride=False
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
stride = 2 if in_channels != out_channels and set_stride else 1
|
49 |
+
|
50 |
+
self.conv1 = nn.LazyConv2d(
|
51 |
+
out_channels,
|
52 |
+
kernel_size=1,
|
53 |
+
padding="same" if stride == 1 else 0,
|
54 |
+
stride=stride,
|
55 |
+
)
|
56 |
+
self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same")
|
57 |
+
self.conv3 = nn.LazyConv2d(out_channels * 4, kernel_size=1, padding="same")
|
58 |
+
|
59 |
+
self.bn1 = nn.LazyBatchNorm2d()
|
60 |
+
self.bn2 = nn.LazyBatchNorm2d()
|
61 |
+
self.bn3 = nn.LazyBatchNorm2d()
|
62 |
+
|
63 |
+
self.relu = nn.ReLU()
|
64 |
+
|
65 |
+
if in_channels != out_channels or not identity_mapping:
|
66 |
+
self.residual = nn.Sequential(
|
67 |
+
nn.LazyConv2d(out_channels * 4, kernel_size=1, stride=stride),
|
68 |
+
nn.LazyBatchNorm2d(),
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
self.residual = nn.Identity()
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
75 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
76 |
+
out = self.bn3(self.conv3(out))
|
77 |
+
out += self.residual(x)
|
78 |
+
out = self.relu(out)
|
79 |
+
return out
|
80 |
+
|
81 |
+
|
82 |
+
RESNET_18 = [2, 2, 2, 2]
|
83 |
+
RESNET_34 = [3, 4, 6, 3]
|
84 |
+
RESNET_50 = [3, 4, 6, 3]
|
85 |
+
RESNET_101 = [3, 4, 23, 3]
|
86 |
+
RESNET_152 = [3, 8, 36, 3]
|
87 |
+
|
88 |
+
|
89 |
+
class ResNet(nn.Module):
|
90 |
+
def __init__(self, arch=RESNET_18, block="simple", num_classes=256):
|
91 |
+
super().__init__()
|
92 |
+
self.conv1 = nn.Sequential(
|
93 |
+
nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
|
94 |
+
nn.LazyBatchNorm2d(),
|
95 |
+
nn.ReLU(),
|
96 |
+
)
|
97 |
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
98 |
+
self.conv2 = self._make_layer(64, 64, arch[0], set_stride=False, block=block)
|
99 |
+
self.conv3 = self._make_layer(64, 128, arch[1], block=block)
|
100 |
+
self.conv4 = self._make_layer(128, 256, arch[2], block=block)
|
101 |
+
self.conv5 = self._make_layer(256, 512, arch[3], block=block)
|
102 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
103 |
+
self.flatten = nn.Flatten()
|
104 |
+
self.fc = nn.LazyLinear(num_classes)
|
105 |
+
|
106 |
+
def _make_layer(
|
107 |
+
self, in_channels, out_channels, num_blocks, set_stride=True, block="simple"
|
108 |
+
):
|
109 |
+
"""Block is either 'simple' or 'bottleneck'"""
|
110 |
+
layers = []
|
111 |
+
for i in range(num_blocks):
|
112 |
+
layers.append(
|
113 |
+
SimpleResidualBlock(in_channels, out_channels, set_stride=set_stride)
|
114 |
+
if block == "simple"
|
115 |
+
else BottleneckResidualBlock(
|
116 |
+
in_channels if i == 0 else out_channels * 4,
|
117 |
+
out_channels,
|
118 |
+
set_stride=set_stride,
|
119 |
+
)
|
120 |
+
)
|
121 |
+
set_stride = False
|
122 |
+
return nn.Sequential(*layers)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
out = self.conv1(x)
|
126 |
+
out = self.maxpool(self.conv2(out))
|
127 |
+
out = self.conv3(out)
|
128 |
+
out = self.conv4(out)
|
129 |
+
out = self.conv5(out)
|
130 |
+
out = self.avgpool(out)
|
131 |
+
out = self.flatten(out)
|
132 |
+
out = self.fc(out)
|
133 |
+
return out
|
134 |
+
|
135 |
+
def _init_weights(module):
|
136 |
+
# Initlize weights with glorot uniform
|
137 |
+
if isinstance(module, nn.Conv2d):
|
138 |
+
nn.init.xavier_uniform_(module.weight)
|
139 |
+
nn.init.zeros_(module.bias)
|
140 |
+
elif isinstance(module, nn.Linear):
|
141 |
+
nn.init.xavier_uniform_(module.weight)
|
142 |
+
nn.init.zeros_(module.bias)
|
143 |
+
|
144 |
|
145 |
class ImageClassifier:
|
146 |
def __init__(self, checkpoint_path):
|
147 |
self.checkpoint_path = checkpoint_path
|
148 |
self.model = self.load_model(checkpoint_path)
|
149 |
+
self.transform = self.get_transform((244, 244))
|
150 |
self.labels = [
|
151 |
"airplane",
|
152 |
"automobile",
|
|
|
166 |
block="simple",
|
167 |
num_classes=10,
|
168 |
)
|
169 |
+
classifier.load_state_dict(torch.load(checkpoint_path))
|
170 |
classifier = classifier.cpu()
|
171 |
classifier.eval()
|
172 |
return classifier
|
|
|
190 |
def classify(self, input_image):
|
191 |
return self.predict(input_image)
|
192 |
|
193 |
+
|
194 |
def classify(input_image):
|
195 |
return classifier.classify(input_image)
|
196 |
|
197 |
+
|
198 |
checkpoint_path = hf_hub_download(
|
199 |
repo_id="SatwikKambham/resnet18-cifar10",
|
200 |
filename="model.pt",
|
201 |
)
|
|
|
202 |
classifier = ImageClassifier(checkpoint_path)
|
|
|
203 |
iface = gr.Interface(
|
204 |
+
classify,
|
205 |
inputs=[
|
206 |
gr.Image(label="Input Image", type="pil"),
|
207 |
],
|