| """ |
| Inference wrapper: load image + audio models, run modalities present, apply fusion, return schema. |
| """ |
| from pathlib import Path |
| from typing import Optional, Dict, Any, BinaryIO |
| import json |
| import torch |
| import torchaudio |
| from torchvision import transforms |
| from PIL import Image |
|
|
| |
| import sys |
| ROOT = Path(__file__).resolve().parent.parent.parent |
| sys.path.insert(0, str(ROOT)) |
|
|
|
|
| def _load_image_model(weights_path: Path, label_mapping_path: Path, device: str): |
| from src.models.image_model import ElectricalOutletsImageModel |
| ckpt = torch.load(weights_path, map_location=device) |
| model = ElectricalOutletsImageModel( |
| num_classes=ckpt["num_classes"], |
| label_mapping_path=label_mapping_path, |
| pretrained=False, |
| ) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.idx_to_issue_type = ckpt.get("idx_to_issue_type") |
| model.idx_to_severity = ckpt.get("idx_to_severity") |
| model.eval() |
| return model.to(device), ckpt.get("temperature", 1.0) |
|
|
|
|
| def _load_audio_model(weights_path: Path, label_mapping_path: Path, device: str, config: dict): |
| from src.models.audio_model import ElectricalOutletsAudioModel |
| ckpt = torch.load(weights_path, map_location=device) |
| model = ElectricalOutletsAudioModel( |
| num_classes=ckpt["num_classes"], |
| label_mapping_path=label_mapping_path, |
| n_mels=config.get("n_mels", 64), |
| time_steps=config.get("time_steps", 128), |
| ) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.idx_to_label = ckpt.get("idx_to_label") |
| model.idx_to_issue_type = ckpt.get("idx_to_issue_type") |
| model.idx_to_severity = ckpt.get("idx_to_severity") |
| model.eval() |
| return model.to(device), ckpt.get("temperature", 1.0) |
|
|
|
|
| def run_electrical_outlets_inference( |
| image_path: Optional[Path] = None, |
| image_fp: Optional[BinaryIO] = None, |
| audio_path: Optional[Path] = None, |
| audio_fp: Optional[BinaryIO] = None, |
| weights_dir: Path = None, |
| config_dir: Path = None, |
| device: str = None, |
| ) -> Dict[str, Any]: |
| """ |
| Run image and/or audio model, then fuse. Returns canonical schema dict. |
| """ |
| if weights_dir is None: |
| weights_dir = ROOT / "weights" |
| if config_dir is None: |
| config_dir = ROOT / "config" |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| label_mapping_path = config_dir / "label_mapping.json" |
| thresholds_path = config_dir / "thresholds.yaml" |
| import yaml |
| with open(thresholds_path) as f: |
| thresholds = yaml.safe_load(f) |
|
|
| image_out = None |
| if image_path or image_fp: |
| img = Image.open(image_path or image_fp).convert("RGB") |
| tf = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
| x = tf(img).unsqueeze(0).to(device) |
| model, T = _load_image_model(weights_dir / "electrical_outlets_image_best.pt", label_mapping_path, device) |
| with torch.no_grad(): |
| logits = model(x) / T |
| from src.fusion.fusion_logic import ModalityOutput |
| pred = model.predict_to_schema(logits) |
| image_out = ModalityOutput( |
| result=pred["result"], |
| issue_type=pred.get("issue_type"), |
| severity=pred["severity"], |
| confidence=pred["confidence"], |
| ) |
|
|
| audio_out = None |
| if (audio_path or audio_fp) and (weights_dir / "electrical_outlets_audio_best.pt").exists(): |
| if audio_path: |
| waveform, sr = torchaudio.load(str(audio_path)) |
| else: |
| import io |
| waveform, sr = torchaudio.load(io.BytesIO(audio_fp.read())) |
| if sr != 16000: |
| waveform = torchaudio.functional.resample(waveform, sr, 16000) |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| target_len = int(5.0 * 16000) |
| if waveform.shape[1] >= target_len: |
| start = (waveform.shape[1] - target_len) // 2 |
| waveform = waveform[:, start : start + target_len] |
| else: |
| waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.shape[1])) |
| mel = torchaudio.transforms.MelSpectrogram( |
| sample_rate=16000, n_fft=512, hop_length=256, win_length=512, n_mels=64, |
| )(waveform) |
| log_mel = torch.log(mel.clamp(min=1e-5)).unsqueeze(0).to(device) |
| model, T = _load_audio_model( |
| weights_dir / "electrical_outlets_audio_best.pt", |
| label_mapping_path, |
| device, |
| {"n_mels": 64, "time_steps": 128}, |
| ) |
| with torch.no_grad(): |
| logits = model(log_mel) / T |
| from src.fusion.fusion_logic import ModalityOutput |
| pred = model.predict_to_schema(logits) |
| audio_out = ModalityOutput( |
| result=pred["result"], |
| issue_type=pred.get("issue_type"), |
| severity=pred["severity"], |
| confidence=pred["confidence"], |
| ) |
|
|
| from src.fusion.fusion_logic import fuse_modalities |
| return fuse_modalities( |
| image_out, |
| audio_out, |
| confidence_issue_min=thresholds.get("confidence_issue_min", 0.6), |
| confidence_normal_min=thresholds.get("confidence_normal_min", 0.75), |
| uncertain_if_disagree=thresholds.get("uncertain_if_disagree", True), |
| high_confidence_override=thresholds.get("high_confidence_override", 0.92), |
| severity_order=thresholds.get("severity_order"), |
| ) |
|
|