iarbel commited on
Commit
5bb5008
·
1 Parent(s): 04f67ba

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -26
handler.py CHANGED
@@ -1,16 +1,15 @@
1
  from ultralyticsplus import YOLO
2
- from typing import List, Dict, Any
3
- from sahi import ObjectPrediction
4
 
5
 
6
  DEFAULT_CONFIG = {'conf': 0.25, 'iou': 0.45, 'agnostic_nms': False, 'max_det': 1000}
7
-
8
 
9
  class EndpointHandler():
10
- def __init__(self, path=""):
11
  self.model = YOLO('ultralyticsplus/yolov8s')
12
 
13
- def __call__(self, data: str) -> List[ObjectPrediction]:
14
  """
15
  data args:
16
  image: image path to segment
@@ -19,33 +18,37 @@ class EndpointHandler():
19
  agnostic_nms - NMS class-agnostic: True / False,
20
  max_det - maximum number of detections per image)
21
  Return:
22
- object_predictions
23
  """
24
- config = DEFAULT_CONFIG
 
 
 
 
 
25
  # Set model parameters
26
  self.model.overrides['conf'] = config.get('conf')
27
  self.model.overrides['iou'] = config.get('iou')
28
  self.model.overrides['agnostic_nms'] = config.get('agnostic_nms')
29
  self.model.overrides['max_det'] = config.get('max_det')
30
 
 
 
 
31
  # perform inference
32
- inputs = data.pop("inputs", data)
33
  result = self.model.predict(inputs['image'])[0]
34
-
35
- names = self.model.model.names
36
- boxes = result.boxes
37
-
38
- object_predictions = []
39
- if boxes is not None:
40
- det_ind = 0
41
- for xyxy, conf, cls in zip(boxes.xyxy, boxes.conf, boxes.cls):
42
- object_prediction = ObjectPrediction(
43
- bbox=xyxy.tolist(),
44
- category_name=names[int(cls)],
45
- category_id=int(cls),
46
- score=conf,
47
- )
48
- object_predictions.append(object_prediction)
49
- det_ind += 1
50
- return object_predictions
51
-
 
1
  from ultralyticsplus import YOLO
2
+ from typing import Dict, Any, List
 
3
 
4
 
5
  DEFAULT_CONFIG = {'conf': 0.25, 'iou': 0.45, 'agnostic_nms': False, 'max_det': 1000}
6
+ BOX_KEYS = ['xmin', 'ymin', 'xmax', 'ymax']
7
 
8
  class EndpointHandler():
9
+ def __init__(self):
10
  self.model = YOLO('ultralyticsplus/yolov8s')
11
 
12
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
  """
14
  data args:
15
  image: image path to segment
 
18
  agnostic_nms - NMS class-agnostic: True / False,
19
  max_det - maximum number of detections per image)
20
  Return:
21
+ A :obj: `dict` | `dict`: {scores, labels, boxes}
22
  """
23
+ inputs = data.pop("inputs", data)
24
+ input_config = inputs.pop("config", DEFAULT_CONFIG)
25
+ config = {**DEFAULT_CONFIG, **input_config}
26
+
27
+ if config is None:
28
+ config = DEFAULT_CONFIG
29
  # Set model parameters
30
  self.model.overrides['conf'] = config.get('conf')
31
  self.model.overrides['iou'] = config.get('iou')
32
  self.model.overrides['agnostic_nms'] = config.get('agnostic_nms')
33
  self.model.overrides['max_det'] = config.get('max_det')
34
 
35
+ # Get label idx-to-name
36
+ names = model.model.names
37
+
38
  # perform inference
 
39
  result = self.model.predict(inputs['image'])[0]
40
+ prediction = []
41
+ for score, label, box in zip(result.boxes.conf, result.boxes.cls, result.boxes.xyxy):
42
+ item_score = score.item()
43
+ item_label = names[int(label)]
44
+ item_box = box.to(dtype=int).tolist()
45
+
46
+ item_prediction = {
47
+ 'score': item_score,
48
+ 'label': item_label,
49
+ 'box': dict(zip(BOX_KEYS, item_box))
50
+ }
51
+
52
+ prediction.append(item_prediction)
53
+
54
+ return prediction