sensura commited on
Commit
5c5412c
·
verified ·
1 Parent(s): 0a43a36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -40
app.py CHANGED
@@ -5,62 +5,90 @@ from huggingface_hub import snapshot_download
5
  import os
6
  import tempfile
7
  import cv2
 
 
8
 
9
  def load_model(repo_id):
10
  download_dir = snapshot_download(repo_id)
11
  path = os.path.join(download_dir, "best_int8_openvino_model")
12
  return YOLO(path, task='detect')
13
 
14
- def predict_image(pilimg, conf_threshold, iou_threshold):
15
- result = detection_model.predict(pilimg, conf=conf_threshold, iou=iou_threshold)
16
- img_bgr = result[0].plot()
17
- out_img = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB for PIL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Save to temp file and return path
20
- tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
21
- out_img.save(tmp_img.name)
22
- return tmp_img.name
 
 
 
23
 
24
- def predict_video(video_path, conf_threshold, iou_threshold):
25
- cap = cv2.VideoCapture(video_path)
26
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
27
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
28
- fps = cap.get(cv2.CAP_PROP_FPS)
29
-
30
- out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
31
- out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
32
-
33
- while cap.isOpened():
34
- ret, frame = cap.read()
35
- if not ret:
36
- break
37
- result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold)
38
- annotated = result[0].plot()
39
- out_writer.write(annotated)
40
 
41
- cap.release()
42
- out_writer.release()
43
- return out_path
 
 
 
 
 
 
 
 
44
 
45
- def predict(file, conf_threshold, iou_threshold):
46
- ext = os.path.splitext(file.name)[1].lower()
47
- if ext in ['.jpg', '.jpeg', '.png']:
48
- img = Image.open(file).convert("RGB")
49
- return predict_image(img, conf_threshold, iou_threshold)
50
- elif ext in ['.mp4', '.mov', '.avi']:
51
- return predict_video(file.name, conf_threshold, iou_threshold)
52
- else:
53
- return "Unsupported file type. Please upload an image or video."
54
 
55
  REPO_ID = "sensura/belisha-beacon-zebra-crossing-yoloV8"
56
  detection_model = load_model(REPO_ID)
57
 
58
- gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  fn=predict,
60
  inputs=[
61
- gr.File(label="Upload Image or Video"),
62
  gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold"),
63
  gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
64
  ],
65
- outputs=gr.File(label="Processed Output")
66
- ).launch(share=True)
 
 
 
 
5
  import os
6
  import tempfile
7
  import cv2
8
+ import zipfile
9
+ import shutil
10
 
11
  def load_model(repo_id):
12
  download_dir = snapshot_download(repo_id)
13
  path = os.path.join(download_dir, "best_int8_openvino_model")
14
  return YOLO(path, task='detect')
15
 
16
+ def predict(files, conf_threshold, iou_threshold):
17
+ if len(files) == 1:
18
+ ext = os.path.splitext(files[0].name)[1].lower()
19
+ if ext in ['.jpg', '.jpeg', '.png']:
20
+ img = Image.open(files[0]).convert("RGB")
21
+ result = detection_model.predict(img, conf=conf_threshold, iou=iou_threshold)
22
+ img_bgr = result[0].plot()
23
+ out_img = Image.fromarray(img_bgr[..., ::-1])
24
+ tmp_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
25
+ out_img.save(tmp_path)
26
+ return tmp_path
27
+ elif ext in ['.mp4', '.mov', '.avi']:
28
+ cap = cv2.VideoCapture(files[0].name)
29
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
30
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
31
+ fps = cap.get(cv2.CAP_PROP_FPS)
32
+ out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
33
+ out_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
34
 
35
+ while cap.isOpened():
36
+ ret, frame = cap.read()
37
+ if not ret:
38
+ break
39
+ result = detection_model.predict(frame, conf=conf_threshold, iou=iou_threshold)
40
+ annotated = result[0].plot()
41
+ out_writer.write(annotated)
42
 
43
+ cap.release()
44
+ out_writer.release()
45
+ return out_path
46
+ else:
47
+ return "Unsupported file type."
48
+ else:
49
+ output_dir = tempfile.mkdtemp()
50
+ annotated_images = []
 
 
 
 
 
 
 
 
51
 
52
+ for file in files:
53
+ try:
54
+ img = Image.open(file).convert("RGB")
55
+ result = detection_model.predict(img, conf=conf_threshold, iou=iou_threshold)
56
+ img_bgr = result[0].plot()
57
+ out_img = Image.fromarray(img_bgr[..., ::-1])
58
+ out_path = os.path.join(output_dir, os.path.basename(file.name))
59
+ out_img.save(out_path)
60
+ annotated_images.append(out_img)
61
+ except Exception as e:
62
+ print(f"Failed to process {file.name}: {e}")
63
 
64
+ zip_path = shutil.make_archive(output_dir, 'zip', output_dir)
65
+ return annotated_images, zip_path
 
 
 
 
 
 
 
66
 
67
  REPO_ID = "sensura/belisha-beacon-zebra-crossing-yoloV8"
68
  detection_model = load_model(REPO_ID)
69
 
70
+ def dynamic_output(file_list):
71
+ if len(file_list) == 1:
72
+ ext = os.path.splitext(file_list[0].name)[1].lower()
73
+ if ext in ['.jpg', '.jpeg', '.png']:
74
+ return gr.Image(type="filepath", label="Annotated Image")
75
+ elif ext in ['.mp4', '.mov', '.avi']:
76
+ return gr.Video(label="Annotated Video")
77
+ else:
78
+ return [
79
+ gr.Gallery(label="Annotated Images").style(grid=3, height="auto"),
80
+ gr.File(label="Download All Annotated Images (ZIP)")
81
+ ]
82
+
83
+ interface = gr.Interface(
84
  fn=predict,
85
  inputs=[
86
+ gr.File(file_types=["image", "video"], file_count="multiple", label="Upload Image(s) or Video"),
87
  gr.Slider(0.1, 1.0, 0.5, step=0.05, label="Confidence Threshold"),
88
  gr.Slider(0.1, 1.0, 0.6, step=0.05, label="IoU Threshold")
89
  ],
90
+ outputs=dynamic_output,
91
+ live=False
92
+ )
93
+
94
+ interface.launch(share=True)