sensura commited on
Commit
8e2e3e2
·
verified ·
1 Parent(s): 169e5bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -67
app.py CHANGED
@@ -8,87 +8,111 @@ import cv2
8
  import zipfile
9
  import shutil
10
 
 
11
  def load_model(repo_id):
12
  download_dir = snapshot_download(repo_id)
13
  path = os.path.join(download_dir, "best_int8_openvino_model")
14
  return YOLO(path, task='detect')
15
 
16
- def predict(files, conf_threshold, iou_threshold):
17
- image_output = None
18
- video_output = None
19
- gallery_output = None
20
- zip_output = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- if len(files) == 1:
23
- ext = os.path.splitext(files[0].name)[1].lower()
24
- if ext in ['.jpg', '.jpeg', '.png']:
25
- img = Image.open(files[0]).convert("RGB")
 
 
26
  result = detection_model.predict(img, conf=conf_threshold, iou=iou_threshold)
27
  img_bgr = result[0].plot()
28
  out_img = Image.fromarray(img_bgr[..., ::-1])
29
- tmp_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
30
- out_img.save(tmp_path)
31
- image_output = tmp_path
32
- elif ext in ['.mp4', '.mov', '.avi']:
33
- cap = cv2.VideoCapture(files[0].name)
34
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
- fps = cap.get(cv2.CAP_PROP_FPS)
37
- out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
38
- out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
39
-
40
- while cap.isOpened():
41
- ret, frame = cap.read()
42
- if not ret:
43
- break
44
- result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold)
45
- annotated = result[0].plot()
46
- out_writer.write(annotated)
47
-
48
- cap.release()
49
- out_writer.release()
50
- video_output = out_path
51
- else:
52
- return "Unsupported file type.", None, None, None
53
- else:
54
- output_dir = tempfile.mkdtemp()
55
- annotated_images = []
56
-
57
- for file in files:
58
- try:
59
- img = Image.open(file).convert("RGB")
60
- result = detection_model.predict(img, conf=conf_threshold, iou=iou_threshold)
61
- img_bgr = result[0].plot()
62
- out_img = Image.fromarray(img_bgr[..., ::-1])
63
- out_path = os.path.join(output_dir, os.path.basename(file.name))
64
- out_img.save(out_path)
65
- annotated_images.append(out_img)
66
- except Exception as e:
67
- print(f"Failed to process {file.name}: {e}")
68
-
69
- zip_path = shutil.make_archive(output_dir, 'zip', output_dir)
70
- gallery_output = annotated_images
71
- zip_output = zip_path
72
-
73
- return image_output, video_output, gallery_output, zip_output
74
 
75
- REPO_ID = "sensura/belisha-beacon-zebra-crossing-yoloV8"
76
- detection_model = load_model(REPO_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- interface = gr.Interface(
79
- fn=predict,
 
80
  inputs=[
81
- gr.File(file_types=["image", "video"], file_count="multiple", label="Upload Image(s) or Video"),
82
- gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold"),
83
- gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
84
  ],
85
  outputs=[
86
- gr.Image(type="filepath", label="Single Image Output"),
87
- gr.Video(label="Single Video Output"),
88
- gr.Gallery(columns=3, height="auto"),
89
- gr.File(label="Download ZIP for Multiple Images")
90
  ],
91
- live=False
92
  )
93
 
94
- interface.launch(share=True)
 
 
 
 
 
8
  import zipfile
9
  import shutil
10
 
11
+ # === Load model ===
12
  def load_model(repo_id):
13
  download_dir = snapshot_download(repo_id)
14
  path = os.path.join(download_dir, "best_int8_openvino_model")
15
  return YOLO(path, task='detect')
16
 
17
+ REPO_ID = "sensura/belisha-beacon-zebra-crossing-yoloV8"
18
+ detection_model = load_model(REPO_ID)
19
+
20
+ # === Single file prediction ===
21
+ def predict_single(file, conf_threshold, iou_threshold):
22
+ if file is None:
23
+ return None, None
24
+
25
+ ext = os.path.splitext(file.name)[1].lower()
26
+
27
+ if ext in ['.jpg', '.jpeg', '.png']:
28
+ img = Image.open(file).convert("RGB")
29
+ result = detection_model.predict(img, conf=conf_threshold, iou=iou_threshold)
30
+ img_bgr = result[0].plot()
31
+ out_img = Image.fromarray(img_bgr[..., ::-1])
32
+ tmp_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
33
+ out_img.save(tmp_path)
34
+ return tmp_path, None
35
+
36
+ elif ext in ['.mp4', '.mov', '.avi']:
37
+ cap = cv2.VideoCapture(file.name)
38
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
39
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
40
+ fps = cap.get(cv2.CAP_PROP_FPS)
41
+ out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
42
+ out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
43
+
44
+ while cap.isOpened():
45
+ ret, frame = cap.read()
46
+ if not ret:
47
+ break
48
+ result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold)
49
+ annotated = result[0].plot()
50
+ out_writer.write(annotated)
51
+
52
+ cap.release()
53
+ out_writer.release()
54
+ return None, out_path
55
+
56
+ else:
57
+ return None, None
58
+
59
+ # === Multiple images prediction ===
60
+ def predict_multiple(files, conf_threshold, iou_threshold):
61
+ if not files:
62
+ return None, None
63
 
64
+ output_dir = tempfile.mkdtemp()
65
+ annotated_images = []
66
+
67
+ for file in files:
68
+ try:
69
+ img = Image.open(file).convert("RGB")
70
  result = detection_model.predict(img, conf=conf_threshold, iou=iou_threshold)
71
  img_bgr = result[0].plot()
72
  out_img = Image.fromarray(img_bgr[..., ::-1])
73
+ out_path = os.path.join(output_dir, os.path.basename(file.name))
74
+ out_img.save(out_path)
75
+ annotated_images.append(out_img)
76
+ except Exception as e:
77
+ print(f"Failed to process {file.name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ zip_path = shutil.make_archive(output_dir, 'zip', output_dir)
80
+ return annotated_images, zip_path
81
+
82
+ # === Gradio Interfaces ===
83
+
84
+ # Tab 1: Single Image or Video
85
+ single_file_tab = gr.Interface(
86
+ fn=predict_single,
87
+ inputs=[
88
+ gr.File(file_types=["image", "video"], label="Upload Image or Video"),
89
+ gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold"),
90
+ gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold"),
91
+ ],
92
+ outputs=[
93
+ gr.Image(type="filepath", label="Detected Image"),
94
+ gr.Video(label="Detected Video")
95
+ ],
96
+ title="Detect from a Single Image or Video"
97
+ )
98
 
99
+ # Tab 2: Multiple Images
100
+ multi_image_tab = gr.Interface(
101
+ fn=predict_multiple,
102
  inputs=[
103
+ gr.File(file_types=["image"], file_count="multiple", label="Upload Multiple Images"),
104
+ gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold"),
105
+ gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold"),
106
  ],
107
  outputs=[
108
+ gr.Gallery(label="Detected Gallery", columns=3, height="auto"),
109
+ gr.File(label="Download Annotated ZIP")
 
 
110
  ],
111
+ title="Batch Detect from Multiple Images"
112
  )
113
 
114
+ # === Tabbed UI Launch ===
115
+ gr.TabbedInterface(
116
+ [single_file_tab, multi_image_tab],
117
+ tab_names=["Single File", "Multiple Images"]
118
+ ).launch(share=True)