flinta commited on
Commit
a611b4e
1 Parent(s): a7390e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -72
app.py CHANGED
@@ -1,77 +1,145 @@
1
  import gradio as gr
2
- import torch
3
- from ultralyticsplus import YOLO, render_result
4
 
5
 
6
- torch.hub.download_url_to_file(
7
- 'https://external-content.duckduckgo.com/iu/?u=https%3A%2F%2Ftexashafts.com%2Fwp-content%2Fuploads%2F2016%2F04%2Fconstruction-worker.jpg', 'one.jpg')
8
- torch.hub.download_url_to_file(
9
- 'https://www.pearsonkoutcherlaw.com/wp-content/uploads/2020/06/Construction-Workers.jpg', 'two.jpg')
10
- torch.hub.download_url_to_file(
11
- 'https://nssgroup.com/wp-content/uploads/2019/02/Building-maintenance-blog.jpg', 'three.jpg')
12
 
13
-
14
- def yoloV8_func(image: gr.inputs.Image = None,
15
- image_size: gr.inputs.Slider = 640,
16
- conf_threshold: gr.inputs.Slider = 0.4,
17
- iou_threshold: gr.inputs.Slider = 0.50):
18
- """This function performs YOLOv8 object detection on the given image.
19
-
20
- Args:
21
- image (gr.inputs.Image, optional): Input image to detect objects on. Defaults to None.
22
- image_size (gr.inputs.Slider, optional): Desired image size for the model. Defaults to 640.
23
- conf_threshold (gr.inputs.Slider, optional): Confidence threshold for object detection. Defaults to 0.4.
24
- iou_threshold (gr.inputs.Slider, optional): Intersection over Union threshold for object detection. Defaults to 0.50.
25
  """
