VlaTal commited on
Commit
807a8fb
1 Parent(s): 3f72a26

added model and threshold choosing

Browse files
app.py CHANGED
@@ -1,30 +1,57 @@
 
1
  import gradio as gr
2
  from ultralytics import YOLO
3
  import numpy as np
4
- import os
5
-
6
- # Load YOLO model
7
- model = YOLO('./best.pt')
8
 
 
 
 
9
  example_list = [["examples/" + example] for example in os.listdir("examples")]
10
 
11
- def process_image(input_image):
12
- if input_image is not None:
13
- results = model(input_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- for r in results:
16
- im_array = r.plot()
17
- im_array = im_array.astype(np.uint8)
18
- return im_array
 
 
 
 
 
 
 
19
 
20
- # Create Gradio Interface
21
  iface = gr.Interface(
22
  fn=process_image,
23
- inputs=gr.Image(),
24
- outputs=gr.Image(), # Specify output as Gradio Image
25
- title="YOLOv8-obb ships detection",
26
- description="YOLOv8-obb trained on ShipRSImageNet_BAL-2",
27
- examples=example_list)
28
-
29
- # Launch the Gradio interface
30
- iface.launch()
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
  from ultralytics import YOLO
4
  import numpy as np
 
 
 
 
5
 
6
+ model_options = ["yolo-8n-shiprs.pt", "yolo-8s-shiprs.pt", "yolo-8m-shiprs.pt"]
7
+ model_names = ["Nano", "Small", "Medium"]
8
+ models = [YOLO(option) for option in model_options]
9
  example_list = [["examples/" + example] for example in os.listdir("examples")]
10
 
11
+ def process_image(input_image, model_name, conf):
12
+ if input_image is None:
13
+ return None, "No image"
14
+
15
+ if model_name is None:
16
+ model_name = model_names[0]
17
+
18
+ if conf is None:
19
+ conf = 0.6
20
+
21
+ model_index = model_names.index(model_name)
22
+ model = models[model_index]
23
+
24
+ results = model.predict(input_image, conf=conf)
25
+ class_counts = {}
26
+ class_counts_str = "Class Counts:\n"
27
+
28
+ for r in results:
29
+ im_array = r.plot()
30
+ im_array = im_array.astype(np.uint8)
31
 
32
+ if len(r.obb.cls) == 0: # If no objects are detected
33
+ return None, "No objects detected."
34
+
35
+ for cls in r.obb.cls:
36
+ class_name = r.names[cls.item()]
37
+ class_counts[class_name] = class_counts.get(class_name, 0) + 1
38
+
39
+ for cls, count in class_counts.items():
40
+ class_counts_str += f"\n{cls}: {count}"
41
+
42
+ return im_array, class_counts_str
43
 
 
44
  iface = gr.Interface(
45
  fn=process_image,
46
+ inputs=[
47
+ gr.Image(),
48
+ gr.Radio(model_names, label="Choose model", value=model_names[0]),
49
+ gr.Slider(minimum=0.2, maximum=1.0, step=0.1, label="Confidence Threshold", value=0.6)
50
+ ],
51
+ outputs=["image", gr.Textbox(label="More info")],
52
+ title="YOLOv8-obb aerial detection",
53
+ description='''YOLOv8-obb trained on DOTAv1.5''',
54
+ examples=example_list
55
+ )
56
+
57
+ iface.launch()
best.pt → yolo-8m-shiprs.pt RENAMED
File without changes
yolo-8n-shiprs.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5734148e15a344cf3b629a8cceed69d86ce84f2f475bcf5bf1b6e2013c857878
3
+ size 6567938
yolo-8s-shiprs.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55d73e67f64549bbdcbfdc3785bd67d20fecf8580f3ab1171561e382b49cd954
3
+ size 23283778