Hector Lopez commited on
Commit
9fbf078
·
1 Parent(s): 161f9af

feature: Objects classification

Browse files
Files changed (3) hide show
  1. app.py +38 -5
  2. classifier.py +45 -0
  3. model.py +18 -1
app.py CHANGED
@@ -3,25 +3,55 @@ import matplotlib.pyplot as plt
3
  import numpy as np
4
  import cv2
5
  import PIL
 
6
 
7
- from model import get_model, predict, prepare_prediction
 
8
 
9
  print('Creating the model')
10
  model = get_model('checkpoint.ckpt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def plot_img_no_mask(image, boxes):
13
  # Show image
14
  boxes = boxes.cpu().detach().numpy().astype(np.int32)
15
  fig, ax = plt.subplots(1, 1, figsize=(12, 6))
16
 
17
- for i, box in enumerate(boxes):
 
 
18
  [x1, y1, x2, y2] = np.array(box).astype(int)
19
  # Si no se hace la copia da error en cv2.rectangle
20
  image = np.array(image).copy()
21
 
22
  pt1 = (x1, y1)
23
  pt2 = (x2, y2)
24
- cv2.rectangle(image, pt1, pt2, (220,0,0), thickness=5)
 
 
 
25
 
26
  plt.axis('off')
27
  ax.imshow(image)
@@ -79,8 +109,11 @@ if image_file is not None:
79
  pred_dict = predict(model, data, detection_threshold)
80
  print('Fixing the preds')
81
  boxes, image = prepare_prediction(pred_dict, nms_threshold)
 
 
 
82
  print('Plotting')
83
- plot_img_no_mask(image, boxes)
84
 
85
  img = PIL.Image.open('img.png')
86
  st.image(img,width=750)
 
3
  import numpy as np
4
  import cv2
5
  import PIL
6
+ import torch
7
 
8
+ from classifier import CustomEfficientNet
9
+ from model import get_model, predict, prepare_prediction, predict_class
10
 
11
  print('Creating the model')
12
  model = get_model('checkpoint.ckpt')
13
+ print('Loading the classifier')
14
+ classifier = CustomEfficientNet(target_size=7, pretrained=False)
15
+ classifier.load_state_dict(torch.load('class_efficientB0_taco_7_class.pth'))
16
+
17
+ def plot_img_no_mask(image, boxes, labels):
18
+ colors = {
19
+ 0: (255,255,0),
20
+ 1: (255, 0, 0),
21
+ 2: (0, 0, 255),
22
+ 3: (0,128,0),
23
+ 4: (255,165,0),
24
+ 5: (230,230,250),
25
+ 6: (192,192,192)
26
+ }
27
+
28
+ texts = {
29
+ 0: 'plastic',
30
+ 1: 'dangerous',
31
+ 2: 'carton',
32
+ 3: 'glass',
33
+ 4: 'organic',
34
+ 5: 'rest',
35
+ 6: 'other'
36
+ }
37
 
 
38
  # Show image
39
  boxes = boxes.cpu().detach().numpy().astype(np.int32)
40
  fig, ax = plt.subplots(1, 1, figsize=(12, 6))
41
 
42
+ for i, box in enumerate(boxes):
43
+ color = colors[labels[i]]
44
+
45
  [x1, y1, x2, y2] = np.array(box).astype(int)
46
  # Si no se hace la copia da error en cv2.rectangle
47
  image = np.array(image).copy()
48
 
49
  pt1 = (x1, y1)
50
  pt2 = (x2, y2)
51
+ cv2.rectangle(image, pt1, pt2, color, thickness=5)
52
+ cv2.putText(image, texts[labels[i]], (x1, y1-10),
53
+ cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color)
54
+
55
 
56
  plt.axis('off')
57
  ax.imshow(image)
 
109
  pred_dict = predict(model, data, detection_threshold)
110
  print('Fixing the preds')
111
  boxes, image = prepare_prediction(pred_dict, nms_threshold)
112
+
113
+ print('Predicting classes')
114
+ labels = predict_class(classifier, image, boxes)
115
  print('Plotting')
116
+ plot_img_no_mask(image, boxes, labels)
117
 
118
  img = PIL.Image.open('img.png')
119
  st.image(img,width=750)
classifier.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch.nn as nn
3
+
4
+ import torch
5
+
6
+ def get_efficientnet(model_name):
7
+ model = timm.create_model(model_name, pretrained=True)
8
+
9
+ return model
10
+
11
+ class CustomEfficientNet(nn.Module):
12
+ """
13
+ This class defines a custom EfficientNet network.
14
+
15
+ Parameters
16
+ ----------
17
+ target_size : int
18
+ Number of units for the output layer.
19
+ pretrained : bool
20
+ Determine if pretrained weights are used.
21
+
22
+ Attributes
23
+ ----------
24
+ model : nn.Module
25
+ EfficientNet model.
26
+ """
27
+ def __init__(self, model_name : str = 'efficientnet_b0',
28
+ target_size : int = 4, pretrained : bool = True):
29
+ super().__init__()
30
+ self.model = timm.create_model(model_name, pretrained=pretrained)
31
+
32
+ # Modify the classifier layer
33
+ in_features = self.model.classifier.in_features
34
+ self.model.classifier = nn.Sequential(
35
+ #nn.Dropout(0.5),
36
+ nn.Linear(in_features, 256),
37
+ nn.ReLU(),
38
+ #nn.Dropout(0.5),
39
+ nn.Linear(256, target_size)
40
+ )
41
+
42
+ def forward(self, x : torch.Tensor) -> torch.Tensor:
43
+ x = self.model(x)
44
+
45
+ return x
model.py CHANGED
@@ -72,4 +72,21 @@ def prepare_prediction(pred_dict, threshold):
72
  fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, threshold)
73
  boxes = boxes[fixed_boxes, :]
74
 
75
- return boxes, image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, threshold)
73
  boxes = boxes[fixed_boxes, :]
74
 
75
+ return boxes, image
76
+
77
+ def predict_class(model, image, bboxes):
78
+ preds = []
79
+
80
+ for bbox in bboxes:
81
+ img = image.copy()
82
+ bbox = np.array(bbox).astype(int)
83
+ cropped_img = PIL.Image.fromarray(img).crop(bbox)
84
+ cropped_img = np.array(cropped_img).transpose(2, 0, 1)
85
+ cropped_img = torch.as_tensor(cropped_img, dtype=torch.float).unsqueeze(0)
86
+
87
+ y_preds = model(cropped_img)
88
+ preds.append(y_preds.softmax(1).detach().numpy())
89
+
90
+ preds = np.concatenate(preds).argmax(1)
91
+
92
+ return preds