iarbel commited on
Commit
04f67ba
1 Parent(s): c3c7f71

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +24 -26
handler.py CHANGED
@@ -1,7 +1,6 @@
1
  from ultralyticsplus import YOLO
2
  from typing import List, Dict, Any
3
  from sahi import ObjectPrediction
4
- import torch, torchvision
5
 
6
 
7
  DEFAULT_CONFIG = {'conf': 0.25, 'iou': 0.45, 'agnostic_nms': False, 'max_det': 1000}
@@ -22,32 +21,31 @@ class EndpointHandler():
22
  Return:
23
  object_predictions
24
  """
25
- # config = DEFAULT_CONFIG
26
- # # Set model parameters
27
- # self.model.overrides['conf'] = config.get('conf')
28
- # self.model.overrides['iou'] = config.get('iou')
29
- # self.model.overrides['agnostic_nms'] = config.get('agnostic_nms')
30
- # self.model.overrides['max_det'] = config.get('max_det')
31
 
32
- # # perform inference
33
- # inputs = data.pop("inputs", data)
34
- # result = self.model.predict(inputs['image'])[0]
35
 
36
- # names = self.model.model.names
37
- # boxes = result.boxes
38
 
39
- # object_predictions = []
40
- # if boxes is not None:
41
- # det_ind = 0
42
- # for xyxy, conf, cls in zip(boxes.xyxy, boxes.conf, boxes.cls):
43
- # object_prediction = ObjectPrediction(
44
- # bbox=xyxy.tolist(),
45
- # category_name=names[int(cls)],
46
- # category_id=int(cls),
47
- # score=conf,
48
- # )
49
- # object_predictions.append(object_prediction)
50
- # det_ind += 1
51
- # return object_predictions
52
- return torch.__version__, torchvision.__version__
53
 
 
1
  from ultralyticsplus import YOLO
2
  from typing import List, Dict, Any
3
  from sahi import ObjectPrediction
 
4
 
5
 
6
  DEFAULT_CONFIG = {'conf': 0.25, 'iou': 0.45, 'agnostic_nms': False, 'max_det': 1000}
 
21
  Return:
22
  object_predictions
23
  """
24
+ config = DEFAULT_CONFIG
25
+ # Set model parameters
26
+ self.model.overrides['conf'] = config.get('conf')
27
+ self.model.overrides['iou'] = config.get('iou')
28
+ self.model.overrides['agnostic_nms'] = config.get('agnostic_nms')
29
+ self.model.overrides['max_det'] = config.get('max_det')
30
 
31
+ # perform inference
32
+ inputs = data.pop("inputs", data)
33
+ result = self.model.predict(inputs['image'])[0]
34
 
35
+ names = self.model.model.names
36
+ boxes = result.boxes
37
 
38
+ object_predictions = []
39
+ if boxes is not None:
40
+ det_ind = 0
41
+ for xyxy, conf, cls in zip(boxes.xyxy, boxes.conf, boxes.cls):
42
+ object_prediction = ObjectPrediction(
43
+ bbox=xyxy.tolist(),
44
+ category_name=names[int(cls)],
45
+ category_id=int(cls),
46
+ score=conf,
47
+ )
48
+ object_predictions.append(object_prediction)
49
+ det_ind += 1
50
+ return object_predictions
 
51