nehulagrawal commited on
Commit
eac9d61
1 Parent(s): d341b67

Update detection.py

Browse files
Files changed (1) hide show
  1. detection.py +72 -62
detection.py CHANGED
@@ -1,63 +1,73 @@
1
- import cv2
2
- import IPython
3
- from PIL import ImageColor
4
- from ultralytics import YOLO
5
-
6
- class ObjectDetection:
7
- def __init__(self, model_name='Yolov8'):
8
- self.model_name = model_name
9
- self.model = self.load_model()
10
- self.classes = self.model.names
11
- self.device = 'cpu'
12
-
13
- def load_model(self):
14
- model = YOLO(f"weights/{self.model_name}_best.pt")
15
- return model
16
-
17
- def v8_score_frame(self, frame):
18
- results = self.model(frame)
19
-
20
- labels = []
21
- confidences = []
22
- coords = []
23
-
24
- for result in results:
25
- boxes = result.boxes.cpu().numpy()
26
-
27
- label = boxes.cls
28
- conf = boxes.conf
29
- coord = boxes.xyxy
30
-
31
- labels.extend(label)
32
- confidences.extend(conf)
33
- coords.extend(coord)
34
-
35
- return labels, confidences, coords
36
-
37
- def get_coords(self, frame, row):
38
- return int(row[0]), int(row[1]), int(row[2]), int(row[3])
39
-
40
- def class_to_label(self, x):
41
- return self.classes[int(x)]
42
-
43
- def get_color(self, code):
44
- rgb = ImageColor.getcolor(code, "RGB")
45
- return rgb
46
-
47
- def plot_bboxes(self, results, frame, threshold=0.5, box_color='red', text_color='white'):
48
- labels, conf, coord = results
49
-
50
- frame = frame.copy()
51
- box_color = self.get_color(box_color)
52
- text_color = self.get_color(text_color)
53
-
54
- for i in range(len(labels)):
55
- if conf[i] >= threshold:
56
- x1, y1, x2, y2 = self.get_coords(frame, coord[i])
57
- class_name = self.class_to_label(labels[i])
58
-
59
- cv2.rectangle(frame, (x1, y1), (x2, y2), box_color, 2)
60
- cv2.putText(frame, f"{class_name} - {conf[i]*100:.2f}%", (x1, y1), cv2.FONT_HERSHEY_COMPLEX, 0.5, text_color)
61
-
62
- return frame
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from detection import ObjectDetection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ examples = [
5
+ ['test-images/plant1.jpeg', 0.23],
6
+ ['test-images/plant2.jpeg', 0.45],
7
+ ['test-images/plant3.webp', 0.43],
8
+ ]
9
+
10
+ def get_predictions(img, threshold, box_color, text_color):
11
+ v8_results = yolov8_detector.v8_score_frame(img)
12
+ v8_frame = yolov8_detector.plot_bboxes(v8_results, img, float(threshold), box_color, text_color)
13
+ return v8_frame
14
+
15
+ # Load the YOLOv8 model for plant leaf detection and classification
16
+ yolov8_detector = ObjectDetection('Yolov8')
17
+
18
+ interface = gr.Interface(
19
+ fn=get_predictions,
20
+ inputs=[
21
+ gr.Image(shape=(824, 824), label="Input Image"),
22
+ gr.Slider(maximum=1, step=0.01, value=0.4, label="Confidence Threshold", interactive=True),
23
+ gr.ColorPicker(label="Box Color", value="#FF8C00"),
24
+ gr.ColorPicker(label="Prediction Color", value="#000000"),
25
+ ],
26
+ outputs=gr.Image(label="YOLOv8 Prediction"),
27
+ examples=examples,
28
+ live=True,
29
+ title="Plant Leaf Detection and Classification",
30
+ )
31
+
32
+ # Custom CSS to create a dark mode appearance
33
+ custom_css = """
34
+ <style>
35
+ body {
36
+ background-color: #222222;
37
+ color: #FFFFFF;
38
+ }
39
+
40
+ h1, h2, h3, h4, h5, h6 {
41
+ color: #FF8C00;
42
+ }
43
+
44
+ .gradio-interface {
45
+ border: 1px solid #FF8C00;
46
+ border-radius: 10px;
47
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
48
+ }
49
+
50
+ .gradio-interface > .title {
51
+ background-color: #FF8C00;
52
+ color: #FFFFFF;
53
+ padding: 12px;
54
+ border-top-left-radius: 10px;
55
+ border-top-right-radius: 10px;
56
+ }
57
+
58
+ .gradio-interface > .content {
59
+ padding: 20px;
60
+ }
61
+
62
+ .gradio-interface > .footer {
63
+ background-color: #FF8C00;
64
+ color: #FFFFFF;
65
+ padding: 12px;
66
+ border-bottom-left-radius: 10px;
67
+ border-bottom-right-radius: 10px;
68
+ }
69
+ </style>
70
+ """
71
+
72
+ # Inject custom CSS into the interface
73
+ interface.launch(share=False, custom_css=custom_css)