SuriRaja commited on
Commit
fdb85e9
·
1 Parent(s): fc3c42a

Update services/thermal_service.py

Browse files
Files changed (1) hide show
  1. services/thermal_service.py +20 -16
services/thermal_service.py CHANGED
@@ -1,48 +1,52 @@
1
- import torch
2
  import os
 
3
  from ultralytics import YOLO
4
  from torch.serialization import add_safe_globals
 
 
5
  from ultralytics.nn.tasks import DetectionModel
 
6
  import torch.nn.modules.container as container
7
 
8
- # ✅ Register all required globals
9
  add_safe_globals({
10
  container.Sequential: "torch.nn.modules.container.Sequential",
11
  container.ModuleList: "torch.nn.modules.container.ModuleList",
12
  container.ModuleDict: "torch.nn.modules.container.ModuleDict",
13
- DetectionModel: "ultralytics.nn.tasks.DetectionModel"
 
14
  })
15
 
16
  def load_yolo_model_safely(model_path: str = 'yolov8n.pt') -> YOLO:
17
  """
18
- Safely loads a YOLO model with necessary PyTorch 2.6+ fixes
19
- and ensures the model file exists.
20
  """
21
- # Auto-download yolov8n.pt if not present
22
  if not os.path.isfile(model_path):
23
- print(f"[INFO] Downloading {model_path}...")
24
- # Ultralytics automatically handles download inside YOLO if path missing
25
- return YOLO(model_path)
26
-
27
  try:
28
  model = YOLO(model_path)
 
29
  return model
30
  except Exception as e:
31
- print(f"[ERROR] Failed loading YOLO model: {e}")
32
  raise
33
 
34
- # ✅ Load model once at module load
35
  thermal_model = load_yolo_model_safely()
36
 
37
- def detect_thermal_anomalies(image_path):
38
  """
39
- Detects thermal-like anomalies (placeholder simulation).
40
  """
41
  results = thermal_model(image_path)
42
  flagged = []
43
  for r in results:
44
  for box in r.boxes:
45
- # Simulate flagging if confidence > 70%
46
  if box.conf > 0.7:
47
- flagged.append(box)
 
 
 
48
  return flagged
 
 
1
  import os
2
+ import torch
3
  from ultralytics import YOLO
4
  from torch.serialization import add_safe_globals
5
+
6
+ # Import all necessary classes explicitly
7
  from ultralytics.nn.tasks import DetectionModel
8
+ from ultralytics.nn.modules import Conv
9
  import torch.nn.modules.container as container
10
 
11
+ # ✅ Register all trusted classes
12
  add_safe_globals({
13
  container.Sequential: "torch.nn.modules.container.Sequential",
14
  container.ModuleList: "torch.nn.modules.container.ModuleList",
15
  container.ModuleDict: "torch.nn.modules.container.ModuleDict",
16
+ DetectionModel: "ultralytics.nn.tasks.DetectionModel",
17
+ Conv: "ultralytics.nn.modules.Conv"
18
  })
19
 
20
  def load_yolo_model_safely(model_path: str = 'yolov8n.pt') -> YOLO:
21
  """
22
+ Safely loads a YOLO model with necessary PyTorch 2.6+ safe globals registration
23
+ and ensures auto-download if missing.
24
  """
 
25
  if not os.path.isfile(model_path):
26
+ print(f"[INFO] Model {model_path} not found locally. Auto-downloading...")
 
 
 
27
  try:
28
  model = YOLO(model_path)
29
+ print(f"[INFO] YOLO model {model_path} loaded successfully.")
30
  return model
31
  except Exception as e:
32
+ print(f"[ERROR] Could not load YOLO model: {e}")
33
  raise
34
 
35
+ # ✅ Load YOLO model globally
36
  thermal_model = load_yolo_model_safely()
37
 
38
+ def detect_thermal_anomalies(image_path: str):
39
  """
40
+ Detects thermal anomalies in a given image frame using YOLO.
41
  """
42
  results = thermal_model(image_path)
43
  flagged = []
44
  for r in results:
45
  for box in r.boxes:
46
+ # Simulate thermal detection by confidence threshold
47
  if box.conf > 0.7:
48
+ flagged.append({
49
+ "confidence": float(box.conf),
50
+ "bbox": box.xyxy.tolist() if hasattr(box, 'xyxy') else []
51
+ })
52
  return flagged