hca97 commited on
Commit
eedca6c
1 Parent(s): bc611ff

adding yolov8s model as well

Browse files
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  env
2
- __pycache__
 
 
1
  env
2
+ __pycache__
3
+ gradio_cached_examples
README.md CHANGED
@@ -23,6 +23,9 @@ The target species were:
23
  - **Culiseta** - Genus
24
  - **Aedes japonicus/Aedes koreicus** - Species complex (Differentiating between these two species is particularly challenging).
25
 
 
 
 
26
  ## Experiment Details
27
 
28
  All the details regarding the experiments and source code for the models can be found in the [GitHub repository](https://github.com/HCA97/Mosquito-Classifiction/tree/main).
 
23
  - **Culiseta** - Genus
24
  - **Aedes japonicus/Aedes koreicus** - Species complex (Differentiating between these two species is particularly challenging).
25
 
26
+ > ***Note:** Only one mosquito will be annotated even if there are multiple mosquitoes in the image.*
27
+
28
+
29
  ## Experiment Details
30
 
31
  All the details regarding the experiments and source code for the models can be found in the [GitHub repository](https://github.com/HCA97/Mosquito-Classifiction/tree/main).
app.py CHANGED
@@ -4,37 +4,44 @@ import gradio as gr
4
  import numpy as np
5
  import cv2
6
 
7
- from my_models import YOLOV5CLIPModel
8
 
9
 
10
  def annotated_image(
11
  image: np.ndarray, label: str, conf: float, bbox: list
12
  ) -> np.ndarray:
13
-
14
  line_thickness = int(0.005 * max(image.shape[:2]))
15
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
16
- image = cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), thickness=line_thickness)
17
- image = cv2.putText(image,
18
- f"{label} {conf:.2f}",
19
- (bbox[0], max(bbox[1] - 2*line_thickness, 0)),
20
- cv2.FONT_HERSHEY_SIMPLEX,
21
- thickness=max(line_thickness//2, 1),
22
- lineType=cv2.LINE_AA,
23
- color=(0, 0, 0),
24
- fontScale=0.1*line_thickness)
 
 
 
 
 
 
 
 
25
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
26
 
27
  return image
28
 
29
 
30
  def detect_mosquito(image):
31
- label, conf, bbox = YOLOV5CLIPModel().predict(image)
32
  return annotated_image(image, label, conf, bbox)
33
 
34
 
35
  description = """# [Mosquito Alert Competition 2023](https://www.aicrowd.com/challenges/mosquitoalert-challenge-2023) - 7th Place Solution
36
 
37
- Welcome to my Hugging Face Space showcasing the performance of our model.
38
 
39
  This competition focused on detecting and classifying various mosquito species.
40
 
@@ -46,14 +53,20 @@ The target species were:
46
  - **Culiseta** - Genus
47
  - **Aedes japonicus/Aedes koreicus** - Species complex (Differentiating between these two species is particularly challenging).
48
 
 
 
49
  ## Experiment Details
50
 
51
  All the details regarding the experiments and source code for the models can be found in the [GitHub repository](https://github.com/HCA97/Mosquito-Classifiction/tree/main).
52
  """
53
 
54
  iface = gr.Interface(
55
- fn=detect_mosquito, description=description, inputs=gr.Image(), outputs=gr.Image(), allow_flagging="never",
 
 
 
 
56
  examples=[os.path.join("examples", f) for f in os.listdir("examples")],
57
- cache_examples=True
58
  )
59
  iface.launch()
 
4
  import numpy as np
5
  import cv2
6
 
7
+ from my_models import YOLOV5CLIPModel, YOLOV8CLIPModel
8
 
9
 
10
  def annotated_image(
11
  image: np.ndarray, label: str, conf: float, bbox: list
12
  ) -> np.ndarray:
 
13
  line_thickness = int(0.005 * max(image.shape[:2]))
14
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
15
+ image = cv2.rectangle(
16
+ image,
17
+ (bbox[0], bbox[1]),
18
+ (bbox[2], bbox[3]),
19
+ (255, 0, 0),
20
+ thickness=line_thickness,
21
+ )
22
+ image = cv2.putText(
23
+ image,
24
+ f"{label} {conf:.2f}",
25
+ (bbox[0], max(bbox[1] - 2 * line_thickness, 0)),
26
+ cv2.FONT_HERSHEY_SIMPLEX,
27
+ thickness=max(line_thickness // 2, 1),
28
+ lineType=cv2.LINE_AA,
29
+ color=(0, 0, 0),
30
+ fontScale=0.1 * line_thickness,
31
+ )
32
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
33
 
34
  return image
35
 
36
 
37
  def detect_mosquito(image):
38
+ label, conf, bbox = YOLOV8CLIPModel().predict(image)
39
  return annotated_image(image, label, conf, bbox)
40
 
41
 
42
  description = """# [Mosquito Alert Competition 2023](https://www.aicrowd.com/challenges/mosquitoalert-challenge-2023) - 7th Place Solution
43
 
44
+ Welcome to my Hugging Face Space showcasing the performance of our model.
45
 
46
  This competition focused on detecting and classifying various mosquito species.
47
 
 
53
  - **Culiseta** - Genus
54
  - **Aedes japonicus/Aedes koreicus** - Species complex (Differentiating between these two species is particularly challenging).
55
 
56
+ > ***Note:** Only one mosquito will be annotated even if there are multiple mosquitoes in the image.*
57
+
58
  ## Experiment Details
59
 
60
  All the details regarding the experiments and source code for the models can be found in the [GitHub repository](https://github.com/HCA97/Mosquito-Classifiction/tree/main).
61
  """
62
 
63
  iface = gr.Interface(
64
+ fn=detect_mosquito,
65
+ description=description,
66
+ inputs=gr.Image(),
67
+ outputs=gr.Image(),
68
+ allow_flagging="never",
69
  examples=[os.path.join("examples", f) for f in os.listdir("examples")],
70
+ cache_examples=True,
71
  )
72
  iface.launch()
my_models/__init__.py CHANGED
@@ -1 +1,2 @@
1
  from .yolov5_clip_model import YOLOV5CLIPModel
 
 
1
  from .yolov5_clip_model import YOLOV5CLIPModel
2
+ from .yolov8_clip_model import YOLOV8CLIPModel
my_models/yolo_weights/best-yolov8-s.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36080c806f3a8b501bc52c126d68427c61dd15b8dcff1423ae35163588a09583
3
+ size 22484974
my_models/yolov8_clip_model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import numpy as np
3
+ import time
4
+
5
+ import torch
6
+
7
+ torch.set_num_threads(2)
8
+
9
+ from my_models.clip_model.data_loader import pre_process_foo
10
+ from my_models.clip_model.classification import MosquitoClassifier
11
+
12
+ IMG_SIZE = (224, 224)
13
+ USE_CHANNEL_LAST = False
14
+ DATASET = "laion"
15
+ DEVICE = "cpu"
16
+ PRESERVE_ASPECT_RATIO = False
17
+ SHIFT = 0
18
+
19
+
20
+ @torch.no_grad()
21
+ def classify_image(det: YOLO, cls: MosquitoClassifier, image: np.ndarray):
22
+ s = time.time()
23
+ labels = [
24
+ "albopictus",
25
+ "culex",
26
+ "japonicus-koreicus",
27
+ "culiseta",
28
+ "anopheles",
29
+ "aegypti",
30
+ ]
31
+
32
+ results = det(image, verbose=True, device=DEVICE, max_det=1)
33
+ img_w, img_h, _ = image.shape
34
+ bbox = [0, 0, img_w, img_h]
35
+ label = "albopictus"
36
+ conf = 0.0
37
+
38
+ for result in results:
39
+ _bbox = [0, 0, img_w, img_h]
40
+ _label = "albopictus"
41
+ _conf = 0.0
42
+
43
+ bboxes_tmp = result.boxes.xyxy.tolist()
44
+ labels_tmp = result.boxes.cls.tolist()
45
+ confs_tmp = result.boxes.conf.tolist()
46
+
47
+ for bbox_tmp, label_tmp, conf_tmp in zip(bboxes_tmp, labels_tmp, confs_tmp):
48
+ if conf_tmp > _conf:
49
+ _bbox = bbox_tmp
50
+ _label = labels[int(label_tmp)]
51
+ _conf = conf_tmp
52
+
53
+ if _conf > conf:
54
+ bbox = _bbox
55
+ label = _label
56
+ conf = _conf
57
+
58
+ bbox = [int(float(mcb)) for mcb in bbox]
59
+
60
+ try:
61
+ if conf < 1e-4:
62
+ raise Exception
63
+ image_cropped = image[bbox[1] : bbox[3], bbox[0] : bbox[2], :]
64
+ bbox = [bbox[0] + SHIFT, bbox[1] + SHIFT, bbox[2] - SHIFT, bbox[3] - SHIFT]
65
+ except Exception as e:
66
+ print("Error", e)
67
+ image_cropped = image
68
+
69
+ if PRESERVE_ASPECT_RATIO:
70
+ w, h = image_cropped.shape[:2]
71
+ if w > h:
72
+ x = torch.unsqueeze(
73
+ pre_process_foo(
74
+ (IMG_SIZE[0], max(int(IMG_SIZE[1] * h / w), 32)), DATASET
75
+ )(image_cropped),
76
+ 0,
77
+ )
78
+ else:
79
+ x = torch.unsqueeze(
80
+ pre_process_foo(
81
+ (max(int(IMG_SIZE[0] * w / h), 32), IMG_SIZE[1]), DATASET
82
+ )(image_cropped),
83
+ 0,
84
+ )
85
+ else:
86
+ x = torch.unsqueeze(pre_process_foo(IMG_SIZE, DATASET)(image_cropped), 0)
87
+
88
+ x = x.to(device=DEVICE)
89
+
90
+ if USE_CHANNEL_LAST:
91
+ p = cls(x.to(memory_format=torch.channels_last))
92
+ else:
93
+ p = cls(x)
94
+ ind = torch.argmax(p).item()
95
+ label = labels[ind]
96
+
97
+ e = time.time()
98
+
99
+ print("Time ", 1000 * (e - s), "ms")
100
+ return {"name": label, "confidence": p.max().item(), "bbox": bbox}
101
+
102
+
103
+ # getting mosquito_class name from predicted result
104
+ def extract_predicted_mosquito_class_name(extractedInformation):
105
+ return extractedInformation.get("name", "albopictus")
106
+
107
+
108
+ def extract_predicted_mosquito_bbox(extractedInformation):
109
+ return extractedInformation.get("bbox", [0, 0, 0, 0])
110
+
111
+
112
+ class YOLOV8CLIPModel:
113
+ def __init__(self):
114
+ trained_model_path = "my_models/yolo_weights/best-yolov8-s.pt"
115
+ clip_model_path = f"my_models/clip_weights/best_clf.ckpt"
116
+ self.det = YOLO(trained_model_path, task="detect")
117
+
118
+ self.cls = MosquitoClassifier.load_from_checkpoint(
119
+ clip_model_path, head_version=7, map_location=torch.device(DEVICE)
120
+ ).eval()
121
+
122
+ if USE_CHANNEL_LAST:
123
+ self.cls.to(memory_format=torch.channels_last)
124
+
125
+ def predict(self, image):
126
+ predictedInformation = classify_image(self.det, self.cls, image)
127
+
128
+ mosquito_class_name_predicted = extract_predicted_mosquito_class_name(
129
+ predictedInformation
130
+ )
131
+ mosquito_class_bbox = extract_predicted_mosquito_bbox(predictedInformation)
132
+
133
+ bbox = bbox = [int(float(mcb)) for mcb in mosquito_class_bbox]
134
+
135
+ return mosquito_class_name_predicted, predictedInformation["confidence"], bbox