Gabolozano commited on
Commit
e3205c6
·
verified ·
1 Parent(s): b842d10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -13
app.py CHANGED
@@ -11,9 +11,8 @@ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config
11
  image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
12
 
13
  def load_model(threshold):
14
- # Since changing threshold at runtime for models isn't typically supported directly by the transformers pipeline,
15
- # we reinitialize the model with the desired configuration when needed.
16
- config = DetrConfig.from_pretrained("facebook/detr-resnet-50", num_labels=91, threshold=threshold)
17
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
18
  image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
19
  return pipeline(task='object-detection', model=model, image_processor=image_processor)
@@ -27,6 +26,7 @@ def draw_detections(image, detections):
27
  # Convert RGB to BGR for OpenCV
28
  np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
29
 
 
30
  for detection in detections:
31
  score = detection['score']
32
  label = detection['label']
@@ -36,10 +36,10 @@ def draw_detections(image, detections):
36
  x_max = box['xmax']
37
  y_max = box['ymax']
38
 
39
- # Draw rectangles and text with a larger font
40
  cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
41
  label_text = f'{label} {score:.2f}'
42
- cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
43
 
44
  # Convert BGR to RGB for displaying
45
  final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
@@ -48,30 +48,39 @@ def draw_detections(image, detections):
48
 
49
  def get_pipeline_prediction(threshold, pil_image):
50
  global od_pipe
51
- if od_pipe.config.threshold != threshold:
52
- od_pipe = load_model(threshold)
53
  try:
54
- pil_image = Image.fromarray(np.array(pil_image))
 
 
 
 
 
 
 
 
 
55
  pipeline_output = od_pipe(pil_image)
56
  processed_image = draw_detections(pil_image, pipeline_output)
57
  return processed_image, pipeline_output
58
  except Exception as e:
59
- print(f"An error occurred: {str(e)}")
60
- return pil_image, {"error": str(e)}
 
61
 
62
- # Define the Gradio blocks interface
63
  with gr.Blocks() as demo:
64
  with gr.Row():
65
  with gr.Column():
66
  inp_image = gr.Image(label="Input image")
67
- slider = gr.Slider(minimum=0, maximum=1, step=0.05, label="Adjust Detection Sensitivity", value=0.5)
68
- gr.Markdown("Adjust the slider to change the detection sensitivity.")
69
  btn_run = gr.Button('Run Detection')
70
  with gr.Column():
71
  with gr.Tab("Annotated Image"):
72
  out_image = gr.Image()
73
  with gr.Tab("Detection Results"):
74
  out_json = gr.JSON()
 
75
  btn_run.click(get_pipeline_prediction, inputs=[slider, inp_image], outputs=[out_image, out_json])
76
 
77
  demo.launch()
 
11
  image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
12
 
13
  def load_model(threshold):
14
+ # Reinitialize the model with the desired detection threshold
15
+ config = DetrConfig.from_pretrained("facebook/detr-resnet-50", threshold=threshold)
 
16
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
17
  image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
18
  return pipeline(task='object-detection', model=model, image_processor=image_processor)
 
26
  # Convert RGB to BGR for OpenCV
27
  np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
28
 
29
+ # Draw detections
30
  for detection in detections:
31
  score = detection['score']
32
  label = detection['label']
 
36
  x_max = box['xmax']
37
  y_max = box['ymax']
38
 
39
+ # Increase font size for better visibility
40
  cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
41
  label_text = f'{label} {score:.2f}'
42
+ cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
43
 
44
  # Convert BGR to RGB for displaying
45
  final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
 
48
 
49
  def get_pipeline_prediction(threshold, pil_image):
50
  global od_pipe
 
 
51
  try:
52
+ # Check if the model threshold needs adjusting
53
+ if od_pipe.config.threshold != threshold:
54
+ od_pipe = load_model(threshold)
55
+ print("Model reloaded with new threshold:", threshold)
56
+
57
+ # Ensure input is a PIL image
58
+ if not isinstance(pil_image, Image.Image):
59
+ pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
60
+
61
+ # Run detection and return annotated image and results
62
  pipeline_output = od_pipe(pil_image)
63
  processed_image = draw_detections(pil_image, pipeline_output)
64
  return processed_image, pipeline_output
65
  except Exception as e:
66
+ error_message = f"An error occurred: {str(e)}"
67
+ print(error_message)
68
+ return pil_image, {"error": error_message}
69
 
70
+ # Gradio interface
71
  with gr.Blocks() as demo:
72
  with gr.Row():
73
  with gr.Column():
74
  inp_image = gr.Image(label="Input image")
75
+ slider = gr.Slider(minimum=0, maximum=1, step=0.05, label="Detection Sensitivity", value=0.5)
76
+ gr.Markdown("Adjust the slider to change detection sensitivity.")
77
  btn_run = gr.Button('Run Detection')
78
  with gr.Column():
79
  with gr.Tab("Annotated Image"):
80
  out_image = gr.Image()
81
  with gr.Tab("Detection Results"):
82
  out_json = gr.JSON()
83
+
84
  btn_run.click(get_pipeline_prediction, inputs=[slider, inp_image], outputs=[out_image, out_json])
85
 
86
  demo.launch()