jameslahm commited on
Commit
1e8e71b
1 Parent(s): 8a3330e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -30
app.py CHANGED
@@ -1,32 +1,61 @@
1
- import PIL.Image as Image
2
  import gradio as gr
3
- import spaces
4
-
5
  from ultralytics import YOLOv10
 
6
 
7
  @spaces.GPU
8
- def predict_image(img, model_id, image_size, conf_threshold):
9
  model = YOLOv10.from_pretrained(f'jameslahm/{model_id}')
10
- results = model.predict(
11
- source=img,
12
- conf=conf_threshold,
13
- show_labels=True,
14
- show_conf=True,
15
- imgsz=image_size,
16
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- for r in results:
19
- im_array = r.plot()
20
- im = Image.fromarray(im_array[..., ::-1])
 
 
 
 
 
 
 
 
 
 
21
 
22
- return im
23
 
24
  def app():
25
  with gr.Blocks():
26
  with gr.Row():
27
  with gr.Column():
28
- image = gr.Image(type="pil", label="Image")
29
-
 
 
 
 
 
30
  model_id = gr.Dropdown(
31
  label="Model",
32
  choices=[
@@ -56,35 +85,52 @@ def app():
56
  yolov10_infer = gr.Button(value="Detect Objects")
57
 
58
  with gr.Column():
59
- output_image = gr.Image(type="pil", label="Annotated Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  yolov10_infer.click(
62
- fn=predict_image,
63
- inputs=[
64
- image,
65
- model_id,
66
- image_size,
67
- conf_threshold,
68
- ],
69
- outputs=[output_image],
70
  )
71
 
72
  gr.Examples(
73
  examples=[
74
  [
75
- "bus.jpg",
76
  "yolov10s",
77
  640,
78
  0.25,
79
  ],
80
  [
81
- "zidane.jpg",
82
  "yolov10s",
83
  640,
84
  0.25,
85
  ],
86
  ],
87
- fn=predict_image,
88
  inputs=[
89
  image,
90
  model_id,
@@ -113,4 +159,4 @@ with gradio_app:
113
  with gr.Column():
114
  app()
115
  if __name__ == '__main__':
116
- gradio_app.launch()
 
 
1
  import gradio as gr
2
+ import cv2
3
+ import tempfile
4
  from ultralytics import YOLOv10
5
+ import spaces
6
 
7
  @spaces.GPU
8
+ def yolov10_inference(image, video, model_id, image_size, conf_threshold):
9
  model = YOLOv10.from_pretrained(f'jameslahm/{model_id}')
10
+ if image:
11
+ results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
12
+ annotated_image = results[0].plot()
13
+ return annotated_image[:, :, ::-1], None
14
+ else:
15
+ video_path = tempfile.mktemp(suffix=".webm")
16
+ with open(video_path, "wb") as f:
17
+ with open(video, "rb") as g:
18
+ f.write(g.read())
19
+
20
+ cap = cv2.VideoCapture(video_path)
21
+ fps = cap.get(cv2.CAP_PROP_FPS)
22
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
23
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
24
+
25
+ output_video_path = tempfile.mktemp(suffix=".webm")
26
+ out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
27
+
28
+ while cap.isOpened():
29
+ ret, frame = cap.read()
30
+ if not ret:
31
+ break
32
 
33
+ results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
34
+ annotated_frame = results[0].plot()
35
+ out.write(annotated_frame)
36
+
37
+ cap.release()
38
+ out.release()
39
+
40
+ return None, output_video_path
41
+
42
+ @spaces.GPU
43
+ def yolov10_inference_for_examples(image, model_path, image_size, conf_threshold):
44
+ annotated_image, _ = yolov10_inference(image, None, model_path, image_size, conf_threshold)
45
+ return annotated_image
46
 
 
47
 
48
  def app():
49
  with gr.Blocks():
50
  with gr.Row():
51
  with gr.Column():
52
+ image = gr.Image(type="pil", label="Image", visible=True)
53
+ video = gr.Video(label="Video", visible=False)
54
+ input_type = gr.Radio(
55
+ choices=["Image", "Video"],
56
+ value="Image",
57
+ label="Input Type",
58
+ )
59
  model_id = gr.Dropdown(
60
  label="Model",
61
  choices=[
 
85
  yolov10_infer = gr.Button(value="Detect Objects")
86
 
87
  with gr.Column():
88
+ output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
89
+ output_video = gr.Video(label="Annotated Video", visible=False)
90
+
91
+ def update_visibility(input_type):
92
+ image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
93
+ video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
94
+ output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
95
+ output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
96
+
97
+ return image, video, output_image, output_video
98
+
99
+ input_type.change(
100
+ fn=update_visibility,
101
+ inputs=[input_type],
102
+ outputs=[image, video, output_image, output_video],
103
+ )
104
+
105
+ def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
106
+ if input_type == "Image":
107
+ return yolov10_inference(image, None, model_id, image_size, conf_threshold)
108
+ else:
109
+ return yolov10_inference(None, video, model_id, image_size, conf_threshold)
110
+
111
 
112
  yolov10_infer.click(
113
+ fn=run_inference,
114
+ inputs=[image, video, model_id, image_size, conf_threshold, input_type],
115
+ outputs=[output_image, output_video],
 
 
 
 
 
116
  )
117
 
118
  gr.Examples(
119
  examples=[
120
  [
121
+ "ultralytics/assets/bus.jpg",
122
  "yolov10s",
123
  640,
124
  0.25,
125
  ],
126
  [
127
+ "ultralytics/assets/zidane.jpg",
128
  "yolov10s",
129
  640,
130
  0.25,
131
  ],
132
  ],
133
+ fn=yolov10_inference_for_examples,
134
  inputs=[
135
  image,
136
  model_id,
 
159
  with gr.Column():
160
  app()
161
  if __name__ == '__main__':
162
+ gradio_app.launch()