sensura commited on
Commit
d11116c
·
verified ·
1 Parent(s): 288c8ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -11
app.py CHANGED
@@ -3,31 +3,60 @@ from PIL import Image
3
  import gradio as gr
4
  from huggingface_hub import snapshot_download
5
  import os
 
 
6
 
7
  def load_model(repo_id):
8
  download_dir = snapshot_download(repo_id)
9
- print(download_dir)
10
  path = os.path.join(download_dir, "best_int8_openvino_model")
11
- print(path)
12
- detection_model = YOLO(path, task='detect')
13
- return detection_model
14
 
15
- def predict(pilimg, conf_threshold, iou_threshold):
16
- source = pilimg
17
- result = detection_model.predict(source, conf=conf_threshold, iou=iou_threshold)
18
  img_bgr = result[0].plot()
19
  out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB for PIL
20
  return out_pilimg
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  REPO_ID = "sensura/belisha-beacon-zebra-crossing-yoloV8"
23
  detection_model = load_model(REPO_ID)
24
 
25
  gr.Interface(
26
  fn=predict,
27
  inputs=[
28
- gr.Image(type="pil", label="Upload Image"),
29
- gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold"),
30
- gr.Slider(minimum=0.1, maximum=1.0, value=0.6, step=0.05, label="IoU Threshold")
31
  ],
32
- outputs=gr.Image(type="pil", label="Detection Output")
33
  ).launch(share=True)
 
3
  import gradio as gr
4
  from huggingface_hub import snapshot_download
5
  import os
6
+ import tempfile
7
+ import cv2
8
 
9
  def load_model(repo_id):
10
  download_dir = snapshot_download(repo_id)
 
11
  path = os.path.join(download_dir, "best_int8_openvino_model")
12
+ return YOLO(path, task='detect')
 
 
13
 
14
+ def predict_image(pilimg, conf_threshold, iou_threshold):
15
+ result = detection_model.predict(pilimg, conf=conf_threshold, iou=iou_threshold)
 
16
  img_bgr = result[0].plot()
17
  out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB for PIL
18
  return out_pilimg
19
 
20
+ def predict_video(video_path, conf_threshold, iou_threshold):
21
+ cap = cv2.VideoCapture(video_path)
22
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
23
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
24
+ fps = cap.get(cv2.CAP_PROP_FPS)
25
+
26
+ out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
27
+ out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
28
+
29
+ while cap.isOpened():
30
+ ret, frame = cap.read()
31
+ if not ret:
32
+ break
33
+ result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold)
34
+ annotated = result[0].plot()
35
+ out_writer.write(annotated)
36
+
37
+ cap.release()
38
+ out_writer.release()
39
+ return out_path
40
+
41
+ def predict(file, conf_threshold, iou_threshold):
42
+ ext = os.path.splitext(file.name)[1].lower()
43
+ if ext in ['.jpg', '.jpeg', '.png']:
44
+ img = Image.open(file).convert("RGB")
45
+ return predict_image(img, conf_threshold, iou_threshold)
46
+ elif ext in ['.mp4', '.mov', '.avi']:
47
+ return predict_video(file.name, conf_threshold, iou_threshold)
48
+ else:
49
+ return "Unsupported file type. Please upload an image or video."
50
+
51
  REPO_ID = "sensura/belisha-beacon-zebra-crossing-yoloV8"
52
  detection_model = load_model(REPO_ID)
53
 
54
  gr.Interface(
55
  fn=predict,
56
  inputs=[
57
+ gr.File(label="Upload Image or Video"),
58
+ gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold"),
59
+ gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
60
  ],
61
+ outputs=gr.outputs.Image(type="pil", label="Detected Image or Video") | gr.outputs.Video(label="Detected Video")
62
  ).launch(share=True)