| """ |
| Image classifier for Electrical Outlets. EfficientNet-B0 backbone + MLP head. |
| FINAL v5: 5 classes (no GFCI). |
| """ |
| from pathlib import Path |
| from typing import Dict, Any, Optional |
| import json |
| import torch |
| import torch.nn as nn |
| from torchvision import models |
|
|
|
|
| class ElectricalOutletsImageModel(nn.Module): |
|
|
| def __init__( |
| self, |
| num_classes: int = 5, |
| label_mapping_path: Optional[Path] = None, |
| pretrained: bool = True, |
| head_hidden: int = 256, |
| head_dropout: float = 0.4, |
| ): |
| super().__init__() |
| self.num_classes = num_classes |
| self.backbone = models.efficientnet_b0( |
| weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None |
| ) |
| in_features = self.backbone.classifier[1].in_features |
| self.backbone.classifier = nn.Identity() |
|
|
| self.head = nn.Sequential( |
| nn.Dropout(head_dropout), |
| nn.Linear(in_features, head_hidden), |
| nn.ReLU(), |
| nn.Dropout(head_dropout * 0.5), |
| nn.Linear(head_hidden, num_classes), |
| ) |
|
|
| self.idx_to_issue_type = None |
| self.idx_to_severity = None |
| if label_mapping_path and Path(label_mapping_path).exists(): |
| with open(label_mapping_path) as f: |
| lm = json.load(f) |
| self.idx_to_issue_type = lm["image"]["idx_to_issue_type"] |
| self.idx_to_severity = lm["image"]["idx_to_severity"] |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| features = self.backbone(x) |
| return self.head(features) |
|
|
| def predict_to_schema(self, logits: torch.Tensor) -> Dict[str, Any]: |
| probs = torch.softmax(logits, dim=-1) |
| if logits.dim() == 1: |
| probs = probs.unsqueeze(0) |
| conf, pred = probs.max(dim=-1) |
| pred = pred.item() if pred.numel() == 1 else pred |
| conf = conf.item() if conf.numel() == 1 else conf |
| issue_type = (self.idx_to_issue_type or ["unknown"] * self.num_classes)[pred] |
| severity = (self.idx_to_severity or ["medium"] * self.num_classes)[pred] |
| result = "normal" if issue_type == "normal" else "issue_detected" |
| return { |
| "result": result, |
| "issue_type": issue_type, |
| "severity": severity, |
| "confidence": float(conf), |
| "class_idx": int(pred), |
| } |
|
|