cpraschl commited on
Commit
6f74262
·
verified ·
1 Parent(s): 5dbce6d

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +160 -0
inference.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wildlife Detection with YOLOv26 — Inference Script
3
+ ===================================================
4
+ Supports RGB and thermal drone imagery.
5
+
6
+ Usage:
7
+ python inference.py --model rgb --source path/to/image.jpg
8
+ python inference.py --model thermal_merged --source path/to/thermal/ --save
9
+ python inference.py --model matched_rgb --source image.jpg --conf 0.3 --show
10
+
11
+ Available models:
12
+ thermal_original — Baseline thermal model
13
+ thermal_merged — Refined thermal model (more training data)
14
+ rgb — Primary RGB model
15
+ matched_rgb — RGB model trained on matched RGB/thermal pairs
16
+ matched_thermal — Thermal model trained on matched RGB/thermal pairs
17
+ """
18
+
19
+ import argparse
20
+ from pathlib import Path
21
+ from ultralytics import YOLO
22
+
23
+
24
+ MODELS = {
25
+ "thermal_original": "thermal_original/weights/best.pt",
26
+ "thermal_merged": "thermal_merged/weights/best.pt",
27
+ "rgb": "rgb/weights/best.pt",
28
+ "matched_rgb": "matched_rgb/weights/best.pt",
29
+ "matched_thermal": "matched_thermal/weights/best.pt",
30
+ }
31
+
32
+
33
+ def load_model(name: str) -> YOLO:
34
+ """Load a model by name or direct path."""
35
+ path = MODELS.get(name, name)
36
+ print(f"Loading model: {path}")
37
+ return YOLO(path)
38
+
39
+
40
+ def run_inference(
41
+ model_name: str = "rgb",
42
+ source: str = "0",
43
+ imgsz: int = 1024,
44
+ conf: float = 0.25,
45
+ iou: float = 0.45,
46
+ show: bool = False,
47
+ save: bool = False,
48
+ save_txt: bool = False,
49
+ project: str = "detections",
50
+ name: str = "predict",
51
+ device: str = "",
52
+ ):
53
+ """Run inference and return results."""
54
+ model = load_model(model_name)
55
+
56
+ results = model.predict(
57
+ source=source,
58
+ imgsz=imgsz,
59
+ conf=conf,
60
+ iou=iou,
61
+ show=show,
62
+ save=save,
63
+ save_txt=save_txt,
64
+ project=project,
65
+ name=name,
66
+ device=device if device else None,
67
+ )
68
+
69
+ for i, result in enumerate(results):
70
+ n = len(result.boxes)
71
+ print(f"[Image {i+1}] {n} detection(s)")
72
+ for box in result.boxes:
73
+ cls_id = int(box.cls.item())
74
+ cls_name = result.names[cls_id]
75
+ conf_val = box.conf.item()
76
+ xyxy = [round(v, 1) for v in box.xyxy[0].tolist()]
77
+ print(f" {cls_name:15s} conf={conf_val:.2f} bbox={xyxy}")
78
+
79
+ return results
80
+
81
+
82
+ def compare_modalities(
83
+ rgb_source: str,
84
+ thermal_source: str,
85
+ conf: float = 0.25,
86
+ imgsz: int = 1024,
87
+ ):
88
+ """
89
+ Compare RGB vs thermal detections on co-registered image pairs.
90
+ Useful for the matched dataset experiments.
91
+ """
92
+ rgb_model = load_model("matched_rgb")
93
+ thermal_model = load_model("matched_thermal")
94
+
95
+ rgb_results = rgb_model.predict(rgb_source, imgsz=imgsz, conf=conf, verbose=False)
96
+ thermal_results = thermal_model.predict(thermal_source, imgsz=imgsz, conf=conf, verbose=False)
97
+
98
+ for i, (r_rgb, r_thm) in enumerate(zip(rgb_results, thermal_results)):
99
+ print(f"\n--- Pair {i+1} ---")
100
+ print(f" RGB detections: {len(r_rgb.boxes)}")
101
+ print(f" Thermal detections: {len(r_thm.boxes)}")
102
+
103
+ return rgb_results, thermal_results
104
+
105
+
106
+ def main():
107
+ parser = argparse.ArgumentParser(description="Wildlife YOLOv26 Inference")
108
+ parser.add_argument(
109
+ "--model",
110
+ default="rgb",
111
+ choices=list(MODELS.keys()) + ["custom"],
112
+ help="Model to use. Pass a file path with --model custom --weights <path>.",
113
+ )
114
+ parser.add_argument("--weights", default=None, help="Direct path to .pt weights file.")
115
+ parser.add_argument("--source", default="0", help="Image/video/folder path or webcam index.")
116
+ parser.add_argument("--imgsz", type=int, default=1024, help="Inference image size.")
117
+ parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold.")
118
+ parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold.")
119
+ parser.add_argument("--show", action="store_true", help="Display results.")
120
+ parser.add_argument("--save", action="store_true", help="Save annotated images.")
121
+ parser.add_argument("--save-txt", action="store_true", help="Save YOLO-format labels.")
122
+ parser.add_argument("--project", default="detections", help="Output project folder.")
123
+ parser.add_argument("--name", default="predict", help="Output run name.")
124
+ parser.add_argument("--device", default="", help="CUDA device, e.g. '0' or 'cpu'.")
125
+ parser.add_argument(
126
+ "--compare",
127
+ nargs=2,
128
+ metavar=("RGB_SOURCE", "THERMAL_SOURCE"),
129
+ help="Compare RGB and thermal models on co-registered pairs.",
130
+ )
131
+ args = parser.parse_args()
132
+
133
+ if args.compare:
134
+ compare_modalities(
135
+ rgb_source=args.compare[0],
136
+ thermal_source=args.compare[1],
137
+ conf=args.conf,
138
+ imgsz=args.imgsz,
139
+ )
140
+ return
141
+
142
+ model_name = args.weights if (args.model == "custom" and args.weights) else args.model
143
+
144
+ run_inference(
145
+ model_name=model_name,
146
+ source=args.source,
147
+ imgsz=args.imgsz,
148
+ conf=args.conf,
149
+ iou=args.iou,
150
+ show=args.show,
151
+ save=args.save,
152
+ save_txt=args.save_txt,
153
+ project=args.project,
154
+ name=args.name,
155
+ device=args.device,
156
+ )
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()