Spaces:
Runtime error
Runtime error
Trained SqueezeNet with torch weights
Browse files- services/backend/model/model.pt +2 -2
- services/backend/model/transform.pt +1 -1
- services/backend/src/__pycache__/data.cpython-311.pyc +0 -0
- services/backend/src/__pycache__/main.cpython-311.pyc +0 -0
- services/backend/src/__pycache__/model.cpython-311.pyc +0 -0
- services/backend/src/architectures/__pycache__/squeeze_net.cpython-311.pyc +0 -0
- services/backend/src/architectures/squeeze_net.py +27 -1
- services/backend/src/data.py +1 -1
- services/backend/src/main.py +7 -6
- services/backend/src/model.py +9 -3
services/backend/model/model.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5541ad1fb46690d0d0de99beafa579f013e8cfcda6ffa1e7b616acc246e69eb3
|
3 |
+
size 3215922
|
services/backend/model/transform.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4363
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:471b2bb67abb7f23f5e283cd84643e3148776dfc6a7fb0597775b397bdb8a710
|
3 |
size 4363
|
services/backend/src/__pycache__/data.cpython-311.pyc
CHANGED
Binary files a/services/backend/src/__pycache__/data.cpython-311.pyc and b/services/backend/src/__pycache__/data.cpython-311.pyc differ
|
|
services/backend/src/__pycache__/main.cpython-311.pyc
CHANGED
Binary files a/services/backend/src/__pycache__/main.cpython-311.pyc and b/services/backend/src/__pycache__/main.cpython-311.pyc differ
|
|
services/backend/src/__pycache__/model.cpython-311.pyc
CHANGED
Binary files a/services/backend/src/__pycache__/model.cpython-311.pyc and b/services/backend/src/__pycache__/model.cpython-311.pyc differ
|
|
services/backend/src/architectures/__pycache__/squeeze_net.cpython-311.pyc
CHANGED
Binary files a/services/backend/src/architectures/__pycache__/squeeze_net.cpython-311.pyc and b/services/backend/src/architectures/__pycache__/squeeze_net.cpython-311.pyc differ
|
|
services/backend/src/architectures/squeeze_net.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6 |
from torch import nn
|
7 |
from collections import OrderedDict
|
8 |
from src.architectures.deep_cnn import CNNBlock
|
|
|
9 |
|
10 |
|
11 |
class FireBlock(nn.Module):
|
@@ -38,7 +39,7 @@ class FireBlock(nn.Module):
|
|
38 |
return out
|
39 |
|
40 |
|
41 |
-
class
|
42 |
def __init__(
|
43 |
self,
|
44 |
in_channels: int = 3,
|
@@ -101,3 +102,28 @@ class SqueezeNet(nn.Module):
|
|
101 |
@property
|
102 |
def name(self):
|
103 |
return "SqueezeNet"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from torch import nn
|
7 |
from collections import OrderedDict
|
8 |
from src.architectures.deep_cnn import CNNBlock
|
9 |
+
from src.architectures.base import Backbone
|
10 |
|
11 |
|
12 |
class FireBlock(nn.Module):
|
|
|
39 |
return out
|
40 |
|
41 |
|
42 |
+
class SqueezeNet1_0(nn.Module):
|
43 |
def __init__(
|
44 |
self,
|
45 |
in_channels: int = 3,
|
|
|
102 |
@property
|
103 |
def name(self):
|
104 |
return "SqueezeNet"
|
105 |
+
|
106 |
+
|
107 |
+
from typing import Literal
|
108 |
+
|
109 |
+
|
110 |
+
class SqueezeNet:
|
111 |
+
def __new__(
|
112 |
+
cls,
|
113 |
+
in_channels: int,
|
114 |
+
version: Literal["squeezenet1_0", "squeezenet1_1"],
|
115 |
+
load_from_torch: bool = False,
|
116 |
+
pretrained: bool = False,
|
117 |
+
freeze_extractor: bool = False,
|
118 |
+
):
|
119 |
+
if load_from_torch:
|
120 |
+
_net = torch.hub.load("pytorch/vision:v0.10.0", version, pretrained=pretrained)
|
121 |
+
_net.classifier[0] = torch.nn.Identity()
|
122 |
+
_net.classifier[1] = torch.nn.Identity()
|
123 |
+
_net.classifier[2] = torch.nn.Identity()
|
124 |
+
net = Backbone(_net, out_channels=512, name=version)
|
125 |
+
if freeze_extractor:
|
126 |
+
net.freeze()
|
127 |
+
else:
|
128 |
+
net = SqueezeNet1_0(in_channels=in_channels)
|
129 |
+
return net
|
services/backend/src/data.py
CHANGED
@@ -41,7 +41,7 @@ class FlowersDataset(torchvision.datasets.Flowers102):
|
|
41 |
self.idx2label = {i: label for i, label in enumerate(labels)}
|
42 |
|
43 |
self._named_labels = np.array([self.idx2label[target] for target in self._labels])
|
44 |
-
self.classes = list(
|
45 |
|
46 |
def get_raw_img(self, idx: int):
|
47 |
image_file = self._image_files[idx]
|
|
|
41 |
self.idx2label = {i: label for i, label in enumerate(labels)}
|
42 |
|
43 |
self._named_labels = np.array([self.idx2label[target] for target in self._labels])
|
44 |
+
self.classes = list(self.idx2label.values())
|
45 |
|
46 |
def get_raw_img(self, idx: int):
|
47 |
image_file = self._image_files[idx]
|
services/backend/src/main.py
CHANGED
@@ -21,7 +21,7 @@ app.add_middleware(
|
|
21 |
|
22 |
model_path = "model/model.pt"
|
23 |
transform_path = "model/transform.pt"
|
24 |
-
mapping_path = "model/mapping.
|
25 |
|
26 |
model = torch.jit.load(model_path)
|
27 |
model.eval()
|
@@ -29,9 +29,10 @@ model.to("cpu")
|
|
29 |
transform = torch.jit.load(transform_path)
|
30 |
|
31 |
|
32 |
-
with open(mapping_path
|
33 |
-
|
34 |
-
|
|
|
35 |
|
36 |
|
37 |
def load_image_into_numpy_array(data):
|
@@ -45,8 +46,8 @@ async def predict(file: UploadFile = File(...)):
|
|
45 |
img = transform(img).unsqueeze(0)
|
46 |
log_probs = model(img)[0]
|
47 |
probs = torch.exp(log_probs)
|
48 |
-
|
49 |
-
return
|
50 |
|
51 |
|
52 |
@app.get("/")
|
|
|
21 |
|
22 |
model_path = "model/model.pt"
|
23 |
transform_path = "model/transform.pt"
|
24 |
+
mapping_path = "model/mapping.txt"
|
25 |
|
26 |
model = torch.jit.load(model_path)
|
27 |
model.eval()
|
|
|
29 |
transform = torch.jit.load(transform_path)
|
30 |
|
31 |
|
32 |
+
with open(mapping_path) as f:
|
33 |
+
labels = f.readlines()
|
34 |
+
labels = [label.strip() for label in labels]
|
35 |
+
idx2label = {i: label for i, label in enumerate(labels)}
|
36 |
|
37 |
|
38 |
def load_image_into_numpy_array(data):
|
|
|
46 |
img = transform(img).unsqueeze(0)
|
47 |
log_probs = model(img)[0]
|
48 |
probs = torch.exp(log_probs)
|
49 |
+
confidences = {idx2label[i]: float(probs[i]) for i in range(len(idx2label))}
|
50 |
+
return confidences
|
51 |
|
52 |
|
53 |
@app.get("/")
|
services/backend/src/model.py
CHANGED
@@ -136,14 +136,20 @@ def load_net(num_classes: int) -> nn.Module:
|
|
136 |
# pool_kernels=[2, 1, 2, 1, 2],
|
137 |
# )
|
138 |
|
139 |
-
|
140 |
-
backbone = ResNet(
|
141 |
in_channels=3,
|
142 |
-
version="
|
143 |
load_from_torch=True,
|
144 |
pretrained=True,
|
145 |
freeze_extractor=True,
|
146 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
return nn.Sequential(
|
149 |
backbone,
|
|
|
136 |
# pool_kernels=[2, 1, 2, 1, 2],
|
137 |
# )
|
138 |
|
139 |
+
backbone = SqueezeNet(
|
|
|
140 |
in_channels=3,
|
141 |
+
version="squeezenet1_0",
|
142 |
load_from_torch=True,
|
143 |
pretrained=True,
|
144 |
freeze_extractor=True,
|
145 |
)
|
146 |
+
# backbone = ResNet(
|
147 |
+
# in_channels=3,
|
148 |
+
# version="resnet18",
|
149 |
+
# load_from_torch=True,
|
150 |
+
# pretrained=True,
|
151 |
+
# freeze_extractor=True,
|
152 |
+
# )
|
153 |
|
154 |
return nn.Sequential(
|
155 |
backbone,
|