thawro commited on
Commit
8b6e8c0
1 Parent(s): a2308cf

Trained SqueezeNet with torch weights

Browse files
services/backend/model/model.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:aa39db0e7695da5c388e69de4f56b3b5e5a6f33df3a9bdb197f3e59d8343a16f
3
- size 45041378
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5541ad1fb46690d0d0de99beafa579f013e8cfcda6ffa1e7b616acc246e69eb3
3
+ size 3215922
services/backend/model/transform.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c26dd204ca8676a5affd07a2cbbd162b79e0e01ebba451fdf599aa6abb242a46
3
  size 4363
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:471b2bb67abb7f23f5e283cd84643e3148776dfc6a7fb0597775b397bdb8a710
3
  size 4363
services/backend/src/__pycache__/data.cpython-311.pyc CHANGED
Binary files a/services/backend/src/__pycache__/data.cpython-311.pyc and b/services/backend/src/__pycache__/data.cpython-311.pyc differ
 
services/backend/src/__pycache__/main.cpython-311.pyc CHANGED
Binary files a/services/backend/src/__pycache__/main.cpython-311.pyc and b/services/backend/src/__pycache__/main.cpython-311.pyc differ
 
services/backend/src/__pycache__/model.cpython-311.pyc CHANGED
Binary files a/services/backend/src/__pycache__/model.cpython-311.pyc and b/services/backend/src/__pycache__/model.cpython-311.pyc differ
 
services/backend/src/architectures/__pycache__/squeeze_net.cpython-311.pyc CHANGED
Binary files a/services/backend/src/architectures/__pycache__/squeeze_net.cpython-311.pyc and b/services/backend/src/architectures/__pycache__/squeeze_net.cpython-311.pyc differ
 
services/backend/src/architectures/squeeze_net.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  from torch import nn
7
  from collections import OrderedDict
8
  from src.architectures.deep_cnn import CNNBlock
 
9
 
10
 
11
  class FireBlock(nn.Module):
@@ -38,7 +39,7 @@ class FireBlock(nn.Module):
38
  return out
39
 
40
 
41
- class SqueezeNet(nn.Module):
42
  def __init__(
43
  self,
44
  in_channels: int = 3,
@@ -101,3 +102,28 @@ class SqueezeNet(nn.Module):
101
  @property
102
  def name(self):
103
  return "SqueezeNet"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from torch import nn
7
  from collections import OrderedDict
8
  from src.architectures.deep_cnn import CNNBlock
9
+ from src.architectures.base import Backbone
10
 
11
 
12
  class FireBlock(nn.Module):
 
39
  return out
40
 
41
 
42
+ class SqueezeNet1_0(nn.Module):
43
  def __init__(
44
  self,
45
  in_channels: int = 3,
 
102
  @property
103
  def name(self):
104
  return "SqueezeNet"
105
+
106
+
107
+ from typing import Literal
108
+
109
+
110
+ class SqueezeNet:
111
+ def __new__(
112
+ cls,
113
+ in_channels: int,
114
+ version: Literal["squeezenet1_0", "squeezenet1_1"],
115
+ load_from_torch: bool = False,
116
+ pretrained: bool = False,
117
+ freeze_extractor: bool = False,
118
+ ):
119
+ if load_from_torch:
120
+ _net = torch.hub.load("pytorch/vision:v0.10.0", version, pretrained=pretrained)
121
+ _net.classifier[0] = torch.nn.Identity()
122
+ _net.classifier[1] = torch.nn.Identity()
123
+ _net.classifier[2] = torch.nn.Identity()
124
+ net = Backbone(_net, out_channels=512, name=version)
125
+ if freeze_extractor:
126
+ net.freeze()
127
+ else:
128
+ net = SqueezeNet1_0(in_channels=in_channels)
129
+ return net
services/backend/src/data.py CHANGED
@@ -41,7 +41,7 @@ class FlowersDataset(torchvision.datasets.Flowers102):
41
  self.idx2label = {i: label for i, label in enumerate(labels)}
42
 
43
  self._named_labels = np.array([self.idx2label[target] for target in self._labels])
44
- self.classes = list(label2idx.keys())
45
 
46
  def get_raw_img(self, idx: int):
47
  image_file = self._image_files[idx]
 
41
  self.idx2label = {i: label for i, label in enumerate(labels)}
42
 
43
  self._named_labels = np.array([self.idx2label[target] for target in self._labels])
44
+ self.classes = list(self.idx2label.values())
45
 
46
  def get_raw_img(self, idx: int):
47
  image_file = self._image_files[idx]
services/backend/src/main.py CHANGED
@@ -21,7 +21,7 @@ app.add_middleware(
21
 
22
  model_path = "model/model.pt"
23
  transform_path = "model/transform.pt"
24
- mapping_path = "model/mapping.json"
25
 
26
  model = torch.jit.load(model_path)
27
  model.eval()
@@ -29,9 +29,10 @@ model.to("cpu")
29
  transform = torch.jit.load(transform_path)
30
 
31
 
32
- with open(mapping_path, "rb") as f:
33
- label2idx = json.load(f)
34
- LABELS = list(label2idx.keys())
 
35
 
36
 
37
  def load_image_into_numpy_array(data):
@@ -45,8 +46,8 @@ async def predict(file: UploadFile = File(...)):
45
  img = transform(img).unsqueeze(0)
46
  log_probs = model(img)[0]
47
  probs = torch.exp(log_probs)
48
- label_probs = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
49
- return label_probs
50
 
51
 
52
  @app.get("/")
 
21
 
22
  model_path = "model/model.pt"
23
  transform_path = "model/transform.pt"
24
+ mapping_path = "model/mapping.txt"
25
 
26
  model = torch.jit.load(model_path)
27
  model.eval()
 
29
  transform = torch.jit.load(transform_path)
30
 
31
 
32
+ with open(mapping_path) as f:
33
+ labels = f.readlines()
34
+ labels = [label.strip() for label in labels]
35
+ idx2label = {i: label for i, label in enumerate(labels)}
36
 
37
 
38
  def load_image_into_numpy_array(data):
 
46
  img = transform(img).unsqueeze(0)
47
  log_probs = model(img)[0]
48
  probs = torch.exp(log_probs)
49
+ confidences = {idx2label[i]: float(probs[i]) for i in range(len(idx2label))}
50
+ return confidences
51
 
52
 
53
  @app.get("/")
services/backend/src/model.py CHANGED
@@ -136,14 +136,20 @@ def load_net(num_classes: int) -> nn.Module:
136
  # pool_kernels=[2, 1, 2, 1, 2],
137
  # )
138
 
139
- # backbone = SqueezeNet()
140
- backbone = ResNet(
141
  in_channels=3,
142
- version="resnet18",
143
  load_from_torch=True,
144
  pretrained=True,
145
  freeze_extractor=True,
146
  )
 
 
 
 
 
 
 
147
 
148
  return nn.Sequential(
149
  backbone,
 
136
  # pool_kernels=[2, 1, 2, 1, 2],
137
  # )
138
 
139
+ backbone = SqueezeNet(
 
140
  in_channels=3,
141
+ version="squeezenet1_0",
142
  load_from_torch=True,
143
  pretrained=True,
144
  freeze_extractor=True,
145
  )
146
+ # backbone = ResNet(
147
+ # in_channels=3,
148
+ # version="resnet18",
149
+ # load_from_torch=True,
150
+ # pretrained=True,
151
+ # freeze_extractor=True,
152
+ # )
153
 
154
  return nn.Sequential(
155
  backbone,