Spaces:
Runtime error
Runtime error
Hector Lopez
commited on
Commit
·
9fbf078
1
Parent(s):
161f9af
feature: Objects classification
Browse files- app.py +38 -5
- classifier.py +45 -0
- model.py +18 -1
app.py
CHANGED
@@ -3,25 +3,55 @@ import matplotlib.pyplot as plt
|
|
3 |
import numpy as np
|
4 |
import cv2
|
5 |
import PIL
|
|
|
6 |
|
7 |
-
from
|
|
|
8 |
|
9 |
print('Creating the model')
|
10 |
model = get_model('checkpoint.ckpt')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
def plot_img_no_mask(image, boxes):
|
13 |
# Show image
|
14 |
boxes = boxes.cpu().detach().numpy().astype(np.int32)
|
15 |
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
|
16 |
|
17 |
-
for i, box in enumerate(boxes):
|
|
|
|
|
18 |
[x1, y1, x2, y2] = np.array(box).astype(int)
|
19 |
# Si no se hace la copia da error en cv2.rectangle
|
20 |
image = np.array(image).copy()
|
21 |
|
22 |
pt1 = (x1, y1)
|
23 |
pt2 = (x2, y2)
|
24 |
-
cv2.rectangle(image, pt1, pt2,
|
|
|
|
|
|
|
25 |
|
26 |
plt.axis('off')
|
27 |
ax.imshow(image)
|
@@ -79,8 +109,11 @@ if image_file is not None:
|
|
79 |
pred_dict = predict(model, data, detection_threshold)
|
80 |
print('Fixing the preds')
|
81 |
boxes, image = prepare_prediction(pred_dict, nms_threshold)
|
|
|
|
|
|
|
82 |
print('Plotting')
|
83 |
-
plot_img_no_mask(image, boxes)
|
84 |
|
85 |
img = PIL.Image.open('img.png')
|
86 |
st.image(img,width=750)
|
|
|
3 |
import numpy as np
|
4 |
import cv2
|
5 |
import PIL
|
6 |
+
import torch
|
7 |
|
8 |
+
from classifier import CustomEfficientNet
|
9 |
+
from model import get_model, predict, prepare_prediction, predict_class
|
10 |
|
11 |
print('Creating the model')
|
12 |
model = get_model('checkpoint.ckpt')
|
13 |
+
print('Loading the classifier')
|
14 |
+
classifier = CustomEfficientNet(target_size=7, pretrained=False)
|
15 |
+
classifier.load_state_dict(torch.load('class_efficientB0_taco_7_class.pth'))
|
16 |
+
|
17 |
+
def plot_img_no_mask(image, boxes, labels):
|
18 |
+
colors = {
|
19 |
+
0: (255,255,0),
|
20 |
+
1: (255, 0, 0),
|
21 |
+
2: (0, 0, 255),
|
22 |
+
3: (0,128,0),
|
23 |
+
4: (255,165,0),
|
24 |
+
5: (230,230,250),
|
25 |
+
6: (192,192,192)
|
26 |
+
}
|
27 |
+
|
28 |
+
texts = {
|
29 |
+
0: 'plastic',
|
30 |
+
1: 'dangerous',
|
31 |
+
2: 'carton',
|
32 |
+
3: 'glass',
|
33 |
+
4: 'organic',
|
34 |
+
5: 'rest',
|
35 |
+
6: 'other'
|
36 |
+
}
|
37 |
|
|
|
38 |
# Show image
|
39 |
boxes = boxes.cpu().detach().numpy().astype(np.int32)
|
40 |
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
|
41 |
|
42 |
+
for i, box in enumerate(boxes):
|
43 |
+
color = colors[labels[i]]
|
44 |
+
|
45 |
[x1, y1, x2, y2] = np.array(box).astype(int)
|
46 |
# Si no se hace la copia da error en cv2.rectangle
|
47 |
image = np.array(image).copy()
|
48 |
|
49 |
pt1 = (x1, y1)
|
50 |
pt2 = (x2, y2)
|
51 |
+
cv2.rectangle(image, pt1, pt2, color, thickness=5)
|
52 |
+
cv2.putText(image, texts[labels[i]], (x1, y1-10),
|
53 |
+
cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color)
|
54 |
+
|
55 |
|
56 |
plt.axis('off')
|
57 |
ax.imshow(image)
|
|
|
109 |
pred_dict = predict(model, data, detection_threshold)
|
110 |
print('Fixing the preds')
|
111 |
boxes, image = prepare_prediction(pred_dict, nms_threshold)
|
112 |
+
|
113 |
+
print('Predicting classes')
|
114 |
+
labels = predict_class(classifier, image, boxes)
|
115 |
print('Plotting')
|
116 |
+
plot_img_no_mask(image, boxes, labels)
|
117 |
|
118 |
img = PIL.Image.open('img.png')
|
119 |
st.image(img,width=750)
|
classifier.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timm
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def get_efficientnet(model_name):
|
7 |
+
model = timm.create_model(model_name, pretrained=True)
|
8 |
+
|
9 |
+
return model
|
10 |
+
|
11 |
+
class CustomEfficientNet(nn.Module):
|
12 |
+
"""
|
13 |
+
This class defines a custom EfficientNet network.
|
14 |
+
|
15 |
+
Parameters
|
16 |
+
----------
|
17 |
+
target_size : int
|
18 |
+
Number of units for the output layer.
|
19 |
+
pretrained : bool
|
20 |
+
Determine if pretrained weights are used.
|
21 |
+
|
22 |
+
Attributes
|
23 |
+
----------
|
24 |
+
model : nn.Module
|
25 |
+
EfficientNet model.
|
26 |
+
"""
|
27 |
+
def __init__(self, model_name : str = 'efficientnet_b0',
|
28 |
+
target_size : int = 4, pretrained : bool = True):
|
29 |
+
super().__init__()
|
30 |
+
self.model = timm.create_model(model_name, pretrained=pretrained)
|
31 |
+
|
32 |
+
# Modify the classifier layer
|
33 |
+
in_features = self.model.classifier.in_features
|
34 |
+
self.model.classifier = nn.Sequential(
|
35 |
+
#nn.Dropout(0.5),
|
36 |
+
nn.Linear(in_features, 256),
|
37 |
+
nn.ReLU(),
|
38 |
+
#nn.Dropout(0.5),
|
39 |
+
nn.Linear(256, target_size)
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, x : torch.Tensor) -> torch.Tensor:
|
43 |
+
x = self.model(x)
|
44 |
+
|
45 |
+
return x
|
model.py
CHANGED
@@ -72,4 +72,21 @@ def prepare_prediction(pred_dict, threshold):
|
|
72 |
fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, threshold)
|
73 |
boxes = boxes[fixed_boxes, :]
|
74 |
|
75 |
-
return boxes, image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, threshold)
|
73 |
boxes = boxes[fixed_boxes, :]
|
74 |
|
75 |
+
return boxes, image
|
76 |
+
|
77 |
+
def predict_class(model, image, bboxes):
|
78 |
+
preds = []
|
79 |
+
|
80 |
+
for bbox in bboxes:
|
81 |
+
img = image.copy()
|
82 |
+
bbox = np.array(bbox).astype(int)
|
83 |
+
cropped_img = PIL.Image.fromarray(img).crop(bbox)
|
84 |
+
cropped_img = np.array(cropped_img).transpose(2, 0, 1)
|
85 |
+
cropped_img = torch.as_tensor(cropped_img, dtype=torch.float).unsqueeze(0)
|
86 |
+
|
87 |
+
y_preds = model(cropped_img)
|
88 |
+
preds.append(y_preds.softmax(1).detach().numpy())
|
89 |
+
|
90 |
+
preds = np.concatenate(preds).argmax(1)
|
91 |
+
|
92 |
+
return preds
|