26
- # Load the YOLOv8 model from the 'best.pt' checkpoint
27
- model_path = "best.pt"
28
- model = YOLO(model_path)
29
-
30
- # Perform object detection on the input image using the YOLOv8 model
31
- results = model.predict(image,
32
- conf=conf_threshold,
33
- iou=iou_threshold,
34
- imgsz=image_size)
35
-
36
- # Print the detected objects' information (class, coordinates, and probability)
37
- box = results[0].boxes
38
- print("Object type:", box.cls)
39
- print("Coordinates:", box.xyxy)
40
- print("Probability:", box.conf)
41
-
42
- # Render the output image with bounding boxes around detected objects
43
- render = render_result(model=model, image=image, result=results[0])
44
- return render
45
-
46
-
47
- inputs = [
48
- gr.inputs.Image(type="filepath", label="Input Image"),
49
- gr.inputs.Slider(minimum=320, maximum=1280, default=640,
50
- step=32, label="Image Size"),
51
- gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.001,
52
- step=0.05, label="Confidence Threshold"),
53
- gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.5,
54
- step=0.05, label="IOU Threshold"),
55
- ]
56
-
57
-
58
- outputs = gr.outputs.Image(type="filepath", label="Output Image")
59
-
60
- title = "YOLOv8 101: Custom Object Detection on Construction Workers"
61
-
62
-
63
- examples = [['one.jpg', 640, 0.5, 0.7],
64
- ['two.jpg', 800, 0.5, 0.6],
65
- ['three.jpg', 900, 0.5, 0.8]]
66
-
67
- yolo_app = gr.Interface(
68
- fn=yoloV8_func,
69
- inputs=inputs,
70
- outputs=outputs,
71
- title=title,
72
- examples=examples,
73
- cache_examples=True,
74
- )
75
-
76
- # Launch the Gradio interface in debug mode with queue enabled
77
- yolo_app.launch(debug=True, enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
+ from huggingface_hub import hf_hub_download
4
 
5
 
6
+ def download_models(model_id):
7
+ hf_hub_download("flinta/test", filename=f"{model_id}", local_dir=f"./")
8
+ return f"./{model_id}"
 
 
 
9
 
10
+ @spaces.GPU
11
+ def yolov9_inference(img_path, model_id, image_size, conf_threshold, iou_threshold):
 
 
 
 
 
 
 
 
 
 
12
  """
13
+ Load a YOLOv9 model, configure it, perform inference on an image, and optionally adjust
14
+ the input size and apply test time augmentation.
15
+
16
+ :param model_path: Path to the YOLOv9 model file.
17
+ :param conf_threshold: Confidence threshold for NMS.
18
+ :param iou_threshold: IoU threshold for NMS.
19
+ :param img_path: Path to the image file.
20
+ :param size: Optional, input size for inference.
21
+ :return: A tuple containing the detections (boxes, scores, categories) and the results object for further actions like displaying.
22
+ """
23
+ # Import YOLOv9
24
+ import yolov9
25
+
26
+ # Load the model
27
+ model_path = download_models(model_id)
28
+ model = yolov9.load(model_path, device="cuda:0")
29
+
30
+ # Set model parameters
31
+ model.conf = conf_threshold
32
+ model.iou = iou_threshold
33
+
34
+ # Perform inference
35
+ results = model(img_path, size=image_size)
36
+
37
+ # Optionally, show detection bounding boxes on image
38
+ output = results.render()
39
+
40
+ return output[0]
41
+
42
+
43
+ def app():
44
+ with gr.Blocks():
45
+ with gr.Row():
46
+ with gr.Column():
47
+ img_path = gr.Image(type="filepath", label="Image")
48
+ model_path = gr.Dropdown(
49
+ label="Model",
50
+ choices=[
51
+ "gelan-c.pt",
52
+ "gelan-e.pt",
53
+ "yolov9-c.pt",
54
+ "yolov9-e.pt",
55
+ ],
56
+ value="gelan-e.pt",
57
+ )
58
+ image_size = gr.Slider(
59
+ label="Image Size",
60
+ minimum=320,
61
+ maximum=1280,
62
+ step=32,
63
+ value=640,
64
+ )
65
+ conf_threshold = gr.Slider(
66
+ label="Confidence Threshold",
67
+ minimum=0.1,
68
+ maximum=1.0,
69
+ step=0.1,
70
+ value=0.4,
71
+ )
72
+ iou_threshold = gr.Slider(
73
+ label="IoU Threshold",
74
+ minimum=0.1,
75
+ maximum=1.0,
76
+ step=0.1,
77
+ value=0.5,
78
+ )
79
+ yolov9_infer = gr.Button(value="Inference")
80
+
81
+ with gr.Column():
82
+ output_numpy = gr.Image(type="numpy",label="Output")
83
+
84
+ yolov9_infer.click(
85
+ fn=yolov9_inference,
86
+ inputs=[
87
+ img_path,
88
+ model_path,
89
+ image_size,
90
+ conf_threshold,
91
+ iou_threshold,
92
+ ],
93
+ outputs=[output_numpy],
94
+ )
95
+
96
+ gr.Examples(
97
+ examples=[
98
+ [
99
+ "data/zidane.jpg",
100
+ "gelan-e.pt",
101
+ 640,
102
+ 0.4,
103
+ 0.5,
104
+ ],
105
+ [
106
+ "data/huggingface.jpg",
107
+ "yolov9-c.pt",
108
+ 640,
109
+ 0.4,
110
+ 0.5,
111
+ ],
112
+ ],
113
+ fn=yolov9_inference,
114
+ inputs=[
115
+ img_path,
116
+ model_path,
117
+ image_size,
118
+ conf_threshold,
119
+ iou_threshold,
120
+ ],
121
+ outputs=[output_numpy],
122
+ cache_examples=True,
123
+ )
124
+
125
+
126
+ gradio_app = gr.Blocks()
127
+ with gradio_app:
128
+ gr.HTML(
129
+ """
130
+ <h1 style='text-align: center'>
131
+ YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information
132
+ </h1>
133
+ """)
134
+ gr.HTML(
135
+ """
136
+ <h3 style='text-align: center'>
137
+ Follow me for more!
138
+ <a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>Github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>Linkedin</a> | <a href='https://www.huggingface.co/kadirnar/' target='_blank'>HuggingFace</a>
139
+ </h3>
140
+ """)
141
+ with gr.Row():
142
+ with gr.Column():
143
+ app()
144
+
145
+ gradio_app.launch(debug=True)