m7mdal7aj commited on
Commit
4d7e92d
1 Parent(s): cb8b3fe

Update my_model/object_detection.py

Browse files
Files changed (1) hide show
  1. my_model/object_detection.py +6 -4
my_model/object_detection.py CHANGED
@@ -26,10 +26,11 @@ class ObjectDetector:
26
  """
27
  Initializes the ObjectDetector class with default values.
28
  """
29
-
30
  self.model = None
31
  self.processor = None
32
  self.model_name = None
 
33
 
34
  def load_model(self, model_name='detic', pretrained=True, model_version='yolov5s'):
35
  """
@@ -52,6 +53,7 @@ class ObjectDetector:
52
  else:
53
  raise ValueError(f"Unsupported model name: {model_name}")
54
 
 
55
  def _load_detic_model(self, pretrained):
56
  """
57
  Load the Detic model.
@@ -62,13 +64,13 @@ class ObjectDetector:
62
 
63
  try:
64
  model_path = get_model_path('deformable-detr-detic')
65
- st.write(model_path)
66
  self.processor = AutoImageProcessor.from_pretrained(model_path)
67
  self.model = AutoModelForObjectDetection.from_pretrained(model_path)
68
  except Exception as e:
69
  print(f"Error loading Detic model: {e}")
70
  raise
71
 
 
72
  def _load_yolov5_model(self, pretrained, model_version):
73
  """
74
  Load the YOLOv5 model.
@@ -80,7 +82,6 @@ class ObjectDetector:
80
 
81
  try:
82
  model_path = get_model_path ('yolov5')
83
- st.write(model_path)
84
  if model_path and os.path.exists(model_path):
85
  self.model = torch.hub.load(model_path, model_version, pretrained=pretrained, source='local')
86
  else:
@@ -89,6 +90,7 @@ class ObjectDetector:
89
  print(f"Error loading YOLOv5 model: {e}")
90
  raise
91
 
 
92
  def process_image(self, image_input):
93
  """
94
  Process the image from the given path or file-like object.
@@ -194,6 +196,7 @@ class ObjectDetector:
194
  detected_objects_list.append((label_name, box_rounded, certainty))
195
  return detected_objects_str, detected_objects_list
196
 
 
197
  def draw_boxes(self, image, detected_objects, show_confidence=True):
198
  """
199
  Draw bounding boxes around detected objects in the image.
@@ -218,7 +221,6 @@ class ObjectDetector:
218
  for label_name, box, score in detected_objects:
219
  if label_name not in label_color_map:
220
  label_color_map[label_name] = colors[len(label_color_map) % len(colors)]
221
-
222
  color = label_color_map[label_name]
223
  draw.rectangle(box, outline=color, width=3)
224
  label_text = f"{label_name}"
 
26
  """
27
  Initializes the ObjectDetector class with default values.
28
  """
29
+
30
  self.model = None
31
  self.processor = None
32
  self.model_name = None
33
+
34
 
35
  def load_model(self, model_name='detic', pretrained=True, model_version='yolov5s'):
36
  """
 
53
  else:
54
  raise ValueError(f"Unsupported model name: {model_name}")
55
 
56
+
57
  def _load_detic_model(self, pretrained):
58
  """
59
  Load the Detic model.
 
64
 
65
  try:
66
  model_path = get_model_path('deformable-detr-detic')
 
67
  self.processor = AutoImageProcessor.from_pretrained(model_path)
68
  self.model = AutoModelForObjectDetection.from_pretrained(model_path)
69
  except Exception as e:
70
  print(f"Error loading Detic model: {e}")
71
  raise
72
 
73
+
74
  def _load_yolov5_model(self, pretrained, model_version):
75
  """
76
  Load the YOLOv5 model.
 
82
 
83
  try:
84
  model_path = get_model_path ('yolov5')
 
85
  if model_path and os.path.exists(model_path):
86
  self.model = torch.hub.load(model_path, model_version, pretrained=pretrained, source='local')
87
  else:
 
90
  print(f"Error loading YOLOv5 model: {e}")
91
  raise
92
 
93
+
94
  def process_image(self, image_input):
95
  """
96
  Process the image from the given path or file-like object.
 
196
  detected_objects_list.append((label_name, box_rounded, certainty))
197
  return detected_objects_str, detected_objects_list
198
 
199
+
200
  def draw_boxes(self, image, detected_objects, show_confidence=True):
201
  """
202
  Draw bounding boxes around detected objects in the image.
 
221
  for label_name, box, score in detected_objects:
222
  if label_name not in label_color_map:
223
  label_color_map[label_name] = colors[len(label_color_map) % len(colors)]
 
224
  color = label_color_map[label_name]
225
  draw.rectangle(box, outline=color, width=3)
226
  label_text = f"{label_name}"