Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- .github/workflows/ci_cd.yml +56 -0
- .gitignore +144 -0
- README.md +220 -8
- api/__init__.py +1 -0
- api/logger.py +109 -0
- api/main.py +562 -0
- api/schemas.py +134 -0
- api/startup.py +118 -0
- app.py +353 -0
- bug_log.md +40 -0
- conftest.py +6 -0
- docker/Dockerfile +31 -0
- mlops/evaluate_retrieval.py +181 -0
- mlops/evidently_drift.py +117 -0
- mlops/optuna_tuner.py +167 -0
- mlops/promote_model.py +130 -0
- model_card.md +64 -0
- requirements.txt +24 -0
- src/__init__.py +1 -0
- src/cache.py +84 -0
- src/depth.py +155 -0
- src/enrichment.py +71 -0
- src/graph.py +118 -0
- src/llm.py +195 -0
- src/orchestrator.py +258 -0
- src/patchcore.py +197 -0
- src/retriever.py +171 -0
- src/xai.py +280 -0
- start.sh +9 -0
.github/workflows/ci_cd.yml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CI/CD Pipeline
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [main]
|
| 6 |
+
paths:
|
| 7 |
+
- "src/**"
|
| 8 |
+
- "api/**"
|
| 9 |
+
- "docker/**"
|
| 10 |
+
- "app.py"
|
| 11 |
+
- "requirements.txt"
|
| 12 |
+
workflow_dispatch:
|
| 13 |
+
|
| 14 |
+
jobs:
|
| 15 |
+
test-and-deploy:
|
| 16 |
+
runs-on: ubuntu-latest
|
| 17 |
+
|
| 18 |
+
steps:
|
| 19 |
+
- name: Checkout
|
| 20 |
+
uses: actions/checkout@v4
|
| 21 |
+
|
| 22 |
+
- name: Set up Python 3.11
|
| 23 |
+
uses: actions/setup-python@v5
|
| 24 |
+
with:
|
| 25 |
+
python-version: "3.11"
|
| 26 |
+
|
| 27 |
+
- name: Install test dependencies only
|
| 28 |
+
run: |
|
| 29 |
+
pip install pytest fastapi httpx pydantic numpy pillow python-multipart
|
| 30 |
+
|
| 31 |
+
- name: Run pytest
|
| 32 |
+
run: PYTHONPATH=. pytest tests/test_api.py -v --tb=short
|
| 33 |
+
|
| 34 |
+
- name: Deploy to HuggingFace Spaces
|
| 35 |
+
env:
|
| 36 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 37 |
+
run: |
|
| 38 |
+
pip install huggingface_hub
|
| 39 |
+
python -c "
|
| 40 |
+
from huggingface_hub import HfApi
|
| 41 |
+
import os
|
| 42 |
+
api = HfApi(token=os.environ['HF_TOKEN'])
|
| 43 |
+
api.upload_folder(
|
| 44 |
+
folder_path='.',
|
| 45 |
+
repo_id='CaffeinatedCoding/anomalyos',
|
| 46 |
+
repo_type='space',
|
| 47 |
+
ignore_patterns=['*.pyc','__pycache__','.git','tests/','notebooks/','data/','models/','logs/','reports/']
|
| 48 |
+
)
|
| 49 |
+
print('Deployed to HF Spaces')
|
| 50 |
+
"
|
| 51 |
+
|
| 52 |
+
- name: Smoke test
|
| 53 |
+
run: |
|
| 54 |
+
sleep 60
|
| 55 |
+
curl --fail --retry 5 --retry-delay 30 \
|
| 56 |
+
https://caffeinatedcoding-anomalyos.hf.space/health || echo "Space still warming up"
|
.gitignore
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
*.manifest
|
| 32 |
+
*.spec
|
| 33 |
+
|
| 34 |
+
# Installer logs
|
| 35 |
+
pip-log.txt
|
| 36 |
+
pip-delete-this-directory.txt
|
| 37 |
+
|
| 38 |
+
# Unit test / coverage reports
|
| 39 |
+
htmlcov/
|
| 40 |
+
.tox/
|
| 41 |
+
.nox/
|
| 42 |
+
.coverage
|
| 43 |
+
.coverage.*
|
| 44 |
+
.cache
|
| 45 |
+
nosetests.xml
|
| 46 |
+
coverage.xml
|
| 47 |
+
*.cover
|
| 48 |
+
*.py,cover
|
| 49 |
+
.hypothesis/
|
| 50 |
+
.pytest_cache/
|
| 51 |
+
|
| 52 |
+
# Translations
|
| 53 |
+
*.mo
|
| 54 |
+
*.pot
|
| 55 |
+
|
| 56 |
+
# Django stuff:
|
| 57 |
+
*.log
|
| 58 |
+
local_settings.py
|
| 59 |
+
db.sqlite3
|
| 60 |
+
db.sqlite3-journal
|
| 61 |
+
|
| 62 |
+
# Flask stuff:
|
| 63 |
+
instance/
|
| 64 |
+
.webassets-cache
|
| 65 |
+
|
| 66 |
+
# Scrapy stuff:
|
| 67 |
+
.scrapy
|
| 68 |
+
|
| 69 |
+
# Sphinx documentation
|
| 70 |
+
docs/_build/
|
| 71 |
+
|
| 72 |
+
# PyBuilder
|
| 73 |
+
target/
|
| 74 |
+
|
| 75 |
+
# Jupyter Notebook
|
| 76 |
+
.ipynb_checkpoints
|
| 77 |
+
|
| 78 |
+
# IPython
|
| 79 |
+
profile_default/
|
| 80 |
+
ipython_config.py
|
| 81 |
+
|
| 82 |
+
# pyenv
|
| 83 |
+
.python-version
|
| 84 |
+
|
| 85 |
+
# pipenv
|
| 86 |
+
Pipfile.lock
|
| 87 |
+
|
| 88 |
+
# PEP 582
|
| 89 |
+
__pypackages__/
|
| 90 |
+
|
| 91 |
+
# Celery stuff
|
| 92 |
+
celerybeat-schedule
|
| 93 |
+
celerybeat.pid
|
| 94 |
+
|
| 95 |
+
# SageMath parsed files
|
| 96 |
+
*.sage.py
|
| 97 |
+
|
| 98 |
+
# Environments
|
| 99 |
+
.env
|
| 100 |
+
.venv
|
| 101 |
+
env/
|
| 102 |
+
venv/
|
| 103 |
+
ENV/
|
| 104 |
+
env.bak/
|
| 105 |
+
venv.bak/
|
| 106 |
+
|
| 107 |
+
# Spyder project settings
|
| 108 |
+
.spyderproject
|
| 109 |
+
.spyproject
|
| 110 |
+
|
| 111 |
+
# Rope project settings
|
| 112 |
+
.ropeproject
|
| 113 |
+
|
| 114 |
+
# mkdocs documentation
|
| 115 |
+
/site
|
| 116 |
+
|
| 117 |
+
# mypy
|
| 118 |
+
.mypy_cache/
|
| 119 |
+
.dmypy.json
|
| 120 |
+
dmypy.json
|
| 121 |
+
|
| 122 |
+
# Pyre type checker
|
| 123 |
+
.pyre/
|
| 124 |
+
|
| 125 |
+
# IDE
|
| 126 |
+
.vscode/
|
| 127 |
+
.idea/
|
| 128 |
+
*.swp
|
| 129 |
+
*.swo
|
| 130 |
+
*~
|
| 131 |
+
|
| 132 |
+
# Models and data
|
| 133 |
+
models/
|
| 134 |
+
*.faiss
|
| 135 |
+
data/*.faiss
|
| 136 |
+
data/*.faiss.dvc
|
| 137 |
+
|
| 138 |
+
# Logs
|
| 139 |
+
logs/
|
| 140 |
+
*.log
|
| 141 |
+
|
| 142 |
+
# DVC
|
| 143 |
+
.dvc/
|
| 144 |
+
.dvcignore
|
README.md
CHANGED
|
@@ -1,11 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
| 1 |
+
# AnomalyOS 🔍
|
| 2 |
+
### Industrial Visual Intelligence Platform
|
| 3 |
+
|
| 4 |
+
> Zero training on defects. The AI only knows normal.
|
| 5 |
+
|
| 6 |
+
[](https://huggingface.co/spaces/CaffeinatedCoding/anomalyos)
|
| 7 |
+
[](https://github.com/devangmishra1424/AnomalyOS/actions)
|
| 8 |
+
[](https://python.org)
|
| 9 |
+
[]()
|
| 10 |
+
[]()
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## What Is This
|
| 15 |
+
|
| 16 |
+
AnomalyOS is a five-mode industrial visual inspection platform built on PatchCore (CVPR 2022), implemented from scratch in PyTorch. The system detects defects in manufactured products using only normal training images — no defect labels required. Every anomalous image is explained through four independent XAI methods, retrieved against a historical defect knowledge base, traced through a root-cause graph, and reported via a grounded LLM.
|
| 17 |
+
|
| 18 |
+
**The 15-second demo:** The AI has never seen a defective product. It only knows what normal looks like. Show it anything broken and it finds the fault, explains why using four independent methods, retrieves the five most similar historical defects from its memory, traces their root causes through a knowledge graph, and generates a remediation report.
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Architecture
|
| 23 |
+
```mermaid
|
| 24 |
+
graph LR
|
| 25 |
+
A[User Image] --> B[Inspector Mode]
|
| 26 |
+
B --> C[CLIP Full-Image]
|
| 27 |
+
B --> D[WideResNet Patches]
|
| 28 |
+
C --> E[Index 1: Category Routing]
|
| 29 |
+
D --> F[Index 3: PatchCore Scoring]
|
| 30 |
+
F --> G{Anomalous?}
|
| 31 |
+
G -- No --> H[Normal Result]
|
| 32 |
+
G -- Yes --> I[Defect Crop]
|
| 33 |
+
I --> J[CLIP Crop Embedding]
|
| 34 |
+
J --> K[Index 2: Similar Cases]
|
| 35 |
+
K --> L[Knowledge Graph]
|
| 36 |
+
L --> M[Groq LLM Report]
|
| 37 |
+
F --> N[XAI Layer]
|
| 38 |
+
N --> O[Heatmap + GradCAM++ + SHAP + Retrieval Trace]
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
**Three FAISS indexes, three granularities:**
|
| 42 |
+
- **Index 1** — CLIP full-image, 15 vectors, category routing
|
| 43 |
+
- **Index 2** — CLIP defect-crop, 5354 vectors, historical retrieval
|
| 44 |
+
- **Index 3** — WideResNet patches, per-category coreset, anomaly scoring
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## Five Modes
|
| 49 |
+
|
| 50 |
+
| Mode | Purpose |
|
| 51 |
+
|------|---------|
|
| 52 |
+
| 🔬 Inspector | Upload image → defect detection + heatmap + report |
|
| 53 |
+
| 🧬 Forensics | Deep XAI on any past case (GradCAM++, SHAP, retrieval trace) |
|
| 54 |
+
| 📊 Analytics | Aggregated stats, Evidently drift monitoring |
|
| 55 |
+
| 🏟️ Arena | Competitive game — beat the AI at defect detection |
|
| 56 |
+
| 📚 Knowledge Base | Browse defect graph, natural language search |
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
## Technical Decisions
|
| 61 |
+
|
| 62 |
+
**Why PatchCore over a trained classifier?**
|
| 63 |
+
Real manufacturing lines do not have labelled defect datasets. Defects are rare, varied, and novel. PatchCore requires only normal samples and learns the distribution of normal patch features. Any deviation at inference is flagged. The scoring mechanism is a nearest-neighbour distance — inherently interpretable with no post-hoc XAI required for localisation.
|
| 64 |
+
|
| 65 |
+
**Why hierarchical RAG over flat search?**
|
| 66 |
+
A flat index over all 5354 images confuses product categories — a bottle scratch and a carpet scratch share visual similarities that cause cross-category retrieval noise. Hierarchical routing first identifies the category via full-image CLIP embeddings, then retrieves within the category-specific subset. Validated on a 50-question evaluation set: flat search Precision@5 = 61%, hierarchical = 93%.
|
| 67 |
+
|
| 68 |
+
**Why three FAISS indexes?**
|
| 69 |
+
Each index operates at a different granularity and serves a different purpose. Index 1 routes at category level. Index 2 retrieves visually similar historical defects for RAG context. Index 3 IS the PatchCore scoring mechanism — one coreset per product category, because each category has its own definition of normal.
|
| 70 |
+
|
| 71 |
+
**Why GradCAM++ over basic Grad-CAM?**
|
| 72 |
+
Basic Grad-CAM uses only positive gradients and produces fragmented activation maps. GradCAM++ uses a weighted combination of both positive and negative gradients, resulting in more focused and anatomically precise localisation maps. Implementation complexity is nearly identical — it is a direct upgrade.
|
| 73 |
+
|
| 74 |
+
**Why SHAP over LIME?**
|
| 75 |
+
SHAP provides theoretically grounded attribution values with the efficiency axiom (values sum to the prediction). LIME is slower and produces less consistent results across runs. For five interpretable features, SHAP is the correct choice.
|
| 76 |
+
|
| 77 |
+
**Why MiDaS-small not MiDaS-large?**
|
| 78 |
+
The depth signal feeds a five-value statistical summary, not a pixel-level task. MiDaS-small produces identical summary statistics at ~80ms CPU vs ~800ms for large. The architecture is model-agnostic — swapping to DPT-Large is one line change when GPU budget allows.
|
| 79 |
+
|
| 80 |
+
**Why coreset subsampling?**
|
| 81 |
+
2.8M patch vectors across all normal training images cannot all live in RAM or be searched efficiently. The greedy k-center coreset selects M representative patches such that every original patch is within bounded distance of a centre. At 1% coreset: 97.81% average AUROC at <5s CPU latency. At 10%: marginal AUROC gain for 10x the storage and latency.
|
| 82 |
+
|
| 83 |
+
**Why DagsHub over plain MLflow?**
|
| 84 |
+
DagsHub provides free hosted MLflow tracking and DVC remote storage under one account. No self-hosted MLflow server required. All experiment runs, model weights, and FAISS indexes are versioned and reproducible from a single `dvc pull`.
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
## Performance
|
| 89 |
+
|
| 90 |
+
### Image AUROC per Category (PatchCore, 1% coreset)
|
| 91 |
+
|
| 92 |
+
| Category | AUROC | | Category | AUROC |
|
| 93 |
+
|----------|-------|-|----------|-------|
|
| 94 |
+
| bottle | 1.0000 ✓ | | pill | 0.9722 ✓ |
|
| 95 |
+
| hazelnut | 1.0000 ✓ | | grid | 0.9816 ✓ |
|
| 96 |
+
| leather | 1.0000 ✓ | | cable | 0.9828 ✓ |
|
| 97 |
+
| tile | 1.0000 ✓ | | carpet | 0.9835 ✓ |
|
| 98 |
+
| metal_nut | 0.9976 ✓ | | wood | 0.9877 ✓ |
|
| 99 |
+
| transistor | 0.9929 ✓ | | capsule | 0.9813 ✓ |
|
| 100 |
+
| zipper | 0.9659 ✓ | | screw | 0.9545 ⚠ |
|
| 101 |
+
| | | | toothbrush | 0.8722 ⚠ |
|
| 102 |
+
|
| 103 |
+
**Average AUROC: 0.9781** (target ≥0.97 ✓)
|
| 104 |
+
|
| 105 |
+
Toothbrush and screw score lower across all PatchCore implementations in the literature — toothbrush has only 60 training images (thin coreset), screw has highly regular fine-grained thread patterns that challenge patch-level matching.
|
| 106 |
+
|
| 107 |
+
### Retrieval Quality
|
| 108 |
+
- **Precision@5 (hierarchical):** 0.9307
|
| 109 |
+
- **Precision@5 (flat baseline):** ~0.61
|
| 110 |
+
- **Improvement:** +32 percentage points from hierarchical routing
|
| 111 |
+
|
| 112 |
+
### Inference Latency (CPU, HF Spaces)
|
| 113 |
+
- End-to-end (excl. LLM): ~3-5s
|
| 114 |
+
- FAISS k-NN search: <5ms
|
| 115 |
+
- CLIP encoding: ~150ms
|
| 116 |
+
- WideResNet extraction: ~200ms
|
| 117 |
+
|
| 118 |
+
---
|
| 119 |
+
|
| 120 |
+
## MLOps
|
| 121 |
+
|
| 122 |
+
### Experiment Tracking (MLflow on DagsHub)
|
| 123 |
+
> Screenshot: [DagsHub MLflow Dashboard](https://dagshub.com/devangmishra1424/AnomalyOS)
|
| 124 |
+
|
| 125 |
+
15+ logged runs across three experiments:
|
| 126 |
+
- PatchCore ablation (coreset % vs AUROC/latency)
|
| 127 |
+
- EfficientNet fine-tuning (10 Optuna trials)
|
| 128 |
+
- Retrieval quality evaluation (Precision@1, Precision@5, MRR)
|
| 129 |
+
|
| 130 |
+
### CI/CD (GitHub Actions)
|
| 131 |
+
Three-stage smoke test on every deploy:
|
| 132 |
+
1. `GET /health` → 200 OK
|
| 133 |
+
2. `POST /inspect` with 224×224 test image → valid response
|
| 134 |
+
3. `GET /metrics` → 200 OK
|
| 135 |
+
|
| 136 |
+
### Data Versioning (DVC + DagsHub)
|
| 137 |
+
All artifacts versioned and reproducible:
|
| 138 |
+
```
|
| 139 |
+
dvc pull # pulls all FAISS indexes, PCA model, thresholds, graph
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### Drift Monitoring (Evidently AI)
|
| 143 |
+
Reference: first 200 inference records.
|
| 144 |
+
Current: most recent 200 records.
|
| 145 |
+
Metrics: anomaly score distribution, predicted category distribution.
|
| 146 |
+
**Note: drift simulation uses injected OOD records for portfolio demonstration.**
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## Limitations
|
| 151 |
+
|
| 152 |
+
- **Dataset bias:** MVTec AD contains Austrian/European industrial products. Performance on other product types or manufacturing contexts is unknown and likely degraded.
|
| 153 |
+
- **Category specificity:** PatchCore builds one coreset per product category. A category not in the 15 MVTec classes requires retraining from scratch.
|
| 154 |
+
- **Retrieval degradation:** Index 2 retrieval precision degrades on novel defect types not present in the training set.
|
| 155 |
+
- **LLM reports unverified:** Groq Llama-3 reports are grounded in retrieved context but not verified by domain experts. Do not use for real industrial decisions.
|
| 156 |
+
- **Drift monitoring simulated:** Evidently drift reports use artificially injected OOD records. Not real production drift.
|
| 157 |
+
- **CPU latency:** 3-5s end-to-end on HF Spaces free tier (no GPU). Architecture is GPU-ready.
|
| 158 |
+
- **Not for production use:** This is a portfolio demonstration project. Not suitable for safety-critical industrial deployment under any circumstances.
|
| 159 |
+
|
| 160 |
---
|
| 161 |
+
|
| 162 |
+
## Bug Log
|
| 163 |
+
|
| 164 |
+
### Bug 1 — Greedy coreset RAM explosion
|
| 165 |
+
**What:** Naive pairwise distance computation over 2.8M patch vectors caused OOM crash during coreset construction. A single distance matrix over 2.8M×256 float32 vectors requires ~5.7GB RAM.
|
| 166 |
+
**Found:** Kaggle notebook killed with OOM error during first coreset build attempt.
|
| 167 |
+
**Fixed:** Batched distance computation in chunks of 10,000 vectors. Peak RAM reduced from ~6GB to ~400MB. Added to `greedy_coreset()` as `batched_l2_distance()`.
|
| 168 |
+
|
| 169 |
+
### Bug 2 — FAISS IndexFlatIP vs IndexFlatL2 for CLIP
|
| 170 |
+
**What:** Used IndexFlatL2 for CLIP embeddings initially. CLIP embeddings are L2-normalised, so L2 distance and cosine similarity are equivalent only when using inner product search. L2 on normalised vectors produces correct rankings but wrong distance values, confusing the similarity score display.
|
| 171 |
+
**Found:** Similarity scores in Index 2 retrieval were showing values >1.0 in the UI.
|
| 172 |
+
**Fixed:** Changed Index 1 and Index 2 to IndexFlatIP. Inner product on L2-normalised vectors = cosine similarity, range [0,1].
|
| 173 |
+
|
| 174 |
+
### Bug 3 — `grayscale_lbp` import error in enrichment pipeline
|
| 175 |
+
**What:** Cell 2 of notebook 01 imported `grayscale_lbp` from `skimage.feature`. This function does not exist — the correct function is `local_binary_pattern`.
|
| 176 |
+
**Found:** ImportError on Cell 2 execution.
|
| 177 |
+
**Fixed:** Replaced all `grayscale_lbp` imports with `from skimage.feature import local_binary_pattern`.
|
| 178 |
+
|
| 179 |
+
---
|
| 180 |
+
|
| 181 |
+
## Setup & Reproduction
|
| 182 |
+
```bash
|
| 183 |
+
# 1. Clone
|
| 184 |
+
git clone https://github.com/devangmishra1424/AnomalyOS.git
|
| 185 |
+
cd AnomalyOS
|
| 186 |
+
|
| 187 |
+
# 2. Pull all artifacts (FAISS indexes, PCA model, thresholds, graph)
|
| 188 |
+
dvc pull
|
| 189 |
+
|
| 190 |
+
# 3. Install dependencies
|
| 191 |
+
pip install -r requirements.txt
|
| 192 |
+
|
| 193 |
+
# 4. Set environment variables
|
| 194 |
+
export HF_TOKEN=your_token
|
| 195 |
+
export GROQ_API_KEY=your_key
|
| 196 |
+
export DAGSHUB_TOKEN=your_token
|
| 197 |
+
|
| 198 |
+
# 5. Launch API
|
| 199 |
+
uvicorn api.main:app --host 0.0.0.0 --port 7860
|
| 200 |
+
|
| 201 |
+
# 6. Launch Gradio (separate terminal)
|
| 202 |
+
python app.py
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
---
|
| 206 |
+
|
| 207 |
+
## Project Structure
|
| 208 |
+
```
|
| 209 |
+
AnomalyOS/
|
| 210 |
+
├── notebooks/ # Kaggle training notebooks (01-05)
|
| 211 |
+
├── src/ # Core ML: patchcore, orchestrator, xai, llm
|
| 212 |
+
├── api/ # FastAPI: endpoints, schemas, startup, logger
|
| 213 |
+
├── mlops/ # Evidently, Optuna, retrieval evaluation
|
| 214 |
+
├── tests/ # pytest suite (5 test files)
|
| 215 |
+
├── data/ # DVC-tracked: FAISS indexes, graph, thresholds
|
| 216 |
+
├── models/ # DVC-tracked: PCA model, EfficientNet weights
|
| 217 |
+
├── app.py # Gradio frontend (5 tabs)
|
| 218 |
+
└── docker/Dockerfile # python:3.11-slim, port 7860
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
---
|
| 222 |
|
| 223 |
+
*Built by Devang Pradeep Mishra | [GitHub](https://github.com/devangmishra1424) | [HuggingFace](https://huggingface.co/CaffeinatedCoding)*
|
api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Package initializer for API module
|
api/logger.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/logger.py
|
| 2 |
+
# Two-layer durable logging strategy
|
| 3 |
+
#
|
| 4 |
+
# Layer 1: Local JSONL (fast write, ephemeral — wiped on HF Space restart)
|
| 5 |
+
# Used by Evidently drift scripts
|
| 6 |
+
# Layer 2: HF Dataset API push (durable, survives restarts)
|
| 7 |
+
# Called as FastAPI BackgroundTask — never blocks response
|
| 8 |
+
#
|
| 9 |
+
# If HF push fails: user unaffected, local log still written
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
import time
|
| 14 |
+
from datetime import datetime, timezone
|
| 15 |
+
from huggingface_hub import HfApi
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
HF_REPO_ID = os.environ.get("HF_LOG_REPO", "CaffeinatedCoding/anomalyos-logs")
|
| 19 |
+
LOCAL_LOG_DIR = "logs"
|
| 20 |
+
LOCAL_LOG_PATH = os.path.join(LOCAL_LOG_DIR, "inference.jsonl")
|
| 21 |
+
|
| 22 |
+
_hf_api: HfApi = None
|
| 23 |
+
_hf_push_failure_count: int = 0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def init_logger(hf_token: str):
|
| 27 |
+
"""Called once at FastAPI startup."""
|
| 28 |
+
global _hf_api
|
| 29 |
+
os.makedirs(LOCAL_LOG_DIR, exist_ok=True)
|
| 30 |
+
if hf_token:
|
| 31 |
+
_hf_api = HfApi(token=hf_token)
|
| 32 |
+
print(f"Logger initialised | HF repo: {HF_REPO_ID}")
|
| 33 |
+
else:
|
| 34 |
+
print("WARNING: HF_TOKEN not set — only local logging active")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def log_inference(record: dict):
|
| 38 |
+
"""
|
| 39 |
+
Layer 1: write to local JSONL synchronously.
|
| 40 |
+
Called as BackgroundTask from FastAPI — does not block response.
|
| 41 |
+
"""
|
| 42 |
+
global _hf_push_failure_count
|
| 43 |
+
|
| 44 |
+
# Ensure timestamp
|
| 45 |
+
if "timestamp" not in record:
|
| 46 |
+
record["timestamp"] = datetime.now(timezone.utc).isoformat()
|
| 47 |
+
|
| 48 |
+
# ── Layer 1: Local JSONL ──────────────────────────────────
|
| 49 |
+
try:
|
| 50 |
+
with open(LOCAL_LOG_PATH, "a") as f:
|
| 51 |
+
f.write(json.dumps(record) + "\n")
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"Local log write failed: {e}")
|
| 54 |
+
|
| 55 |
+
# ── Layer 2: HF Dataset push ─────────────────────────────
|
| 56 |
+
if _hf_api is None:
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
ts = record.get("timestamp", datetime.now(timezone.utc).isoformat())
|
| 61 |
+
# Sanitise timestamp for filename
|
| 62 |
+
ts_safe = ts.replace(":", "-").replace(".", "-")[:26]
|
| 63 |
+
path_in_repo = f"inference_logs/{ts_safe}_{record.get('image_hash', 'unknown')[:8]}.json"
|
| 64 |
+
|
| 65 |
+
_hf_api.upload_file(
|
| 66 |
+
path_or_fileobj=json.dumps(record, indent=2).encode("utf-8"),
|
| 67 |
+
path_in_repo=path_in_repo,
|
| 68 |
+
repo_id=HF_REPO_ID,
|
| 69 |
+
repo_type="dataset"
|
| 70 |
+
)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
_hf_push_failure_count += 1
|
| 73 |
+
print(f"HF Dataset push failed (count={_hf_push_failure_count}): {e}")
|
| 74 |
+
# User response is completely unaffected — local log already written
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def log_arena_submission(record: dict):
|
| 78 |
+
"""Log Arena Mode submissions to shared leaderboard dataset."""
|
| 79 |
+
record["log_type"] = "arena"
|
| 80 |
+
log_inference(record)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def log_correction(record: dict):
|
| 84 |
+
"""Log user corrections from /correct/{case_id}."""
|
| 85 |
+
record["log_type"] = "correction"
|
| 86 |
+
log_inference(record)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_recent_logs(n: int = 200) -> list:
|
| 90 |
+
"""
|
| 91 |
+
Read last n records from local JSONL.
|
| 92 |
+
Used by Evidently drift scripts.
|
| 93 |
+
"""
|
| 94 |
+
if not os.path.exists(LOCAL_LOG_PATH):
|
| 95 |
+
return []
|
| 96 |
+
records = []
|
| 97 |
+
try:
|
| 98 |
+
with open(LOCAL_LOG_PATH) as f:
|
| 99 |
+
for line in f:
|
| 100 |
+
line = line.strip()
|
| 101 |
+
if line:
|
| 102 |
+
records.append(json.loads(line))
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Error reading local log: {e}")
|
| 105 |
+
return records[-n:]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_push_failure_count() -> int:
|
| 109 |
+
return _hf_push_failure_count
|
api/main.py
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/main.py
|
| 2 |
+
# FastAPI application — 9 endpoints
|
| 3 |
+
# Models loaded once at startup via lifespan, kept in memory
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import io
|
| 7 |
+
import time
|
| 8 |
+
import hashlib
|
| 9 |
+
from contextlib import asynccontextmanager
|
| 10 |
+
from contextvars import ContextVar
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
|
| 14 |
+
from fastapi.responses import JSONResponse
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from api.startup import load_all, get_uptime, MODEL_VERSION
|
| 19 |
+
from api.schemas import (
|
| 20 |
+
InspectResponse, ReportResponse, ForensicsResponse,
|
| 21 |
+
KnowledgeSearchResponse, ArenaCase, ArenaSubmitRequest,
|
| 22 |
+
ArenaSubmitResponse, CorrectionRequest, CorrectionResponse,
|
| 23 |
+
HealthResponse, MetricsResponse
|
| 24 |
+
)
|
| 25 |
+
from api.logger import (
|
| 26 |
+
log_inference, log_arena_submission, log_correction,
|
| 27 |
+
get_push_failure_count
|
| 28 |
+
)
|
| 29 |
+
from src.orchestrator import run_inspection
|
| 30 |
+
from src.retriever import retriever
|
| 31 |
+
from src.graph import knowledge_graph
|
| 32 |
+
from src.xai import gradcam, shap_explainer, heatmap_to_base64, image_to_base64
|
| 33 |
+
from src.llm import get_report, generate_report
|
| 34 |
+
from src.cache import inference_cache, get_image_hash
|
| 35 |
+
|
| 36 |
+
import psutil
|
| 37 |
+
import random
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ── Request-scoped state via ContextVar ──────────────────────
|
| 41 |
+
# Prevents race conditions under concurrent requests
|
| 42 |
+
# Never use global mutable state for per-request data
|
| 43 |
+
request_session_id: ContextVar[str] = ContextVar("session_id", default="")
|
| 44 |
+
|
| 45 |
+
# ── Metrics counters ─────────────────────────────────────────
|
| 46 |
+
_metrics = {
|
| 47 |
+
"request_count": 0,
|
| 48 |
+
"latencies": [],
|
| 49 |
+
"hf_push_failure_count": 0
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
# ── Precompute store (speculative CLIP encoding) ──────────────
|
| 53 |
+
_precompute_store: dict = {}
|
| 54 |
+
|
| 55 |
+
# ── Arena leaderboard (in-memory, persisted to HF Dataset) ───
|
| 56 |
+
_arena_streaks: dict = {}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@asynccontextmanager
|
| 60 |
+
async def lifespan(app: FastAPI):
|
| 61 |
+
"""Load all models at startup. Nothing else runs before this."""
|
| 62 |
+
load_all()
|
| 63 |
+
yield
|
| 64 |
+
# Cleanup on shutdown (not critical but clean)
|
| 65 |
+
inference_cache.clear()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
app = FastAPI(
|
| 69 |
+
title="AnomalyOS",
|
| 70 |
+
description="Industrial Visual Anomaly Detection Platform",
|
| 71 |
+
version=MODEL_VERSION,
|
| 72 |
+
lifespan=lifespan
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ── Helpers ───────────────────────────────────────────────────
|
| 77 |
+
VALID_CATEGORIES = [
|
| 78 |
+
'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut',
|
| 79 |
+
'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush',
|
| 80 |
+
'transistor', 'wood', 'zipper'
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _validate_image(file: UploadFile, image_bytes: bytes) -> Image.Image:
|
| 87 |
+
"""
|
| 88 |
+
Validate uploaded image. Raises HTTPException on any failure.
|
| 89 |
+
Model is never called on invalid input.
|
| 90 |
+
"""
|
| 91 |
+
# File type
|
| 92 |
+
if file.content_type not in ("image/jpeg", "image/png"):
|
| 93 |
+
raise HTTPException(status_code=422,
|
| 94 |
+
detail="Only jpg/png accepted")
|
| 95 |
+
# File size
|
| 96 |
+
if len(image_bytes) > MAX_FILE_SIZE:
|
| 97 |
+
raise HTTPException(status_code=413,
|
| 98 |
+
detail="Max file size is 10MB")
|
| 99 |
+
# Zero-byte
|
| 100 |
+
if len(image_bytes) == 0:
|
| 101 |
+
raise HTTPException(status_code=422,
|
| 102 |
+
detail="Image file is empty")
|
| 103 |
+
# Decode
|
| 104 |
+
try:
|
| 105 |
+
pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 106 |
+
except Exception:
|
| 107 |
+
raise HTTPException(status_code=422,
|
| 108 |
+
detail="Could not decode image")
|
| 109 |
+
# Too small
|
| 110 |
+
if pil_img.size[0] < 32 or pil_img.size[1] < 32:
|
| 111 |
+
raise HTTPException(status_code=422,
|
| 112 |
+
detail="Image too small for inspection")
|
| 113 |
+
return pil_img
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _record_latency(latency_ms: float):
|
| 117 |
+
_metrics["request_count"] += 1
|
| 118 |
+
_metrics["latencies"].append(latency_ms)
|
| 119 |
+
if len(_metrics["latencies"]) > 1000:
|
| 120 |
+
_metrics["latencies"] = _metrics["latencies"][-500:]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# ── POST /inspect ─────────────────────────────────────────────
|
| 124 |
+
@app.post("/inspect", response_model=InspectResponse)
|
| 125 |
+
async def inspect(
|
| 126 |
+
background_tasks: BackgroundTasks,
|
| 127 |
+
image: UploadFile = File(...),
|
| 128 |
+
category_hint: Optional[str] = Form(None),
|
| 129 |
+
session_id: Optional[str] = Form(None)
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Main inspection endpoint.
|
| 133 |
+
Accepts: multipart form (image + optional metadata)
|
| 134 |
+
Returns: anomaly result immediately, LLM report polled separately
|
| 135 |
+
"""
|
| 136 |
+
# Validate category hint
|
| 137 |
+
if category_hint and category_hint not in VALID_CATEGORIES:
|
| 138 |
+
raise HTTPException(status_code=422,
|
| 139 |
+
detail=f"Invalid category_hint: {category_hint}")
|
| 140 |
+
|
| 141 |
+
image_bytes = await image.read()
|
| 142 |
+
pil_img = _validate_image(image, image_bytes)
|
| 143 |
+
|
| 144 |
+
# Run full orchestrator pipeline
|
| 145 |
+
result = run_inspection(
|
| 146 |
+
pil_img=pil_img,
|
| 147 |
+
image_bytes=image_bytes,
|
| 148 |
+
category_hint=category_hint
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Queue LLM report generation (non-blocking)
|
| 152 |
+
if result.report_id and result.is_anomalous:
|
| 153 |
+
background_tasks.add_task(
|
| 154 |
+
generate_report,
|
| 155 |
+
result.report_id,
|
| 156 |
+
result.category,
|
| 157 |
+
result.score,
|
| 158 |
+
result.similar_cases,
|
| 159 |
+
result.graph_context
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Log inference (non-blocking)
|
| 163 |
+
image_hash = get_image_hash(image_bytes)
|
| 164 |
+
log_record = {
|
| 165 |
+
"mode": "inspector",
|
| 166 |
+
"image_hash": image_hash,
|
| 167 |
+
"category": result.category,
|
| 168 |
+
"anomaly_score": result.score,
|
| 169 |
+
"calibrated_score": result.calibrated_score,
|
| 170 |
+
"is_anomalous": result.is_anomalous,
|
| 171 |
+
"latency_ms": result.latency_ms,
|
| 172 |
+
"model_version": MODEL_VERSION,
|
| 173 |
+
"report_id": result.report_id
|
| 174 |
+
}
|
| 175 |
+
background_tasks.add_task(log_inference, log_record)
|
| 176 |
+
_record_latency(result.latency_ms)
|
| 177 |
+
|
| 178 |
+
return InspectResponse(
|
| 179 |
+
is_anomalous=result.is_anomalous,
|
| 180 |
+
anomaly_score=result.score,
|
| 181 |
+
calibrated_score=result.calibrated_score,
|
| 182 |
+
score_std=result.score_std,
|
| 183 |
+
category=result.category,
|
| 184 |
+
model_version=MODEL_VERSION,
|
| 185 |
+
heatmap_b64=result.heatmap_b64,
|
| 186 |
+
defect_crop_b64=result.defect_crop_b64,
|
| 187 |
+
depth_map_b64=result.depth_map_b64,
|
| 188 |
+
similar_cases=result.similar_cases,
|
| 189 |
+
graph_context=result.graph_context,
|
| 190 |
+
shap_features=result.shap_features,
|
| 191 |
+
report_id=result.report_id,
|
| 192 |
+
latency_ms=result.latency_ms,
|
| 193 |
+
image_hash=image_hash,
|
| 194 |
+
low_confidence=result.calibrated_score < 0.3
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ── GET /report/{report_id} ───────────────────────────────────
|
| 199 |
+
@app.get("/report/{report_id}", response_model=ReportResponse)
|
| 200 |
+
async def get_report_status(report_id: str):
|
| 201 |
+
"""
|
| 202 |
+
Poll LLM report status.
|
| 203 |
+
Frontend polls every 500ms until status == 'ready'.
|
| 204 |
+
"""
|
| 205 |
+
result = get_report(report_id)
|
| 206 |
+
return ReportResponse(
|
| 207 |
+
status=result["status"],
|
| 208 |
+
report=result.get("report")
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# ── POST /forensics/{case_id} ─────────────────────────────────
|
| 213 |
+
@app.post("/forensics/{case_id}", response_model=ForensicsResponse)
|
| 214 |
+
async def forensics(
|
| 215 |
+
case_id: str,
|
| 216 |
+
coreset_pct: Optional[float] = None
|
| 217 |
+
):
|
| 218 |
+
"""
|
| 219 |
+
Deep XAI analysis of a previously logged case.
|
| 220 |
+
Loads case from cache or HF Dataset, runs full XAI suite.
|
| 221 |
+
coreset_pct: optional ablation parameter (0.001-0.1)
|
| 222 |
+
"""
|
| 223 |
+
if coreset_pct is not None and not (0.001 <= coreset_pct <= 0.1):
|
| 224 |
+
raise HTTPException(status_code=422,
|
| 225 |
+
detail="coreset_pct must be between 0.001 and 0.1")
|
| 226 |
+
|
| 227 |
+
# Load case from cache
|
| 228 |
+
cached = inference_cache.get(case_id)
|
| 229 |
+
if not cached:
|
| 230 |
+
raise HTTPException(status_code=422,
|
| 231 |
+
detail="Case not found. Run inspection first.")
|
| 232 |
+
|
| 233 |
+
# GradCAM++ (runs here, not in Inspector)
|
| 234 |
+
gradcam_b64 = None
|
| 235 |
+
if cached.get("_pil_img"):
|
| 236 |
+
cam = gradcam.compute(cached["_pil_img"])
|
| 237 |
+
if cam is not None:
|
| 238 |
+
gradcam_b64 = heatmap_to_base64(cam, cached["_pil_img"])
|
| 239 |
+
|
| 240 |
+
# Retrieval trace — enrich similar cases with similarity scores
|
| 241 |
+
retrieval_trace = []
|
| 242 |
+
for case in cached.get("similar_cases", []):
|
| 243 |
+
retrieval_trace.append({
|
| 244 |
+
"case_id": case.get("image_hash", "")[:12],
|
| 245 |
+
"category": case.get("category"),
|
| 246 |
+
"defect_type": case.get("defect_type"),
|
| 247 |
+
"similarity_score": case.get("similarity_score"),
|
| 248 |
+
"graph_path": _format_graph_path(
|
| 249 |
+
case.get("category"),
|
| 250 |
+
case.get("defect_type")
|
| 251 |
+
)
|
| 252 |
+
})
|
| 253 |
+
|
| 254 |
+
return ForensicsResponse(
|
| 255 |
+
case_id=case_id,
|
| 256 |
+
category=cached.get("category", "unknown"),
|
| 257 |
+
anomaly_score=cached.get("score", 0.0),
|
| 258 |
+
calibrated_score=cached.get("calibrated_score", 0.0),
|
| 259 |
+
patch_scores_grid=cached.get("patch_scores_grid", []),
|
| 260 |
+
gradcampp_b64=gradcam_b64,
|
| 261 |
+
shap_features=cached.get("shap_features", {}),
|
| 262 |
+
similar_cases=cached.get("similar_cases", []),
|
| 263 |
+
graph_context=cached.get("graph_context", {}),
|
| 264 |
+
retrieval_trace=retrieval_trace
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _format_graph_path(category: str, defect_type: str) -> str:
|
| 269 |
+
"""Format 2-hop graph path as plain text for Forensics trace."""
|
| 270 |
+
if not category or not defect_type:
|
| 271 |
+
return "unknown"
|
| 272 |
+
ctx = knowledge_graph.get_context(category, defect_type)
|
| 273 |
+
rcs = ctx.get("root_causes", [])
|
| 274 |
+
rems = ctx.get("remediations", [])
|
| 275 |
+
if rcs and rems:
|
| 276 |
+
return f"caused_by: {rcs[0]} → remediated_by: {rems[0]}"
|
| 277 |
+
elif rcs:
|
| 278 |
+
return f"caused_by: {rcs[0]}"
|
| 279 |
+
return "no graph path found"
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# ── GET /knowledge/search ─────────────────────────────────────
|
| 283 |
+
@app.get("/knowledge/search", response_model=KnowledgeSearchResponse)
|
| 284 |
+
async def knowledge_search(
|
| 285 |
+
category: Optional[str] = None,
|
| 286 |
+
defect_type: Optional[str] = None,
|
| 287 |
+
severity_min: Optional[float] = None,
|
| 288 |
+
severity_max: Optional[float] = None,
|
| 289 |
+
query: Optional[str] = None
|
| 290 |
+
):
|
| 291 |
+
"""
|
| 292 |
+
Search defect knowledge base.
|
| 293 |
+
Natural language query → MiniLM embed → Index 2 search.
|
| 294 |
+
Filters: category, defect_type, severity range.
|
| 295 |
+
"""
|
| 296 |
+
all_defects = knowledge_graph.get_all_defect_nodes()
|
| 297 |
+
results = all_defects
|
| 298 |
+
|
| 299 |
+
# Filter by category
|
| 300 |
+
if category:
|
| 301 |
+
results = [r for r in results if r.get("category") == category]
|
| 302 |
+
|
| 303 |
+
# Filter by defect type
|
| 304 |
+
if defect_type:
|
| 305 |
+
results = [r for r in results
|
| 306 |
+
if defect_type.lower() in r.get("defect_type", "").lower()]
|
| 307 |
+
|
| 308 |
+
# Filter by severity
|
| 309 |
+
if severity_min is not None:
|
| 310 |
+
results = [r for r in results
|
| 311 |
+
if r.get("severity_min", 0) >= severity_min]
|
| 312 |
+
if severity_max is not None:
|
| 313 |
+
results = [r for r in results
|
| 314 |
+
if r.get("severity_max", 1) <= severity_max]
|
| 315 |
+
|
| 316 |
+
# Natural language search via Index 2
|
| 317 |
+
if query and retriever.index2 is not None:
|
| 318 |
+
try:
|
| 319 |
+
from sentence_transformers import SentenceTransformer
|
| 320 |
+
_mini_lm = SentenceTransformer("all-MiniLM-L6-v2")
|
| 321 |
+
query_emb = _mini_lm.encode([query])[0].astype("float32")
|
| 322 |
+
query_emb = query_emb / (np.linalg.norm(query_emb) + 1e-8)
|
| 323 |
+
# Pad or truncate to 512 dims to match Index 2
|
| 324 |
+
if len(query_emb) < 512:
|
| 325 |
+
query_emb = np.pad(query_emb, (0, 512 - len(query_emb)))
|
| 326 |
+
else:
|
| 327 |
+
query_emb = query_emb[:512]
|
| 328 |
+
D, I = retriever.index2.search(query_emb.reshape(1, -1), k=10)
|
| 329 |
+
nl_results = [retriever.index2_metadata[i]
|
| 330 |
+
for i in I[0] if i >= 0]
|
| 331 |
+
results = nl_results if nl_results else results
|
| 332 |
+
except Exception as e:
|
| 333 |
+
print(f"NL search failed: {e} — using filter results")
|
| 334 |
+
|
| 335 |
+
return KnowledgeSearchResponse(
|
| 336 |
+
results=results[:50],
|
| 337 |
+
total_found=len(results),
|
| 338 |
+
query=query or ""
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# ── GET /arena/next_case ──────────────────────────────────────
|
| 343 |
+
@app.get("/arena/next_case", response_model=ArenaCase)
|
| 344 |
+
async def arena_next_case(expert_mode: bool = False):
|
| 345 |
+
"""
|
| 346 |
+
Returns next Arena challenge image.
|
| 347 |
+
Expert mode: cases with calibrated_score in [0.45, 0.55] (maximum uncertainty)
|
| 348 |
+
"""
|
| 349 |
+
import os
|
| 350 |
+
from src.cache import pil_to_bytes
|
| 351 |
+
import base64
|
| 352 |
+
|
| 353 |
+
MVTEC_PATH = os.environ.get("MVTEC_PATH", "/app/data/mvtec")
|
| 354 |
+
categories = VALID_CATEGORIES
|
| 355 |
+
|
| 356 |
+
# Pick a random category and image
|
| 357 |
+
cat = random.choice(categories)
|
| 358 |
+
split = random.choice(["train", "test"])
|
| 359 |
+
|
| 360 |
+
if split == "train":
|
| 361 |
+
img_dir = os.path.join(MVTEC_PATH, cat, "train", "good")
|
| 362 |
+
else:
|
| 363 |
+
defect_types = os.listdir(os.path.join(MVTEC_PATH, cat, "test"))
|
| 364 |
+
defect_type = random.choice(defect_types)
|
| 365 |
+
img_dir = os.path.join(MVTEC_PATH, cat, "test", defect_type)
|
| 366 |
+
|
| 367 |
+
if not os.path.exists(img_dir):
|
| 368 |
+
raise HTTPException(status_code=500, detail="Dataset not mounted")
|
| 369 |
+
|
| 370 |
+
files = [f for f in os.listdir(img_dir)
|
| 371 |
+
if f.endswith((".png", ".jpg", ".jpeg"))]
|
| 372 |
+
if not files:
|
| 373 |
+
raise HTTPException(status_code=500, detail="No images found")
|
| 374 |
+
|
| 375 |
+
fname = random.choice(files)
|
| 376 |
+
img_path = os.path.join(img_dir, fname)
|
| 377 |
+
pil_img = Image.open(img_path).convert("RGB")
|
| 378 |
+
|
| 379 |
+
# Generate case_id from path hash
|
| 380 |
+
case_id = hashlib.sha256(img_path.encode()).hexdigest()[:16]
|
| 381 |
+
|
| 382 |
+
# Cache the image path for submit endpoint
|
| 383 |
+
_precompute_store[case_id] = {
|
| 384 |
+
"img_path": img_path,
|
| 385 |
+
"category": cat,
|
| 386 |
+
"is_defective": split == "test" and defect_type != "good"
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
img_b64 = image_to_base64(pil_img)
|
| 390 |
+
|
| 391 |
+
return ArenaCase(
|
| 392 |
+
case_id=case_id,
|
| 393 |
+
image_b64=img_b64,
|
| 394 |
+
expert_mode=expert_mode
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# ── POST /arena/submit/{case_id} ──────────────────────────────
|
| 399 |
+
@app.post("/arena/submit/{case_id}", response_model=ArenaSubmitResponse)
|
| 400 |
+
async def arena_submit(
|
| 401 |
+
case_id: str,
|
| 402 |
+
request: ArenaSubmitRequest,
|
| 403 |
+
background_tasks: BackgroundTasks
|
| 404 |
+
):
|
| 405 |
+
"""Submit Arena answer. Returns AI result + user score + SHAP explanation."""
|
| 406 |
+
case_info = _precompute_store.get(case_id)
|
| 407 |
+
if not case_info:
|
| 408 |
+
raise HTTPException(status_code=422, detail="Case not found")
|
| 409 |
+
|
| 410 |
+
pil_img = Image.open(case_info["img_path"]).convert("RGB")
|
| 411 |
+
image_bytes = pil_to_bytes(pil_img)
|
| 412 |
+
|
| 413 |
+
result = run_inspection(pil_img=pil_img, image_bytes=image_bytes)
|
| 414 |
+
|
| 415 |
+
correct_label = 1 if case_info["is_defective"] else 0
|
| 416 |
+
user_correct = int(request.user_rating == correct_label)
|
| 417 |
+
|
| 418 |
+
# Severity score: 1 if within 1 of AI severity, 0 otherwise
|
| 419 |
+
ai_severity = round(result.calibrated_score * 5)
|
| 420 |
+
sev_score = 1 if abs(request.user_severity - ai_severity) <= 1 else 0
|
| 421 |
+
user_score = float(user_correct + sev_score * 0.5)
|
| 422 |
+
|
| 423 |
+
# Streak tracking
|
| 424 |
+
session = request.session_id or "anonymous"
|
| 425 |
+
streak = _arena_streaks.get(session, 0)
|
| 426 |
+
if user_correct:
|
| 427 |
+
streak += 1
|
| 428 |
+
else:
|
| 429 |
+
streak = 0
|
| 430 |
+
_arena_streaks[session] = streak
|
| 431 |
+
|
| 432 |
+
# Top 2 SHAP features for post-submission explanation
|
| 433 |
+
shap_data = result.shap_features
|
| 434 |
+
top_shap = []
|
| 435 |
+
if shap_data.get("feature_names"):
|
| 436 |
+
pairs = list(zip(shap_data["feature_names"],
|
| 437 |
+
shap_data["shap_values"]))
|
| 438 |
+
pairs.sort(key=lambda x: abs(x[1]), reverse=True)
|
| 439 |
+
top_shap = [{"feature": p[0], "contribution": round(p[1], 4)}
|
| 440 |
+
for p in pairs[:2]]
|
| 441 |
+
|
| 442 |
+
# Log
|
| 443 |
+
background_tasks.add_task(log_arena_submission, {
|
| 444 |
+
"case_id": case_id,
|
| 445 |
+
"user_rating": request.user_rating,
|
| 446 |
+
"ai_decision": int(result.is_anomalous),
|
| 447 |
+
"user_score": user_score,
|
| 448 |
+
"streak": streak,
|
| 449 |
+
"session_id": session
|
| 450 |
+
})
|
| 451 |
+
|
| 452 |
+
return ArenaSubmitResponse(
|
| 453 |
+
correct_label=correct_label,
|
| 454 |
+
ai_score=result.score,
|
| 455 |
+
calibrated_score=result.calibrated_score,
|
| 456 |
+
user_score=user_score,
|
| 457 |
+
streak=streak,
|
| 458 |
+
top_shap_features=top_shap,
|
| 459 |
+
heatmap_b64=result.heatmap_b64,
|
| 460 |
+
is_expert_case=0.45 <= result.calibrated_score <= 0.55
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# ── POST /correct/{case_id} ───────────────────────────────────
|
| 465 |
+
@app.post("/correct/{case_id}", response_model=CorrectionResponse)
|
| 466 |
+
async def submit_correction(
|
| 467 |
+
case_id: str,
|
| 468 |
+
request: CorrectionRequest,
|
| 469 |
+
background_tasks: BackgroundTasks
|
| 470 |
+
):
|
| 471 |
+
"""
|
| 472 |
+
User correction widget backend.
|
| 473 |
+
Every correction logged with user_override=True flag.
|
| 474 |
+
Interview line: "Corrections can seed a future active learning cycle."
|
| 475 |
+
"""
|
| 476 |
+
background_tasks.add_task(log_correction, {
|
| 477 |
+
"case_id": case_id,
|
| 478 |
+
"correction_type": request.correction_type,
|
| 479 |
+
"note": request.note,
|
| 480 |
+
"user_override": True
|
| 481 |
+
})
|
| 482 |
+
return CorrectionResponse(status="correction_logged", case_id=case_id)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
# ── GET /health ───────────────────────────────────────────────
|
| 486 |
+
@app.get("/health", response_model=HealthResponse)
|
| 487 |
+
async def health():
|
| 488 |
+
"""
|
| 489 |
+
Health check — called by GitHub Actions smoke test after every deploy.
|
| 490 |
+
Returns 503 if any critical index failed to load at startup.
|
| 491 |
+
"""
|
| 492 |
+
index_status = retriever.get_status()
|
| 493 |
+
|
| 494 |
+
# Critical check: Index 1 and Index 2 must be loaded
|
| 495 |
+
if index_status["index1_vectors"] == 0:
|
| 496 |
+
raise HTTPException(status_code=503,
|
| 497 |
+
detail="Index 1 not loaded — startup failed")
|
| 498 |
+
if index_status["index2_vectors"] == 0:
|
| 499 |
+
raise HTTPException(status_code=503,
|
| 500 |
+
detail="Index 2 not loaded — startup failed")
|
| 501 |
+
|
| 502 |
+
return HealthResponse(
|
| 503 |
+
status="ok",
|
| 504 |
+
model_version=MODEL_VERSION,
|
| 505 |
+
uptime_seconds=round(get_uptime(), 1),
|
| 506 |
+
index_sizes=index_status,
|
| 507 |
+
coreset_size=sum(
|
| 508 |
+
retriever.index3_cache[cat].ntotal
|
| 509 |
+
for cat in retriever.index3_cache
|
| 510 |
+
),
|
| 511 |
+
threshold_config_version="v1.0",
|
| 512 |
+
cache_stats=inference_cache.stats()
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
# ── GET /metrics ──────────────────────────────────────────────
|
| 517 |
+
@app.get("/metrics", response_model=MetricsResponse)
|
| 518 |
+
async def metrics():
|
| 519 |
+
"""
|
| 520 |
+
Prometheus-style observability endpoint.
|
| 521 |
+
Tracked by GitHub Actions smoke test 3.
|
| 522 |
+
"""
|
| 523 |
+
lats = _metrics["latencies"]
|
| 524 |
+
p50 = float(np.percentile(lats, 50)) if lats else 0.0
|
| 525 |
+
p95 = float(np.percentile(lats, 95)) if lats else 0.0
|
| 526 |
+
|
| 527 |
+
mem = psutil.Process().memory_info().rss / 1024 / 1024
|
| 528 |
+
|
| 529 |
+
return MetricsResponse(
|
| 530 |
+
request_count=_metrics["request_count"],
|
| 531 |
+
latency_p50_ms=round(p50, 1),
|
| 532 |
+
latency_p95_ms=round(p95, 1),
|
| 533 |
+
cache_hit_rate=inference_cache.stats()["hit_rate"],
|
| 534 |
+
hf_push_failure_count=get_push_failure_count(),
|
| 535 |
+
memory_usage_mb=round(mem, 1)
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
# ── GET /precompute ───────────────────────────────────────────
|
| 540 |
+
@app.post("/precompute")
|
| 541 |
+
async def precompute(
|
| 542 |
+
image: UploadFile = File(...),
|
| 543 |
+
session_id: str = Form(...)
|
| 544 |
+
):
|
| 545 |
+
"""
|
| 546 |
+
Speculative CLIP encoding — fired by Gradio onChange before user clicks Inspect.
|
| 547 |
+
Runs Index 1 category routing only.
|
| 548 |
+
Result stored keyed by session_id — /inspect checks this first.
|
| 549 |
+
"""
|
| 550 |
+
image_bytes = await image.read()
|
| 551 |
+
try:
|
| 552 |
+
pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 553 |
+
from src.orchestrator import _get_clip_embedding
|
| 554 |
+
clip_full = _get_clip_embedding(pil_img, mode="full")
|
| 555 |
+
cat_result = retriever.route_category(clip_full)
|
| 556 |
+
_precompute_store[session_id] = {
|
| 557 |
+
"category": cat_result["category"],
|
| 558 |
+
"confidence": cat_result["confidence"]
|
| 559 |
+
}
|
| 560 |
+
except Exception:
|
| 561 |
+
pass # Speculative — failure is silent, /inspect handles normally
|
| 562 |
+
return {"status": "queued"}
|
api/schemas.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/schemas.py
|
| 2 |
+
# Pydantic request and response models for all 7 endpoints
|
| 3 |
+
# Validation happens here — model is never called on invalid input
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field, validator
|
| 6 |
+
from typing import Optional, List, Any
|
| 7 |
+
from enum import Enum
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
VALID_CATEGORIES = [
|
| 11 |
+
'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut',
|
| 12 |
+
'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush',
|
| 13 |
+
'transistor', 'wood', 'zipper'
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ── /inspect ─────────────────────────────────────────────────
|
| 18 |
+
class InspectResponse(BaseModel):
|
| 19 |
+
# Core result
|
| 20 |
+
is_anomalous: bool
|
| 21 |
+
anomaly_score: float = Field(..., ge=0.0)
|
| 22 |
+
calibrated_score: float = Field(..., ge=0.0, le=1.0)
|
| 23 |
+
score_std: float
|
| 24 |
+
category: str
|
| 25 |
+
model_version: str
|
| 26 |
+
|
| 27 |
+
# Visuals (base64 PNG strings)
|
| 28 |
+
heatmap_b64: Optional[str] = None
|
| 29 |
+
defect_crop_b64: Optional[str] = None
|
| 30 |
+
depth_map_b64: Optional[str] = None
|
| 31 |
+
|
| 32 |
+
# Retrieval
|
| 33 |
+
similar_cases: List[dict] = []
|
| 34 |
+
|
| 35 |
+
# Graph context
|
| 36 |
+
graph_context: dict = {}
|
| 37 |
+
|
| 38 |
+
# XAI
|
| 39 |
+
shap_features: dict = {}
|
| 40 |
+
|
| 41 |
+
# LLM report (polled separately)
|
| 42 |
+
report_id: Optional[str] = None
|
| 43 |
+
|
| 44 |
+
# Meta
|
| 45 |
+
latency_ms: float
|
| 46 |
+
image_hash: str
|
| 47 |
+
low_confidence: bool = False # calibrated_score < 0.3
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ── /report/{report_id} ──────────────────────────────────────
|
| 51 |
+
class ReportResponse(BaseModel):
|
| 52 |
+
status: str # "pending" | "ready" | "not_found"
|
| 53 |
+
report: Optional[str] = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ── /forensics/{case_id} ─────────────────────────────────────
|
| 57 |
+
class ForensicsResponse(BaseModel):
|
| 58 |
+
case_id: str
|
| 59 |
+
category: str
|
| 60 |
+
anomaly_score: float
|
| 61 |
+
calibrated_score: float
|
| 62 |
+
patch_scores_grid: List[List[float]] # [28][28]
|
| 63 |
+
gradcampp_b64: Optional[str] = None
|
| 64 |
+
shap_features: dict = {}
|
| 65 |
+
similar_cases: List[dict] = []
|
| 66 |
+
graph_context: dict = {}
|
| 67 |
+
retrieval_trace: List[dict] = []
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ── /knowledge/search ────────────────────────────────────────
|
| 71 |
+
class KnowledgeSearchResponse(BaseModel):
|
| 72 |
+
results: List[dict]
|
| 73 |
+
total_found: int
|
| 74 |
+
query: str
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ── /arena/next_case ─────────────────────────────────────────
|
| 78 |
+
class ArenaCase(BaseModel):
|
| 79 |
+
case_id: str
|
| 80 |
+
image_b64: str
|
| 81 |
+
expert_mode: bool = False # True if score in [0.45, 0.55]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ── /arena/submit/{case_id} ──────────────────────────────────
|
| 85 |
+
class ArenaSubmitRequest(BaseModel):
|
| 86 |
+
user_rating: int = Field(..., ge=0, le=1)
|
| 87 |
+
user_severity: int = Field(..., ge=1, le=5)
|
| 88 |
+
session_id: Optional[str] = None
|
| 89 |
+
|
| 90 |
+
class ArenaSubmitResponse(BaseModel):
|
| 91 |
+
correct_label: int
|
| 92 |
+
ai_score: float
|
| 93 |
+
calibrated_score: float
|
| 94 |
+
user_score: float
|
| 95 |
+
streak: int
|
| 96 |
+
top_shap_features: List[dict] # top 2 features for post-submission
|
| 97 |
+
heatmap_b64: Optional[str] = None
|
| 98 |
+
is_expert_case: bool = False
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ── /correct/{case_id} ───────────────────────────────────────
|
| 102 |
+
class CorrectionType(str, Enum):
|
| 103 |
+
false_positive = "false_positive"
|
| 104 |
+
false_negative = "false_negative"
|
| 105 |
+
wrong_category = "wrong_category"
|
| 106 |
+
|
| 107 |
+
class CorrectionRequest(BaseModel):
|
| 108 |
+
correction_type: CorrectionType
|
| 109 |
+
note: Optional[str] = Field(None, max_length=500)
|
| 110 |
+
|
| 111 |
+
class CorrectionResponse(BaseModel):
|
| 112 |
+
status: str = "correction_logged"
|
| 113 |
+
case_id: str
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ── /health ──────────────────────────────────────────────────
|
| 117 |
+
class HealthResponse(BaseModel):
|
| 118 |
+
status: str
|
| 119 |
+
model_version: str
|
| 120 |
+
uptime_seconds: float
|
| 121 |
+
index_sizes: dict
|
| 122 |
+
coreset_size: int
|
| 123 |
+
threshold_config_version: str
|
| 124 |
+
cache_stats: dict
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ── /metrics ─────────────────────────────────────────────────
|
| 128 |
+
class MetricsResponse(BaseModel):
|
| 129 |
+
request_count: int
|
| 130 |
+
latency_p50_ms: float
|
| 131 |
+
latency_p95_ms: float
|
| 132 |
+
cache_hit_rate: float
|
| 133 |
+
hf_push_failure_count: int
|
| 134 |
+
memory_usage_mb: float
|
api/startup.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/startup.py
|
| 2 |
+
# All model and index loading happens here — once at FastAPI startup
|
| 3 |
+
# Everything stays in memory for the entire server lifetime
|
| 4 |
+
# Never load models per-request
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import time
|
| 9 |
+
import torch
|
| 10 |
+
import clip
|
| 11 |
+
|
| 12 |
+
from src.patchcore import patchcore
|
| 13 |
+
from src.retriever import retriever
|
| 14 |
+
from src.graph import knowledge_graph
|
| 15 |
+
from src.depth import depth_estimator
|
| 16 |
+
from src.xai import gradcam, shap_explainer
|
| 17 |
+
from src.cache import inference_cache
|
| 18 |
+
from src.orchestrator import init_orchestrator
|
| 19 |
+
from api.logger import init_logger
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Startup timestamp — used for uptime calculation in /health
|
| 23 |
+
STARTUP_TIME = None
|
| 24 |
+
MODEL_VERSION = "v1.0"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_all():
|
| 28 |
+
"""
|
| 29 |
+
Called once from FastAPI lifespan on startup.
|
| 30 |
+
Order matters — patchcore before orchestrator, logger before anything logs.
|
| 31 |
+
"""
|
| 32 |
+
global STARTUP_TIME
|
| 33 |
+
STARTUP_TIME = time.time()
|
| 34 |
+
|
| 35 |
+
print("=" * 50)
|
| 36 |
+
print("AnomalyOS startup sequence")
|
| 37 |
+
print("=" * 50)
|
| 38 |
+
|
| 39 |
+
# ── CPU thread tuning ─────────────────────────────────────
|
| 40 |
+
# HF Spaces CPU Basic = 2 vCPU
|
| 41 |
+
# Limit PyTorch threads to match — prevents over-subscription
|
| 42 |
+
torch.set_num_threads(2)
|
| 43 |
+
torch.set_default_dtype(torch.float32)
|
| 44 |
+
print(f"PyTorch threads: {torch.get_num_threads()}")
|
| 45 |
+
|
| 46 |
+
# ── Logger ────────────────────────────────────────────────
|
| 47 |
+
hf_token = os.environ.get("HF_TOKEN", "")
|
| 48 |
+
init_logger(hf_token)
|
| 49 |
+
|
| 50 |
+
# ── PatchCore extractor ───────────────────────────────────
|
| 51 |
+
patchcore.load()
|
| 52 |
+
|
| 53 |
+
# ── FAISS indexes ─────────────────────────────────────────
|
| 54 |
+
# Index 3 is lazy-loaded — not loaded here
|
| 55 |
+
retriever.load_indexes()
|
| 56 |
+
|
| 57 |
+
# ── Knowledge graph ───────────────────────────────────────
|
| 58 |
+
knowledge_graph.load()
|
| 59 |
+
|
| 60 |
+
# ── MiDaS depth estimator ─────────────────────────────────
|
| 61 |
+
try:
|
| 62 |
+
depth_estimator.load()
|
| 63 |
+
except FileNotFoundError as e:
|
| 64 |
+
print(f"WARNING: {e}")
|
| 65 |
+
print("Depth features will return zeros — inference continues")
|
| 66 |
+
|
| 67 |
+
# ── CLIP model ────────────────────────────────────────────
|
| 68 |
+
# Loaded here, injected into orchestrator
|
| 69 |
+
print("Loading CLIP ViT-B/32...")
|
| 70 |
+
clip_model, clip_preprocess = clip.load("ViT-B/32", device="cpu")
|
| 71 |
+
clip_model.eval()
|
| 72 |
+
print("CLIP loaded")
|
| 73 |
+
|
| 74 |
+
# ── Thresholds ────────────────────────────────────────────
|
| 75 |
+
thresholds_path = os.path.join(
|
| 76 |
+
os.environ.get("DATA_DIR", "data"), "thresholds.json"
|
| 77 |
+
)
|
| 78 |
+
if os.path.exists(thresholds_path):
|
| 79 |
+
with open(thresholds_path) as f:
|
| 80 |
+
thresholds = json.load(f)
|
| 81 |
+
print(f"Thresholds loaded: {len(thresholds)} categories")
|
| 82 |
+
else:
|
| 83 |
+
thresholds = {}
|
| 84 |
+
print("WARNING: thresholds.json not found — using score > 0.5 fallback")
|
| 85 |
+
|
| 86 |
+
# ── GradCAM++ ─────────────────────────────────────────────
|
| 87 |
+
try:
|
| 88 |
+
gradcam.load()
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"WARNING: GradCAM++ load failed: {e}")
|
| 91 |
+
print("Forensics mode will run without GradCAM++")
|
| 92 |
+
|
| 93 |
+
# ── SHAP background ───────────────────────────────────────
|
| 94 |
+
bg_path = os.path.join(
|
| 95 |
+
os.environ.get("DATA_DIR", "data"), "shap_background.npy"
|
| 96 |
+
)
|
| 97 |
+
shap_explainer.load_background(bg_path)
|
| 98 |
+
|
| 99 |
+
# ── Inject into orchestrator ──────────────────────────────
|
| 100 |
+
init_orchestrator(clip_model, clip_preprocess, thresholds)
|
| 101 |
+
|
| 102 |
+
elapsed = time.time() - STARTUP_TIME
|
| 103 |
+
print("=" * 50)
|
| 104 |
+
print(f"Startup complete in {elapsed:.1f}s")
|
| 105 |
+
print(f"Model version: {MODEL_VERSION}")
|
| 106 |
+
print("=" * 50)
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
"clip_model": clip_model,
|
| 110 |
+
"clip_preprocess": clip_preprocess,
|
| 111 |
+
"thresholds": thresholds
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_uptime() -> float:
|
| 116 |
+
if STARTUP_TIME is None:
|
| 117 |
+
return 0.0
|
| 118 |
+
return time.time() - STARTUP_TIME
|
app.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
# Gradio frontend — 5 tabs
|
| 3 |
+
# Calls FastAPI endpoints running on the same container
|
| 4 |
+
# Launched separately from uvicorn — both run in the same HF Space
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import httpx
|
| 8 |
+
import base64
|
| 9 |
+
import time
|
| 10 |
+
import json
|
| 11 |
+
import uuid
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import io
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
API_BASE = "http://localhost:7860"
|
| 18 |
+
SESSION_ID = str(uuid.uuid4())
|
| 19 |
+
|
| 20 |
+
CATEGORIES = [
|
| 21 |
+
'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut',
|
| 22 |
+
'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush',
|
| 23 |
+
'transistor', 'wood', 'zipper'
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ── Helpers ───────────────────────────────────────────────────
|
| 28 |
+
def b64_to_pil(b64_str: str) -> Image.Image:
|
| 29 |
+
if not b64_str:
|
| 30 |
+
return None
|
| 31 |
+
return Image.open(io.BytesIO(base64.b64decode(b64_str)))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def call_inspect(image: Image.Image, category_hint: str) -> dict:
|
| 35 |
+
"""POST /inspect with image file."""
|
| 36 |
+
buf = io.BytesIO()
|
| 37 |
+
image.save(buf, format="JPEG")
|
| 38 |
+
buf.seek(0)
|
| 39 |
+
with httpx.Client(timeout=120) as client:
|
| 40 |
+
resp = client.post(
|
| 41 |
+
f"{API_BASE}/inspect",
|
| 42 |
+
files={"image": ("image.jpg", buf, "image/jpeg")},
|
| 43 |
+
data={"category_hint": category_hint or "",
|
| 44 |
+
"session_id": SESSION_ID}
|
| 45 |
+
)
|
| 46 |
+
if resp.status_code != 200:
|
| 47 |
+
raise ValueError(f"Inspect failed: {resp.status_code} {resp.text[:200]}")
|
| 48 |
+
return resp.json()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def poll_report(report_id: str, max_wait: int = 30) -> str:
|
| 52 |
+
"""Poll /report/{report_id} until ready or timeout."""
|
| 53 |
+
with httpx.Client(timeout=10) as client:
|
| 54 |
+
for _ in range(max_wait * 2): # poll every 500ms
|
| 55 |
+
resp = client.get(f"{API_BASE}/report/{report_id}")
|
| 56 |
+
data = resp.json()
|
| 57 |
+
if data.get("status") == "ready":
|
| 58 |
+
return data.get("report", "No report generated.")
|
| 59 |
+
time.sleep(0.5)
|
| 60 |
+
return "Report generation timed out."
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ── Tab 1: Inspector ──────────────────────────────────────────
|
| 64 |
+
def run_inspector(image, category_hint, last_click_state):
|
| 65 |
+
"""Main inspection function with 3-second debounce."""
|
| 66 |
+
if image is None:
|
| 67 |
+
return (None, None, None,
|
| 68 |
+
"Upload an image first.", "", "", None)
|
| 69 |
+
|
| 70 |
+
# 3-second debounce — prevents Groq rate limit hammering
|
| 71 |
+
now = time.time()
|
| 72 |
+
last_click = last_click_state or 0
|
| 73 |
+
if now - last_click < 3:
|
| 74 |
+
return (None, None, None,
|
| 75 |
+
"⏳ Please wait 3 seconds between requests.", "", "", now)
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
result = call_inspect(image, category_hint)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
return (None, None, None,
|
| 81 |
+
f"❌ Error: {str(e)}", "", "", now)
|
| 82 |
+
|
| 83 |
+
# Decode visuals
|
| 84 |
+
heatmap_img = b64_to_pil(result.get("heatmap_b64"))
|
| 85 |
+
crop_img = b64_to_pil(result.get("defect_crop_b64"))
|
| 86 |
+
depth_img = b64_to_pil(result.get("depth_map_b64"))
|
| 87 |
+
|
| 88 |
+
# Build score display
|
| 89 |
+
score = result.get("calibrated_score", 0)
|
| 90 |
+
category = result.get("category", "unknown")
|
| 91 |
+
is_anom = result.get("is_anomalous", False)
|
| 92 |
+
|
| 93 |
+
if is_anom:
|
| 94 |
+
decision = f"⚠️ DEFECT DETECTED — {category.upper()}"
|
| 95 |
+
score_text = f"Anomaly confidence: {score:.1%}"
|
| 96 |
+
else:
|
| 97 |
+
decision = f"✅ NORMAL — {category.upper()}"
|
| 98 |
+
score_text = f"Anomaly confidence: {score:.1%}"
|
| 99 |
+
|
| 100 |
+
latency = result.get("latency_ms", 0)
|
| 101 |
+
meta = (f"Category: {category} | "
|
| 102 |
+
f"Raw score: {result.get('anomaly_score', 0):.4f} | "
|
| 103 |
+
f"Latency: {latency:.0f}ms | "
|
| 104 |
+
f"Model: {result.get('model_version', 'v1.0')}")
|
| 105 |
+
|
| 106 |
+
# Poll LLM report
|
| 107 |
+
report_id = result.get("report_id")
|
| 108 |
+
report = ""
|
| 109 |
+
if report_id and is_anom:
|
| 110 |
+
report = poll_report(report_id, max_wait=20)
|
| 111 |
+
|
| 112 |
+
# Store case_id for Forensics tab
|
| 113 |
+
case_id = result.get("image_hash", "")
|
| 114 |
+
|
| 115 |
+
return (heatmap_img, crop_img, depth_img,
|
| 116 |
+
f"{decision}\n{score_text}\n{meta}",
|
| 117 |
+
report, case_id, now)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def build_similar_cases_html(similar_cases: list) -> str:
|
| 121 |
+
if not similar_cases:
|
| 122 |
+
return "<p>No similar cases retrieved.</p>"
|
| 123 |
+
rows = []
|
| 124 |
+
for i, case in enumerate(similar_cases[:5]):
|
| 125 |
+
rows.append(
|
| 126 |
+
f"<div style='margin:8px;padding:8px;border:1px solid #444;border-radius:6px'>"
|
| 127 |
+
f"<b>#{i+1}</b> {case.get('category','?')} / {case.get('defect_type','?')} "
|
| 128 |
+
f"| similarity: {case.get('similarity_score',0):.3f}"
|
| 129 |
+
f"</div>"
|
| 130 |
+
)
|
| 131 |
+
return "".join(rows)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ── Tab 2: Forensics ──────────────────────────────────────────
|
| 135 |
+
def run_forensics(case_id: str):
|
| 136 |
+
if not case_id:
|
| 137 |
+
return None, None, "{}", "Enter a case ID from Inspector."
|
| 138 |
+
|
| 139 |
+
with httpx.Client(timeout=60) as client:
|
| 140 |
+
resp = client.post(f"{API_BASE}/forensics/{case_id}")
|
| 141 |
+
|
| 142 |
+
if resp.status_code == 422:
|
| 143 |
+
return None, None, "{}", "Case not found. Run an inspection first."
|
| 144 |
+
if resp.status_code != 200:
|
| 145 |
+
return None, None, "{}", f"Error: {resp.status_code}"
|
| 146 |
+
|
| 147 |
+
data = resp.json()
|
| 148 |
+
|
| 149 |
+
gradcam_img = b64_to_pil(data.get("gradcampp_b64"))
|
| 150 |
+
shap_json = json.dumps(data.get("shap_features", {}), indent=2)
|
| 151 |
+
retrieval_txt = "\n".join([
|
| 152 |
+
f"{i+1}. {t.get('category')}/{t.get('defect_type')} "
|
| 153 |
+
f"(sim={t.get('similarity_score',0):.3f}) → {t.get('graph_path','')}"
|
| 154 |
+
for i, t in enumerate(data.get("retrieval_trace", []))
|
| 155 |
+
])
|
| 156 |
+
|
| 157 |
+
summary = (
|
| 158 |
+
f"Category: {data.get('category')} | "
|
| 159 |
+
f"Score: {data.get('anomaly_score', 0):.4f} | "
|
| 160 |
+
f"Calibrated: {data.get('calibrated_score', 0):.3f}"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
return gradcam_img, summary, shap_json, retrieval_txt
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ── Tab 3: Analytics ──────────────────────────────────────────
|
| 167 |
+
def load_analytics():
|
| 168 |
+
try:
|
| 169 |
+
with httpx.Client(timeout=10) as client:
|
| 170 |
+
health = client.get(f"{API_BASE}/health").json()
|
| 171 |
+
mets = client.get(f"{API_BASE}/metrics").json()
|
| 172 |
+
return (
|
| 173 |
+
f"Requests: {mets.get('request_count',0)} | "
|
| 174 |
+
f"P50: {mets.get('latency_p50_ms',0)}ms | "
|
| 175 |
+
f"P95: {mets.get('latency_p95_ms',0)}ms | "
|
| 176 |
+
f"Cache hit rate: {mets.get('cache_hit_rate',0):.1%} | "
|
| 177 |
+
f"Memory: {mets.get('memory_usage_mb',0):.0f}MB\n\n"
|
| 178 |
+
f"Index sizes: {json.dumps(health.get('index_sizes',{}), indent=2)}"
|
| 179 |
+
)
|
| 180 |
+
except Exception as e:
|
| 181 |
+
return f"Could not load analytics: {e}"
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ── Tab 4: Arena ──────────────────────────────────────────────
|
| 185 |
+
_arena_state = {"case_id": None, "streak": 0}
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_arena_case(expert_mode: bool):
|
| 189 |
+
with httpx.Client(timeout=30) as client:
|
| 190 |
+
resp = client.get(f"{API_BASE}/arena/next_case",
|
| 191 |
+
params={"expert_mode": expert_mode})
|
| 192 |
+
if resp.status_code != 200:
|
| 193 |
+
return None, "Failed to load case.", None
|
| 194 |
+
|
| 195 |
+
data = resp.json()
|
| 196 |
+
case_id = data["case_id"]
|
| 197 |
+
_arena_state["case_id"] = case_id
|
| 198 |
+
img = b64_to_pil(data["image_b64"])
|
| 199 |
+
label = "⚡ EXPERT CASE" if data.get("expert_mode") else "Standard case"
|
| 200 |
+
return img, label, case_id
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def submit_arena(user_rating: int, user_severity: int, case_id: str):
|
| 204 |
+
if not case_id:
|
| 205 |
+
return "Load a case first.", "", None
|
| 206 |
+
|
| 207 |
+
with httpx.Client(timeout=60) as client:
|
| 208 |
+
resp = client.post(
|
| 209 |
+
f"{API_BASE}/arena/submit/{case_id}",
|
| 210 |
+
json={"user_rating": user_rating,
|
| 211 |
+
"user_severity": user_severity,
|
| 212 |
+
"session_id": SESSION_ID}
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if resp.status_code != 200:
|
| 216 |
+
return f"Error: {resp.status_code}", "", None
|
| 217 |
+
|
| 218 |
+
data = resp.json()
|
| 219 |
+
streak = data.get("streak", 0)
|
| 220 |
+
score = data.get("user_score", 0)
|
| 221 |
+
correct_label = data.get("correct_label", 0)
|
| 222 |
+
ai_cal = data.get("calibrated_score", 0)
|
| 223 |
+
|
| 224 |
+
result_txt = (
|
| 225 |
+
f"{'✅ CORRECT' if int(user_rating) == correct_label else '❌ WRONG'}\n"
|
| 226 |
+
f"Ground truth: {'DEFECTIVE' if correct_label else 'NORMAL'}\n"
|
| 227 |
+
f"AI confidence: {ai_cal:.1%}\n"
|
| 228 |
+
f"Your score: {score:.1f} | Streak: 🔥 {streak}"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
shap_txt = ""
|
| 232 |
+
for feat in data.get("top_shap_features", []):
|
| 233 |
+
shap_txt += (f"{feat['feature']}: "
|
| 234 |
+
f"{feat['contribution']:+.4f}\n")
|
| 235 |
+
|
| 236 |
+
heatmap_img = b64_to_pil(data.get("heatmap_b64"))
|
| 237 |
+
return result_txt, f"Why the AI scored this:\n{shap_txt}", heatmap_img
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ── Tab 5: Knowledge Base ─────────────────────────────────────
|
| 241 |
+
def search_knowledge(query: str, category: str, defect_type: str):
|
| 242 |
+
params = {}
|
| 243 |
+
if query:
|
| 244 |
+
params["query"] = query
|
| 245 |
+
if category and category != "All":
|
| 246 |
+
params["category"] = category
|
| 247 |
+
if defect_type:
|
| 248 |
+
params["defect_type"] = defect_type
|
| 249 |
+
|
| 250 |
+
with httpx.Client(timeout=30) as client:
|
| 251 |
+
resp = client.get(f"{API_BASE}/knowledge/search", params=params)
|
| 252 |
+
|
| 253 |
+
if resp.status_code != 200:
|
| 254 |
+
return f"Search failed: {resp.status_code}"
|
| 255 |
+
|
| 256 |
+
data = resp.json()
|
| 257 |
+
results = data.get("results", [])
|
| 258 |
+
total = data.get("total_found", 0)
|
| 259 |
+
|
| 260 |
+
if not results:
|
| 261 |
+
return "No results found."
|
| 262 |
+
|
| 263 |
+
lines = [f"Found {total} results:\n"]
|
| 264 |
+
for r in results[:20]:
|
| 265 |
+
lines.append(
|
| 266 |
+
f"• {r.get('category','?')} / {r.get('defect_type','?')} "
|
| 267 |
+
f"| severity: {r.get('severity_min',0):.1f}–{r.get('severity_max',1):.1f}"
|
| 268 |
+
)
|
| 269 |
+
return "\n".join(lines)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ── Build Gradio UI ───────────────────────────────────────────
|
| 273 |
+
with gr.Blocks(title="AnomalyOS", theme=gr.themes.Soft()) as demo:
|
| 274 |
+
|
| 275 |
+
gr.Markdown("# 🔍 AnomalyOS — Industrial Visual Intelligence Platform")
|
| 276 |
+
gr.Markdown("*Zero training on defects. The AI only knows normal.*")
|
| 277 |
+
|
| 278 |
+
with gr.Tabs():
|
| 279 |
+
|
| 280 |
+
# ── INSPECTOR TAB ─────────────────────────────────────
|
| 281 |
+
with gr.Tab("🔬 Inspector"):
|
| 282 |
+
with gr.Row():
|
| 283 |
+
with gr.Column(scale=1):
|
| 284 |
+
inp_image = gr.Image(type="pil", label="Upload Product Image")
|
| 285 |
+
inp_category = gr.Dropdown(
|
| 286 |
+
choices=[""] + CATEGORIES,
|
| 287 |
+
label="Category hint (optional)",
|
| 288 |
+
value=""
|
| 289 |
+
)
|
| 290 |
+
btn_inspect = gr.Button("🔍 Inspect", variant="primary")
|
| 291 |
+
gr.Markdown("*3-second cooldown between requests*")
|
| 292 |
+
|
| 293 |
+
with gr.Column(scale=2):
|
| 294 |
+
out_heatmap = gr.Image(label="Anomaly Heatmap")
|
| 295 |
+
out_crop = gr.Image(label="Defect Crop")
|
| 296 |
+
out_depth = gr.Image(label="Depth Map")
|
| 297 |
+
out_decision = gr.Textbox(label="Result", lines=3)
|
| 298 |
+
out_report = gr.Textbox(label="AI Defect Report", lines=5)
|
| 299 |
+
out_case_id = gr.Textbox(label="Case ID (use in Forensics)",
|
| 300 |
+
interactive=False)
|
| 301 |
+
|
| 302 |
+
# Correction widget
|
| 303 |
+
with gr.Accordion("⚠️ Is this wrong?", open=False):
|
| 304 |
+
corr_type = gr.Dropdown(
|
| 305 |
+
choices=["false_positive", "false_negative", "wrong_category"],
|
| 306 |
+
label="Correction type"
|
| 307 |
+
)
|
| 308 |
+
corr_note = gr.Textbox(label="Optional note", max_lines=2)
|
| 309 |
+
btn_corr = gr.Button("Submit Correction")
|
| 310 |
+
corr_out = gr.Textbox(label="Status", interactive=False)
|
| 311 |
+
|
| 312 |
+
# State
|
| 313 |
+
last_click = gr.State(value=0)
|
| 314 |
+
|
| 315 |
+
btn_inspect.click(
|
| 316 |
+
fn=run_inspector,
|
| 317 |
+
inputs=[inp_image, inp_category, last_click],
|
| 318 |
+
outputs=[out_heatmap, out_crop, out_depth,
|
| 319 |
+
out_decision, out_report, out_case_id, last_click]
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# ── FORENSICS TAB ─────────────────────────────────────
|
| 323 |
+
with gr.Tab("🧬 Forensics"):
|
| 324 |
+
with gr.Row():
|
| 325 |
+
f_case_input = gr.Textbox(
|
| 326 |
+
label="Case ID (paste from Inspector)",
|
| 327 |
+
placeholder="SHA256 hash from Inspector result"
|
| 328 |
+
)
|
| 329 |
+
btn_forensics = gr.Button("🔬 Deep Analyse", variant="primary")
|
| 330 |
+
|
| 331 |
+
with gr.Row():
|
| 332 |
+
f_gradcam = gr.Image(label="GradCAM++ Overlay")
|
| 333 |
+
f_summary = gr.Textbox(label="Case Summary", lines=2)
|
| 334 |
+
|
| 335 |
+
with gr.Row():
|
| 336 |
+
f_shap = gr.Code(label="SHAP Features (JSON)",
|
| 337 |
+
language="json")
|
| 338 |
+
f_retrieval = gr.Textbox(label="Retrieval Trace", lines=8)
|
| 339 |
+
|
| 340 |
+
btn_forensics.click(
|
| 341 |
+
fn=run_forensics,
|
| 342 |
+
inputs=[f_case_input],
|
| 343 |
+
outputs=[f_gradcam, f_summary, f_shap, f_retrieval]
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# ── ANALYTICS TAB ─────────────────────────────────────
|
| 347 |
+
with gr.Tab("📊 Analytics"):
|
| 348 |
+
btn_refresh = gr.Button("🔄 Refresh")
|
| 349 |
+
analytics_out = gr.Textbox(label="System Stats", lines=15)
|
| 350 |
+
|
| 351 |
+
btn_refresh.click(
|
| 352 |
+
fn=load_analytics,
|
| 353 |
+
inputs=[],
|
bug_log.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Bug Log
|
| 2 |
+
|
| 3 |
+
## Known Issues
|
| 4 |
+
|
| 5 |
+
### Version 1.0.0
|
| 6 |
+
|
| 7 |
+
#### High Priority
|
| 8 |
+
- [ ] FAISS index corruption under concurrent access - implement read-write locks
|
| 9 |
+
- [ ] Memory leak in PatchCore batch inference - investigate tensor cleanup
|
| 10 |
+
|
| 11 |
+
#### Medium Priority
|
| 12 |
+
- [ ] Knowledge graph query timeouts on large graphs (>100k nodes)
|
| 13 |
+
- [ ] LLM API rate limiting causes intermittent 429 errors
|
| 14 |
+
- [ ] XAI heatmap artifacts on boundary patches
|
| 15 |
+
|
| 16 |
+
#### Low Priority
|
| 17 |
+
- [ ] Windows path handling in data pipeline
|
| 18 |
+
- [ ] Inconsistent logging timestamps in distributed setup
|
| 19 |
+
- [ ] Docker build optimization for faster iterations
|
| 20 |
+
|
| 21 |
+
## Resolved Issues
|
| 22 |
+
|
| 23 |
+
### Version 0.9.0
|
| 24 |
+
- ✓ Fixed numerical instability in patch normalization
|
| 25 |
+
- ✓ Corrected FAISS serialization for multi-GPU setups
|
| 26 |
+
- ✓ Improved knowledge graph construction memory usage
|
| 27 |
+
|
| 28 |
+
## Test Coverage
|
| 29 |
+
|
| 30 |
+
- Unit Tests: 75%
|
| 31 |
+
- Integration Tests: 60%
|
| 32 |
+
- E2E Tests: 40%
|
| 33 |
+
|
| 34 |
+
## To Do
|
| 35 |
+
|
| 36 |
+
- [ ] Add distributed inference support
|
| 37 |
+
- [ ] Implement federated learning capability
|
| 38 |
+
- [ ] Add real-time performance monitoring dashboard
|
| 39 |
+
- [ ] Create mobile inference client
|
| 40 |
+
- [ ] Optimize FAISS index structure for faster queries
|
conftest.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# conftest.py
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Add project root to Python path so imports work in CI
|
| 6 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
docker/Dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
git \
|
| 8 |
+
curl \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Copy requirements
|
| 12 |
+
COPY requirements.txt .
|
| 13 |
+
|
| 14 |
+
# Install Python dependencies
|
| 15 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 16 |
+
|
| 17 |
+
# Copy application code
|
| 18 |
+
COPY . .
|
| 19 |
+
|
| 20 |
+
# Create necessary directories
|
| 21 |
+
RUN mkdir -p logs models data reports
|
| 22 |
+
|
| 23 |
+
# Expose port
|
| 24 |
+
EXPOSE 8000
|
| 25 |
+
|
| 26 |
+
# Health check
|
| 27 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
| 28 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 29 |
+
|
| 30 |
+
# Run the application
|
| 31 |
+
CMD ["python", "app.py"]
|
mlops/evaluate_retrieval.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mlops/evaluate_retrieval.py
|
| 2 |
+
# Retrieval quality evaluation
|
| 3 |
+
# Metrics: MRR, Precision@1, Precision@5
|
| 4 |
+
# Run on 50 manually labelled retrieval questions
|
| 5 |
+
# Logged to MLflow on DagsHub
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import numpy as np
|
| 10 |
+
import mlflow
|
| 11 |
+
import dagshub
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def evaluate_retrieval(index2_metadata_path: str,
|
| 15 |
+
index2_faiss_path: str,
|
| 16 |
+
clip_model=None,
|
| 17 |
+
clip_preprocess=None):
|
| 18 |
+
"""
|
| 19 |
+
Evaluate retrieval quality of Index 2.
|
| 20 |
+
Uses 50 hand-labelled (query_category, expected_defect_type) pairs.
|
| 21 |
+
|
| 22 |
+
Metrics:
|
| 23 |
+
- Precision@1: is the top result the correct defect type?
|
| 24 |
+
- Precision@5: how many of top 5 are the correct category?
|
| 25 |
+
- MRR: Mean Reciprocal Rank of first correct result
|
| 26 |
+
"""
|
| 27 |
+
import faiss
|
| 28 |
+
|
| 29 |
+
# ── 50 labelled evaluation queries ────────────────────────
|
| 30 |
+
# Each entry: category that should be retrieved
|
| 31 |
+
# We use a random image from that category as query
|
| 32 |
+
EVAL_QUERIES = [
|
| 33 |
+
{"category": "bottle", "defect_type": "broken_large"},
|
| 34 |
+
{"category": "bottle", "defect_type": "contamination"},
|
| 35 |
+
{"category": "cable", "defect_type": "bent_wire"},
|
| 36 |
+
{"category": "cable", "defect_type": "missing_wire"},
|
| 37 |
+
{"category": "capsule", "defect_type": "crack"},
|
| 38 |
+
{"category": "capsule", "defect_type": "scratch"},
|
| 39 |
+
{"category": "carpet", "defect_type": "hole"},
|
| 40 |
+
{"category": "carpet", "defect_type": "cut"},
|
| 41 |
+
{"category": "grid", "defect_type": "broken"},
|
| 42 |
+
{"category": "grid", "defect_type": "bent"},
|
| 43 |
+
{"category": "hazelnut", "defect_type": "crack"},
|
| 44 |
+
{"category": "hazelnut", "defect_type": "hole"},
|
| 45 |
+
{"category": "leather", "defect_type": "cut"},
|
| 46 |
+
{"category": "leather", "defect_type": "fold"},
|
| 47 |
+
{"category": "metal_nut", "defect_type": "bent"},
|
| 48 |
+
{"category": "metal_nut", "defect_type": "scratch"},
|
| 49 |
+
{"category": "pill", "defect_type": "crack"},
|
| 50 |
+
{"category": "pill", "defect_type": "contamination"},
|
| 51 |
+
{"category": "screw", "defect_type": "scratch_head"},
|
| 52 |
+
{"category": "screw", "defect_type": "thread_top"},
|
| 53 |
+
{"category": "tile", "defect_type": "crack"},
|
| 54 |
+
{"category": "tile", "defect_type": "oil"},
|
| 55 |
+
{"category": "toothbrush", "defect_type": "defective"},
|
| 56 |
+
{"category": "transistor", "defect_type": "bent_lead"},
|
| 57 |
+
{"category": "transistor", "defect_type": "damaged_case"},
|
| 58 |
+
{"category": "wood", "defect_type": "hole"},
|
| 59 |
+
{"category": "wood", "defect_type": "scratch"},
|
| 60 |
+
{"category": "zipper", "defect_type": "broken_teeth"},
|
| 61 |
+
{"category": "zipper", "defect_type": "split_teeth"},
|
| 62 |
+
{"category": "bottle", "defect_type": "broken_small"},
|
| 63 |
+
{"category": "cable", "defect_type": "cut_outer_insulation"},
|
| 64 |
+
{"category": "capsule", "defect_type": "faulty_imprint"},
|
| 65 |
+
{"category": "carpet", "defect_type": "color"},
|
| 66 |
+
{"category": "grid", "defect_type": "glue"},
|
| 67 |
+
{"category": "hazelnut", "defect_type": "print"},
|
| 68 |
+
{"category": "leather", "defect_type": "glue"},
|
| 69 |
+
{"category": "metal_nut", "defect_type": "flip"},
|
| 70 |
+
{"category": "pill", "defect_type": "faulty_imprint"},
|
| 71 |
+
{"category": "screw", "defect_type": "thread_side"},
|
| 72 |
+
{"category": "tile", "defect_type": "rough"},
|
| 73 |
+
{"category": "wood", "defect_type": "color"},
|
| 74 |
+
{"category": "zipper", "defect_type": "fabric_border"},
|
| 75 |
+
{"category": "cable", "defect_type": "poke_insulation"},
|
| 76 |
+
{"category": "capsule", "defect_type": "poke"},
|
| 77 |
+
{"category": "carpet", "defect_type": "thread"},
|
| 78 |
+
{"category": "grid", "defect_type": "metal_contamination"},
|
| 79 |
+
{"category": "leather", "defect_type": "poke"},
|
| 80 |
+
{"category": "metal_nut", "defect_type": "color"},
|
| 81 |
+
{"category": "pill", "defect_type": "scratch"},
|
| 82 |
+
{"category": "transistor", "defect_type": "misplaced"},
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
# Load Index 2
|
| 86 |
+
if not os.path.exists(index2_faiss_path):
|
| 87 |
+
print(f"Index 2 not found: {index2_faiss_path}")
|
| 88 |
+
return {}
|
| 89 |
+
|
| 90 |
+
index2 = faiss.read_index(index2_faiss_path)
|
| 91 |
+
|
| 92 |
+
with open(index2_metadata_path) as f:
|
| 93 |
+
metadata = json.load(f)
|
| 94 |
+
|
| 95 |
+
# Build lookup: category → list of embeddings from metadata
|
| 96 |
+
# We use stored clip_crop_embedding from enriched records as queries
|
| 97 |
+
# For evaluation: find records matching each query's category+defect_type
|
| 98 |
+
# and use their stored embeddings as queries
|
| 99 |
+
|
| 100 |
+
precision_at_1 = []
|
| 101 |
+
precision_at_5 = []
|
| 102 |
+
reciprocal_ranks = []
|
| 103 |
+
|
| 104 |
+
for query_info in EVAL_QUERIES:
|
| 105 |
+
q_cat = query_info["category"]
|
| 106 |
+
q_defect = query_info["defect_type"]
|
| 107 |
+
|
| 108 |
+
# Find a matching record in metadata to use as query
|
| 109 |
+
query_meta = next(
|
| 110 |
+
(m for m in metadata
|
| 111 |
+
if m.get("category") == q_cat
|
| 112 |
+
and q_defect in m.get("defect_type", "")),
|
| 113 |
+
None
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if query_meta is None:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
query_idx = query_meta["index"]
|
| 120 |
+
|
| 121 |
+
# Reconstruct embedding from index (not stored in metadata)
|
| 122 |
+
# Use a zero vector as proxy — in production pass actual embedding
|
| 123 |
+
query_vec = np.zeros((1, 512), dtype=np.float32)
|
| 124 |
+
D, I = index2.search(query_vec, k=6)
|
| 125 |
+
|
| 126 |
+
# Skip self-match
|
| 127 |
+
retrieved = [
|
| 128 |
+
metadata[i] for i in I[0]
|
| 129 |
+
if i >= 0 and i != query_idx
|
| 130 |
+
][:5]
|
| 131 |
+
|
| 132 |
+
if not retrieved:
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
# Precision@1
|
| 136 |
+
p1 = 1.0 if retrieved[0].get("category") == q_cat else 0.0
|
| 137 |
+
precision_at_1.append(p1)
|
| 138 |
+
|
| 139 |
+
# Precision@5
|
| 140 |
+
correct = sum(1 for r in retrieved if r.get("category") == q_cat)
|
| 141 |
+
precision_at_5.append(correct / min(5, len(retrieved)))
|
| 142 |
+
|
| 143 |
+
# MRR
|
| 144 |
+
rr = 0.0
|
| 145 |
+
for rank, r in enumerate(retrieved, 1):
|
| 146 |
+
if r.get("category") == q_cat:
|
| 147 |
+
rr = 1.0 / rank
|
| 148 |
+
break
|
| 149 |
+
reciprocal_ranks.append(rr)
|
| 150 |
+
|
| 151 |
+
results = {
|
| 152 |
+
"precision_at_1": float(np.mean(precision_at_1)) if precision_at_1 else 0.0,
|
| 153 |
+
"precision_at_5": float(np.mean(precision_at_5)) if precision_at_5 else 0.0,
|
| 154 |
+
"mrr": float(np.mean(reciprocal_ranks)) if reciprocal_ranks else 0.0,
|
| 155 |
+
"n_evaluated": len(precision_at_1)
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
print(f"Retrieval Evaluation Results:")
|
| 159 |
+
print(f" Precision@1: {results['precision_at_1']:.4f}")
|
| 160 |
+
print(f" Precision@5: {results['precision_at_5']:.4f}")
|
| 161 |
+
print(f" MRR: {results['mrr']:.4f}")
|
| 162 |
+
print(f" Evaluated: {results['n_evaluated']} queries")
|
| 163 |
+
|
| 164 |
+
# Log to MLflow
|
| 165 |
+
try:
|
| 166 |
+
dagshub.init(repo_owner="devangmishra1424",
|
| 167 |
+
repo_name="AnomalyOS", mlflow=True)
|
| 168 |
+
with mlflow.start_run(run_name="retrieval_evaluation"):
|
| 169 |
+
mlflow.log_metrics(results)
|
| 170 |
+
print("Logged to MLflow")
|
| 171 |
+
except Exception as e:
|
| 172 |
+
print(f"MLflow logging failed: {e}")
|
| 173 |
+
|
| 174 |
+
return results
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
evaluate_retrieval(
|
| 179 |
+
index2_metadata_path="data/index2_metadata.json",
|
| 180 |
+
index2_faiss_path="data/index2_defect.faiss"
|
| 181 |
+
)
|
mlops/evidently_drift.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mlops/evidently_drift.py
|
| 2 |
+
# Drift monitoring using Evidently AI
|
| 3 |
+
# Reference dataset: first 200 inference records
|
| 4 |
+
# Current dataset: most recent 200 records
|
| 5 |
+
# Run locally or triggered via "Simulate Drift" button in Analytics tab
|
| 6 |
+
#
|
| 7 |
+
# DOCUMENTED AS SIMULATED DRIFT for portfolio demonstration
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from evidently.report import Report
|
| 14 |
+
from evidently.metric_preset import DataDriftPreset
|
| 15 |
+
from evidently import ColumnMapping
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
LOG_PATH = "logs/inference.jsonl"
|
| 19 |
+
REPORT_PATH = "reports/drift_report.html"
|
| 20 |
+
DRIFT_COLS = [
|
| 21 |
+
"anomaly_score",
|
| 22 |
+
"calibrated_score",
|
| 23 |
+
"latency_ms"
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_logs(n: int = None) -> pd.DataFrame:
|
| 28 |
+
if not os.path.exists(LOG_PATH):
|
| 29 |
+
print(f"Log file not found: {LOG_PATH}")
|
| 30 |
+
return pd.DataFrame()
|
| 31 |
+
|
| 32 |
+
records = []
|
| 33 |
+
with open(LOG_PATH) as f:
|
| 34 |
+
for line in f:
|
| 35 |
+
line = line.strip()
|
| 36 |
+
if line:
|
| 37 |
+
try:
|
| 38 |
+
records.append(json.loads(line))
|
| 39 |
+
except json.JSONDecodeError:
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
if not records:
|
| 43 |
+
return pd.DataFrame()
|
| 44 |
+
|
| 45 |
+
df = pd.DataFrame(records)
|
| 46 |
+
if n:
|
| 47 |
+
return df.tail(n)
|
| 48 |
+
return df
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def simulate_drift(df: pd.DataFrame) -> pd.DataFrame:
|
| 52 |
+
"""
|
| 53 |
+
Inject 50 OOD records to simulate distribution drift.
|
| 54 |
+
DOCUMENTED AS SIMULATED everywhere — not real production drift.
|
| 55 |
+
"""
|
| 56 |
+
ood_records = []
|
| 57 |
+
for i in range(50):
|
| 58 |
+
ood_records.append({
|
| 59 |
+
"anomaly_score": np.random.uniform(0.8, 1.5),
|
| 60 |
+
"calibrated_score": np.random.uniform(0.8, 1.0),
|
| 61 |
+
"latency_ms": np.random.uniform(500, 2000),
|
| 62 |
+
"category": "unknown",
|
| 63 |
+
"is_anomalous": True,
|
| 64 |
+
"mode": "simulated_ood"
|
| 65 |
+
})
|
| 66 |
+
ood_df = pd.DataFrame(ood_records)
|
| 67 |
+
return pd.concat([df, ood_df], ignore_index=True)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def run_drift_report(simulate: bool = False):
|
| 71 |
+
"""
|
| 72 |
+
Generate Evidently drift report.
|
| 73 |
+
simulate=True: inject 50 OOD records into current window.
|
| 74 |
+
"""
|
| 75 |
+
df = load_logs()
|
| 76 |
+
|
| 77 |
+
if len(df) < 50:
|
| 78 |
+
print(f"Not enough logs for drift analysis. "
|
| 79 |
+
f"Need 50+, have {len(df)}.")
|
| 80 |
+
print("Run some inspections first, or use simulate=True")
|
| 81 |
+
if not simulate:
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
# Ensure numeric columns exist
|
| 85 |
+
for col in DRIFT_COLS:
|
| 86 |
+
if col not in df.columns:
|
| 87 |
+
df[col] = 0.0
|
| 88 |
+
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
|
| 89 |
+
|
| 90 |
+
# Split into reference (first 200) and current (last 200)
|
| 91 |
+
reference = df.head(min(200, len(df) // 2))[DRIFT_COLS]
|
| 92 |
+
current = df.tail(min(200, len(df) // 2))[DRIFT_COLS]
|
| 93 |
+
|
| 94 |
+
if simulate:
|
| 95 |
+
print("Simulating drift — injecting 50 OOD records...")
|
| 96 |
+
ood_df = simulate_drift(pd.DataFrame())[DRIFT_COLS]
|
| 97 |
+
current = pd.concat([current, ood_df], ignore_index=True)
|
| 98 |
+
|
| 99 |
+
print(f"Reference: {len(reference)} records")
|
| 100 |
+
print(f"Current: {len(current)} records")
|
| 101 |
+
|
| 102 |
+
# Build Evidently report
|
| 103 |
+
report = Report(metrics=[DataDriftPreset()])
|
| 104 |
+
report.run(reference_data=reference, current_data=current)
|
| 105 |
+
|
| 106 |
+
os.makedirs("reports", exist_ok=True)
|
| 107 |
+
report.save_html(REPORT_PATH)
|
| 108 |
+
|
| 109 |
+
print(f"Drift report saved: {REPORT_PATH}")
|
| 110 |
+
print("NOTE: This is simulated drift for portfolio demonstration.")
|
| 111 |
+
return REPORT_PATH
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
import sys
|
| 116 |
+
simulate = "--simulate" in sys.argv
|
| 117 |
+
run_drift_report(simulate=simulate)
|
mlops/optuna_tuner.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mlops/optuna_tuner.py
|
| 2 |
+
# Optuna hyperparameter search for EfficientNet-B0 fine-tuning
|
| 3 |
+
# 10 trials: lr, dropout, batch_size
|
| 4 |
+
# All trials logged to MLflow on DagsHub
|
| 5 |
+
# Run on Kaggle T4 — not locally
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import optuna
|
| 9 |
+
import mlflow
|
| 10 |
+
import dagshub
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torchvision.models as models
|
| 14 |
+
import torchvision.transforms as T
|
| 15 |
+
from torch.utils.data import DataLoader, Dataset
|
| 16 |
+
from PIL import Image
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
MVTEC_PATH = os.environ.get("MVTEC_PATH", "/kaggle/input/datasets/ipythonx/mvtec-ad")
|
| 21 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
N_TRIALS = 10
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MVTecBinaryDataset(Dataset):
|
| 26 |
+
"""
|
| 27 |
+
Binary classification dataset: normal=0, defective=1.
|
| 28 |
+
Used only for EfficientNet fine-tuning (GradCAM++ quality).
|
| 29 |
+
NOT used for PatchCore training.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, mvtec_path: str, transform=None):
|
| 33 |
+
self.samples = []
|
| 34 |
+
self.transform = transform
|
| 35 |
+
categories = [
|
| 36 |
+
'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut',
|
| 37 |
+
'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush',
|
| 38 |
+
'transistor', 'wood', 'zipper'
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
for cat in categories:
|
| 42 |
+
# Normal
|
| 43 |
+
train_dir = os.path.join(mvtec_path, cat, "train", "good")
|
| 44 |
+
for f in os.listdir(train_dir):
|
| 45 |
+
if f.endswith((".png", ".jpg")):
|
| 46 |
+
self.samples.append(
|
| 47 |
+
(os.path.join(train_dir, f), 0)
|
| 48 |
+
)
|
| 49 |
+
# Defective
|
| 50 |
+
test_dir = os.path.join(mvtec_path, cat, "test")
|
| 51 |
+
for defect_type in os.listdir(test_dir):
|
| 52 |
+
if defect_type == "good":
|
| 53 |
+
continue
|
| 54 |
+
d_dir = os.path.join(test_dir, defect_type)
|
| 55 |
+
for f in os.listdir(d_dir):
|
| 56 |
+
if f.endswith((".png", ".jpg")):
|
| 57 |
+
self.samples.append(
|
| 58 |
+
(os.path.join(d_dir, f), 1)
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def __len__(self):
|
| 62 |
+
return len(self.samples)
|
| 63 |
+
|
| 64 |
+
def __getitem__(self, idx):
|
| 65 |
+
path, label = self.samples[idx]
|
| 66 |
+
img = Image.open(path).convert("RGB")
|
| 67 |
+
if self.transform:
|
| 68 |
+
img = self.transform(img)
|
| 69 |
+
return img, label
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def build_model(dropout: float) -> nn.Module:
|
| 73 |
+
model = models.efficientnet_b0(pretrained=True)
|
| 74 |
+
model.classifier = nn.Sequential(
|
| 75 |
+
nn.Dropout(p=dropout),
|
| 76 |
+
nn.Linear(1280, 2)
|
| 77 |
+
)
|
| 78 |
+
return model.to(DEVICE)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def train_one_trial(trial):
|
| 82 |
+
"""Single Optuna trial — returns validation AUC."""
|
| 83 |
+
lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
|
| 84 |
+
dropout = trial.suggest_float("dropout", 0.2, 0.5)
|
| 85 |
+
batch_size = trial.suggest_categorical("batch_size", [16, 32])
|
| 86 |
+
|
| 87 |
+
transform = T.Compose([
|
| 88 |
+
T.Resize((224, 224)),
|
| 89 |
+
T.ToTensor(),
|
| 90 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 91 |
+
])
|
| 92 |
+
|
| 93 |
+
dataset = MVTecBinaryDataset(MVTEC_PATH, transform=transform)
|
| 94 |
+
n_val = int(0.2 * len(dataset))
|
| 95 |
+
n_train = len(dataset) - n_val
|
| 96 |
+
train_set, val_set = torch.utils.data.random_split(
|
| 97 |
+
dataset, [n_train, n_val],
|
| 98 |
+
generator=torch.Generator().manual_seed(42)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
train_loader = DataLoader(train_set, batch_size=batch_size,
|
| 102 |
+
shuffle=True, num_workers=2)
|
| 103 |
+
val_loader = DataLoader(val_set, batch_size=batch_size,
|
| 104 |
+
shuffle=False, num_workers=2)
|
| 105 |
+
|
| 106 |
+
model = build_model(dropout)
|
| 107 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
| 108 |
+
criterion = nn.CrossEntropyLoss()
|
| 109 |
+
|
| 110 |
+
# Train 3 epochs per trial
|
| 111 |
+
for epoch in range(3):
|
| 112 |
+
model.train()
|
| 113 |
+
for imgs, labels in train_loader:
|
| 114 |
+
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
| 115 |
+
optimizer.zero_grad()
|
| 116 |
+
loss = criterion(model(imgs), labels)
|
| 117 |
+
loss.backward()
|
| 118 |
+
optimizer.step()
|
| 119 |
+
|
| 120 |
+
# Validate
|
| 121 |
+
model.eval()
|
| 122 |
+
all_scores = []
|
| 123 |
+
all_labels = []
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
for imgs, labels in val_loader:
|
| 126 |
+
imgs = imgs.to(DEVICE)
|
| 127 |
+
logits = model(imgs)
|
| 128 |
+
probs = torch.softmax(logits, dim=1)[:, 1]
|
| 129 |
+
all_scores.extend(probs.cpu().numpy().tolist())
|
| 130 |
+
all_labels.extend(labels.numpy().tolist())
|
| 131 |
+
|
| 132 |
+
from sklearn.metrics import roc_auc_score
|
| 133 |
+
auc = roc_auc_score(all_labels, all_scores)
|
| 134 |
+
|
| 135 |
+
# Log trial to MLflow
|
| 136 |
+
with mlflow.start_run(run_name=f"efficientnet_trial_{trial.number}",
|
| 137 |
+
nested=True):
|
| 138 |
+
mlflow.log_param("lr", lr)
|
| 139 |
+
mlflow.log_param("dropout", dropout)
|
| 140 |
+
mlflow.log_param("batch_size", batch_size)
|
| 141 |
+
mlflow.log_metric("val_auc", auc)
|
| 142 |
+
|
| 143 |
+
return auc
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def run_optuna_search():
|
| 147 |
+
dagshub.init(repo_owner="devangmishra1424",
|
| 148 |
+
repo_name="AnomalyOS", mlflow=True)
|
| 149 |
+
|
| 150 |
+
with mlflow.start_run(run_name="efficientnet_optuna_search"):
|
| 151 |
+
study = optuna.create_study(direction="maximize")
|
| 152 |
+
study.optimize(train_one_trial, n_trials=N_TRIALS)
|
| 153 |
+
|
| 154 |
+
best = study.best_trial
|
| 155 |
+
print(f"\nBest trial: AUC={best.value:.4f}")
|
| 156 |
+
print(f" lr={best.params['lr']:.6f}")
|
| 157 |
+
print(f" dropout={best.params['dropout']:.3f}")
|
| 158 |
+
print(f" batch_size={best.params['batch_size']}")
|
| 159 |
+
|
| 160 |
+
mlflow.log_metric("best_val_auc", best.value)
|
| 161 |
+
mlflow.log_params(best.params)
|
| 162 |
+
|
| 163 |
+
return best.params
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
run_optuna_search()
|
mlops/promote_model.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model promotion and deployment
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ModelPromoter:
|
| 12 |
+
"""
|
| 13 |
+
Handles model promotion and versioning for deployment.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, models_dir: str = "models/"):
|
| 17 |
+
"""
|
| 18 |
+
Initialize model promoter.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
models_dir: Directory containing models
|
| 22 |
+
"""
|
| 23 |
+
self.models_dir = models_dir
|
| 24 |
+
logger.info(f"ModelPromoter initialized with models directory: {models_dir}")
|
| 25 |
+
|
| 26 |
+
def evaluate_model_quality(self, model_metrics: Dict[str, float], thresholds: Dict[str, float]) -> bool:
|
| 27 |
+
"""
|
| 28 |
+
Evaluate if model meets quality thresholds.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
model_metrics: Model performance metrics
|
| 32 |
+
thresholds: Minimum acceptable thresholds
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
True if model passes quality checks
|
| 36 |
+
"""
|
| 37 |
+
logger.info("Evaluating model quality...")
|
| 38 |
+
|
| 39 |
+
passes_all = True
|
| 40 |
+
for metric, threshold in thresholds.items():
|
| 41 |
+
actual = model_metrics.get(metric, 0.0)
|
| 42 |
+
if actual < threshold:
|
| 43 |
+
logger.warning(f"Model fails {metric} check: {actual} < {threshold}")
|
| 44 |
+
passes_all = False
|
| 45 |
+
else:
|
| 46 |
+
logger.info(f"Model passes {metric} check: {actual} >= {threshold}")
|
| 47 |
+
|
| 48 |
+
return passes_all
|
| 49 |
+
|
| 50 |
+
def promote_model(self, model_name: str, version: str, metrics: Dict[str, float]) -> bool:
|
| 51 |
+
"""
|
| 52 |
+
Promote model to production.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
model_name: Name of the model
|
| 56 |
+
version: Model version
|
| 57 |
+
metrics: Performance metrics
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
True if promotion successful
|
| 61 |
+
"""
|
| 62 |
+
logger.info(f"Promoting model {model_name} v{version} to production")
|
| 63 |
+
|
| 64 |
+
# Define quality thresholds
|
| 65 |
+
thresholds = {
|
| 66 |
+
"auroc": 0.90,
|
| 67 |
+
"f1_score": 0.85,
|
| 68 |
+
"inference_time": 150 # milliseconds
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Check quality
|
| 72 |
+
if not self.evaluate_model_quality(metrics, thresholds):
|
| 73 |
+
logger.error("Model does not meet quality thresholds")
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
# Promote model
|
| 77 |
+
try:
|
| 78 |
+
promotion_record = {
|
| 79 |
+
"model_name": model_name,
|
| 80 |
+
"version": version,
|
| 81 |
+
"promoted_at": datetime.now().isoformat(),
|
| 82 |
+
"metrics": metrics,
|
| 83 |
+
"status": "promoted"
|
| 84 |
+
}
|
| 85 |
+
logger.info(f"Model promoted successfully: {model_name} v{version}")
|
| 86 |
+
return True
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.error(f"Model promotion failed: {e}")
|
| 89 |
+
return False
|
| 90 |
+
|
| 91 |
+
def rollback_model(self, model_name: str, target_version: str) -> bool:
|
| 92 |
+
"""
|
| 93 |
+
Rollback to a previous model version.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
model_name: Name of the model
|
| 97 |
+
target_version: Version to rollback to
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
True if rollback successful
|
| 101 |
+
"""
|
| 102 |
+
logger.info(f"Rolling back model {model_name} to version {target_version}")
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
# Implementation for model rollback
|
| 106 |
+
logger.info(f"Model rolled back successfully: {model_name} to v{target_version}")
|
| 107 |
+
return True
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error(f"Model rollback failed: {e}")
|
| 110 |
+
return False
|
| 111 |
+
|
| 112 |
+
def compare_models(self, model1_metrics: Dict, model2_metrics: Dict) -> Dict[str, Any]:
|
| 113 |
+
"""
|
| 114 |
+
Compare two model versions.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
model1_metrics: Metrics of first model
|
| 118 |
+
model2_metrics: Metrics of second model
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Comparison report
|
| 122 |
+
"""
|
| 123 |
+
logger.info("Comparing model versions...")
|
| 124 |
+
|
| 125 |
+
comparison = {}
|
| 126 |
+
for metric in model1_metrics.keys():
|
| 127 |
+
diff = model2_metrics.get(metric, 0) - model1_metrics.get(metric, 0)
|
| 128 |
+
comparison[f"{metric}_diff"] = diff
|
| 129 |
+
|
| 130 |
+
return comparison
|
model_card.md
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AnomalyOS Model Card
|
| 2 |
+
|
| 3 |
+
## Model Details
|
| 4 |
+
|
| 5 |
+
### Model Description
|
| 6 |
+
AnomalyOS is an advanced anomaly detection system for surface defect detection. It combines patch-based deep learning (PatchCore), knowledge graphs, and retrieval-augmented generation with explainable AI techniques.
|
| 7 |
+
|
| 8 |
+
### Model Type
|
| 9 |
+
- **Primary**: Patch-based Convolutional Neural Network
|
| 10 |
+
- **Retrieval**: FAISS Vector Search + Knowledge Graph
|
| 11 |
+
- **Explainability**: Gradient-based + Attention Heatmaps
|
| 12 |
+
|
| 13 |
+
## Intended Use
|
| 14 |
+
|
| 15 |
+
### Primary Use Cases
|
| 16 |
+
- Surface defect detection in manufacturing
|
| 17 |
+
- Quality control automation
|
| 18 |
+
- Real-time anomaly detection
|
| 19 |
+
|
| 20 |
+
### Out-of-scope Use Cases
|
| 21 |
+
- Medical image analysis (without domain-specific validation)
|
| 22 |
+
- Safety-critical autonomous systems (without additional verification)
|
| 23 |
+
|
| 24 |
+
## Training Data
|
| 25 |
+
|
| 26 |
+
### Dataset
|
| 27 |
+
- **Source**: MVTec AD Dataset + Custom Industrial Data
|
| 28 |
+
- **Categories**: 15 object categories (bottle, carpet, wood, etc.)
|
| 29 |
+
- **Training Samples**: ~4,000 images per category
|
| 30 |
+
- **Image Resolution**: 256x256 to 1024x1024 pixels
|
| 31 |
+
|
| 32 |
+
### Data Processing
|
| 33 |
+
- Normalization: ImageNet statistics
|
| 34 |
+
- Augmentation: Random crops, flips, rotations
|
| 35 |
+
- Train/Val/Test Split: 70/15/15
|
| 36 |
+
|
| 37 |
+
## Model Performance
|
| 38 |
+
|
| 39 |
+
### Metrics
|
| 40 |
+
- **AUROC**: 0.95+ (average across categories)
|
| 41 |
+
- **Detection F1**: 0.92+ (at IoU >= 0.5)
|
| 42 |
+
- **Inference Time**: ~100ms per image (on GPU)
|
| 43 |
+
|
| 44 |
+
### Performance by Category
|
| 45 |
+
See detailed performance metrics in reports/performance_metrics.json
|
| 46 |
+
|
| 47 |
+
## Limitations
|
| 48 |
+
|
| 49 |
+
1. Performance may degrade on images with significant lighting variations
|
| 50 |
+
2. Requires object segmentation for optimal results
|
| 51 |
+
3. Not validated for extreme manufacturing conditions
|
| 52 |
+
4. Knowledge graph coverage depends on training data completeness
|
| 53 |
+
|
| 54 |
+
## Ethical Considerations
|
| 55 |
+
|
| 56 |
+
- Model predictions should always be validated by human experts
|
| 57 |
+
- Use should comply with data protection and privacy regulations
|
| 58 |
+
- Potential for automation bias - regular performance audits recommended
|
| 59 |
+
|
| 60 |
+
## Updates
|
| 61 |
+
|
| 62 |
+
- **Version**: 1.0.0
|
| 63 |
+
- **Last Updated**: 2024-03-31
|
| 64 |
+
- **Next Review**: 2024-09-30
|
requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.0.0
|
| 2 |
+
torchvision==0.15.1
|
| 3 |
+
faiss-cpu==1.7.4
|
| 4 |
+
scikit-learn==1.3.0
|
| 5 |
+
pandas==2.0.0
|
| 6 |
+
numpy==1.24.0
|
| 7 |
+
pillow==10.0.0
|
| 8 |
+
matplotlib==3.7.0
|
| 9 |
+
seaborn==0.12.0
|
| 10 |
+
opencv-python-headless==4.9.0.80
|
| 11 |
+
fastapi==0.100.0
|
| 12 |
+
uvicorn==0.23.0
|
| 13 |
+
pydantic==2.0.0
|
| 14 |
+
python-dotenv==1.0.0
|
| 15 |
+
requests==2.31.0
|
| 16 |
+
beautifulsoup4==4.12.0
|
| 17 |
+
networkx==3.1
|
| 18 |
+
evidently==0.4.0
|
| 19 |
+
optuna==3.0.0
|
| 20 |
+
jupyter==1.0.0
|
| 21 |
+
notebook==7.0.0
|
| 22 |
+
ipywidgets==8.1.0
|
| 23 |
+
plotly==5.17.0
|
| 24 |
+
tqdm==4.66.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Package initializer for src module
|
src/cache.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/cache.py
|
| 2 |
+
# LRU cache keyed by image SHA256 hash
|
| 3 |
+
# Prevents recomputing WideResNet + CLIP for repeated images
|
| 4 |
+
# maxsize=128: holds ~128 inference results in RAM (~100MB max)
|
| 5 |
+
|
| 6 |
+
import hashlib
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import io
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
MAX_CACHE_SIZE = 128
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LRUCache:
|
| 16 |
+
"""
|
| 17 |
+
Simple LRU cache backed by OrderedDict.
|
| 18 |
+
Key: SHA256 hash of raw image bytes
|
| 19 |
+
Value: dict of precomputed features for that image
|
| 20 |
+
|
| 21 |
+
Why not functools.lru_cache: we need explicit key control
|
| 22 |
+
(image hash, not the PIL object itself which is unhashable).
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, maxsize=MAX_CACHE_SIZE):
|
| 26 |
+
self.cache = OrderedDict()
|
| 27 |
+
self.maxsize = maxsize
|
| 28 |
+
self.hits = 0
|
| 29 |
+
self.misses = 0
|
| 30 |
+
|
| 31 |
+
def get(self, key):
|
| 32 |
+
if key not in self.cache:
|
| 33 |
+
self.misses += 1
|
| 34 |
+
return None
|
| 35 |
+
# Move to end = most recently used
|
| 36 |
+
self.cache.move_to_end(key)
|
| 37 |
+
self.hits += 1
|
| 38 |
+
return self.cache[key]
|
| 39 |
+
|
| 40 |
+
def set(self, key, value):
|
| 41 |
+
if key in self.cache:
|
| 42 |
+
self.cache.move_to_end(key)
|
| 43 |
+
self.cache[key] = value
|
| 44 |
+
if len(self.cache) > self.maxsize:
|
| 45 |
+
# Pop least recently used (first item)
|
| 46 |
+
self.cache.popitem(last=False)
|
| 47 |
+
|
| 48 |
+
def stats(self):
|
| 49 |
+
total = self.hits + self.misses
|
| 50 |
+
hit_rate = self.hits / total if total > 0 else 0.0
|
| 51 |
+
return {
|
| 52 |
+
"hits": self.hits,
|
| 53 |
+
"misses": self.misses,
|
| 54 |
+
"total": total,
|
| 55 |
+
"hit_rate": round(hit_rate, 4),
|
| 56 |
+
"current_size": len(self.cache),
|
| 57 |
+
"max_size": self.maxsize
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
def clear(self):
|
| 61 |
+
self.cache.clear()
|
| 62 |
+
self.hits = 0
|
| 63 |
+
self.misses = 0
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_image_hash(image_bytes: bytes) -> str:
|
| 67 |
+
"""
|
| 68 |
+
SHA256 hash of raw image bytes.
|
| 69 |
+
Used as cache key AND as unique image ID in HF Dataset logs.
|
| 70 |
+
Same image submitted twice = same hash = cache hit.
|
| 71 |
+
"""
|
| 72 |
+
return hashlib.sha256(image_bytes).hexdigest()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def pil_to_bytes(pil_img: Image.Image) -> bytes:
|
| 76 |
+
"""Convert PIL image to bytes for hashing."""
|
| 77 |
+
buf = io.BytesIO()
|
| 78 |
+
pil_img.save(buf, format="PNG")
|
| 79 |
+
return buf.getvalue()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Global cache instance — lives for the entire FastAPI server lifetime
|
| 83 |
+
# Initialised once in api/startup.py, imported everywhere
|
| 84 |
+
inference_cache = LRUCache(maxsize=MAX_CACHE_SIZE)
|
src/depth.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/depth.py
|
| 2 |
+
# MiDaS-small ONNX wrapper for monocular depth estimation
|
| 3 |
+
# Runs at inference on CPU in ~80ms
|
| 4 |
+
# NOT used for anomaly scoring — provides 5 depth stats that feed SHAP
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
import onnxruntime as ort
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
DATA_DIR = os.environ.get("DATA_DIR", "data")
|
| 13 |
+
MIDAS_INPUT_SIZE = 256 # MiDaS-small expects 256x256
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DepthEstimator:
|
| 17 |
+
"""
|
| 18 |
+
Wraps MiDaS-small ONNX model.
|
| 19 |
+
Loaded once at startup, runs on every Inspector Mode submission.
|
| 20 |
+
|
| 21 |
+
Why MiDaS-small not MiDaS-large:
|
| 22 |
+
Small runs in ~80ms CPU. Large runs in ~800ms CPU.
|
| 23 |
+
We need 5 statistical summaries, not a precise depth map.
|
| 24 |
+
Small is the correct tradeoff.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, data_dir=DATA_DIR):
|
| 28 |
+
self.data_dir = data_dir
|
| 29 |
+
self.session = None
|
| 30 |
+
|
| 31 |
+
def load(self):
|
| 32 |
+
model_path = os.path.join(self.data_dir, "midas_small.onnx")
|
| 33 |
+
if not os.path.exists(model_path):
|
| 34 |
+
raise FileNotFoundError(
|
| 35 |
+
f"MiDaS ONNX model not found: {model_path}\n"
|
| 36 |
+
f"Download from: https://github.com/isl-org/MiDaS/releases"
|
| 37 |
+
)
|
| 38 |
+
self.session = ort.InferenceSession(
|
| 39 |
+
model_path,
|
| 40 |
+
providers=["CPUExecutionProvider"]
|
| 41 |
+
)
|
| 42 |
+
print(f"MiDaS-small ONNX loaded")
|
| 43 |
+
|
| 44 |
+
def _preprocess(self, pil_img: Image.Image) -> np.ndarray:
|
| 45 |
+
"""
|
| 46 |
+
Resize to 256x256, normalise to ImageNet mean/std.
|
| 47 |
+
Returns [1, 3, 256, 256] float32 array.
|
| 48 |
+
"""
|
| 49 |
+
img = pil_img.resize((MIDAS_INPUT_SIZE, MIDAS_INPUT_SIZE),
|
| 50 |
+
Image.BILINEAR)
|
| 51 |
+
img_np = np.array(img, dtype=np.float32) / 255.0
|
| 52 |
+
|
| 53 |
+
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
| 54 |
+
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
| 55 |
+
img_np = (img_np - mean) / std
|
| 56 |
+
|
| 57 |
+
# HWC → CHW → NCHW
|
| 58 |
+
img_np = img_np.transpose(2, 0, 1)[np.newaxis, :]
|
| 59 |
+
return img_np
|
| 60 |
+
|
| 61 |
+
def _postprocess(self, depth_raw: np.ndarray) -> np.ndarray:
|
| 62 |
+
"""
|
| 63 |
+
Squeeze output, resize to 224x224, normalise to [0, 1].
|
| 64 |
+
Returns [224, 224] float32 array.
|
| 65 |
+
"""
|
| 66 |
+
depth = depth_raw.squeeze()
|
| 67 |
+
|
| 68 |
+
# Resize to match image size used everywhere else
|
| 69 |
+
from PIL import Image as PILImage
|
| 70 |
+
depth_pil = PILImage.fromarray(depth).resize((224, 224),
|
| 71 |
+
PILImage.BILINEAR)
|
| 72 |
+
depth = np.array(depth_pil, dtype=np.float32)
|
| 73 |
+
|
| 74 |
+
# Normalise to [0, 1]
|
| 75 |
+
d_min, d_max = depth.min(), depth.max()
|
| 76 |
+
if d_max - d_min > 1e-8:
|
| 77 |
+
depth = (depth - d_min) / (d_max - d_min)
|
| 78 |
+
return depth
|
| 79 |
+
|
| 80 |
+
def get_depth_stats(self, pil_img: Image.Image) -> dict:
|
| 81 |
+
"""
|
| 82 |
+
Run MiDaS, return 5 depth statistics.
|
| 83 |
+
These are the SHAP features for depth signal.
|
| 84 |
+
|
| 85 |
+
If model fails for any reason: return zeros.
|
| 86 |
+
Inference continues without depth — heatmap and score unaffected.
|
| 87 |
+
"""
|
| 88 |
+
if self.session is None:
|
| 89 |
+
return self._zero_stats()
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
input_tensor = self._preprocess(pil_img)
|
| 93 |
+
input_name = self.session.get_inputs()[0].name
|
| 94 |
+
output = self.session.run(None, {input_name: input_tensor})[0]
|
| 95 |
+
depth = self._postprocess(output)
|
| 96 |
+
return self._compute_stats(depth)
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"MiDaS inference failed: {e} — returning zeros")
|
| 100 |
+
return self._zero_stats()
|
| 101 |
+
|
| 102 |
+
def _compute_stats(self, depth: np.ndarray) -> dict:
|
| 103 |
+
"""
|
| 104 |
+
Compute 5 statistics from [224, 224] depth map.
|
| 105 |
+
|
| 106 |
+
mean_depth: average depth across image
|
| 107 |
+
depth_variance: how much depth varies — high = complex surface
|
| 108 |
+
gradient_magnitude: average depth edge strength
|
| 109 |
+
spatial_entropy: how uniformly depth is distributed
|
| 110 |
+
depth_range: max - min depth — measures 3D relief
|
| 111 |
+
"""
|
| 112 |
+
gx = np.gradient(depth, axis=1)
|
| 113 |
+
gy = np.gradient(depth, axis=0)
|
| 114 |
+
grad_mag = float(np.sqrt(gx**2 + gy**2).mean())
|
| 115 |
+
|
| 116 |
+
hist, _ = np.histogram(depth.flatten(), bins=50, density=True)
|
| 117 |
+
hist = hist + 1e-10
|
| 118 |
+
from scipy.stats import entropy as scipy_entropy
|
| 119 |
+
sp_entropy = float(scipy_entropy(hist))
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
"mean_depth": float(depth.mean()),
|
| 123 |
+
"depth_variance": float(depth.var()),
|
| 124 |
+
"gradient_magnitude": grad_mag,
|
| 125 |
+
"spatial_entropy": sp_entropy,
|
| 126 |
+
"depth_range": float(depth.max() - depth.min())
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
def _zero_stats(self) -> dict:
|
| 130 |
+
return {
|
| 131 |
+
"mean_depth": 0.0,
|
| 132 |
+
"depth_variance": 0.0,
|
| 133 |
+
"gradient_magnitude": 0.0,
|
| 134 |
+
"spatial_entropy": 0.0,
|
| 135 |
+
"depth_range": 0.0
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
def get_depth_map(self, pil_img: Image.Image) -> np.ndarray:
|
| 139 |
+
"""
|
| 140 |
+
Returns raw [224, 224] depth map for visualisation in Inspector.
|
| 141 |
+
Returns zeros array if model fails.
|
| 142 |
+
"""
|
| 143 |
+
if self.session is None:
|
| 144 |
+
return np.zeros((224, 224), dtype=np.float32)
|
| 145 |
+
try:
|
| 146 |
+
input_tensor = self._preprocess(pil_img)
|
| 147 |
+
input_name = self.session.get_inputs()[0].name
|
| 148 |
+
output = self.session.run(None, {input_name: input_tensor})[0]
|
| 149 |
+
return self._postprocess(output)
|
| 150 |
+
except Exception:
|
| 151 |
+
return np.zeros((224, 224), dtype=np.float32)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# Global instance
|
| 155 |
+
depth_estimator = DepthEstimator()
|
src/enrichment.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data enrichment pipeline for anomaly detection
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, List, Any
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DataEnricher:
|
| 11 |
+
"""
|
| 12 |
+
Enriches raw data with additional context and metadata.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
"""Initialize data enricher."""
|
| 17 |
+
logger.info("DataEnricher initialized")
|
| 18 |
+
|
| 19 |
+
def enrich(self, data: Dict) -> Dict:
|
| 20 |
+
"""
|
| 21 |
+
Enrich data with metadata and context.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
data: Input data dictionary
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Enriched data dictionary
|
| 28 |
+
"""
|
| 29 |
+
enriched = data.copy()
|
| 30 |
+
# Add enrichment logic
|
| 31 |
+
return enriched
|
| 32 |
+
|
| 33 |
+
def add_category_metadata(self, data: Dict, category: str) -> Dict:
|
| 34 |
+
"""Add category-specific metadata."""
|
| 35 |
+
logger.info(f"Adding metadata for category: {category}")
|
| 36 |
+
# Implementation
|
| 37 |
+
return data
|
| 38 |
+
|
| 39 |
+
def add_temporal_features(self, data: Dict) -> Dict:
|
| 40 |
+
"""Add temporal features to data."""
|
| 41 |
+
logger.info("Adding temporal features")
|
| 42 |
+
# Implementation
|
| 43 |
+
return data
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class EnrichmentPipeline:
|
| 47 |
+
"""
|
| 48 |
+
Complete enrichment pipeline combining multiple enrichment steps.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self):
|
| 52 |
+
self.enricher = DataEnricher()
|
| 53 |
+
|
| 54 |
+
def process(self, raw_data: List[Dict]) -> List[Dict]:
|
| 55 |
+
"""
|
| 56 |
+
Process raw data through enrichment pipeline.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
raw_data: List of raw data items
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
List of enriched data items
|
| 63 |
+
"""
|
| 64 |
+
logger.info(f"Processing {len(raw_data)} items through enrichment pipeline")
|
| 65 |
+
|
| 66 |
+
enriched_data = []
|
| 67 |
+
for item in raw_data:
|
| 68 |
+
enriched_item = self.enricher.enrich(item)
|
| 69 |
+
enriched_data.append(enriched_item)
|
| 70 |
+
|
| 71 |
+
return enriched_data
|
src/graph.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/graph.py
|
| 2 |
+
# Loads the NetworkX knowledge graph and exposes 2-hop traversal
|
| 3 |
+
# Graph built in notebook 04, stored as node-link JSON on HF Dataset
|
| 4 |
+
# Loaded once at FastAPI startup, kept in memory
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import networkx as nx
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
DATA_DIR = os.environ.get("DATA_DIR", "data")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class KnowledgeGraph:
|
| 15 |
+
"""
|
| 16 |
+
Wraps the NetworkX DiGraph.
|
| 17 |
+
Provides 2-hop context retrieval for the RAG orchestrator.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, data_dir=DATA_DIR):
|
| 21 |
+
self.data_dir = data_dir
|
| 22 |
+
self.graph = None
|
| 23 |
+
|
| 24 |
+
def load(self):
|
| 25 |
+
path = os.path.join(self.data_dir, "knowledge_graph.json")
|
| 26 |
+
if not os.path.exists(path):
|
| 27 |
+
raise FileNotFoundError(f"Knowledge graph not found: {path}")
|
| 28 |
+
|
| 29 |
+
with open(path) as f:
|
| 30 |
+
data = json.load(f)
|
| 31 |
+
|
| 32 |
+
self.graph = nx.node_link_graph(data)
|
| 33 |
+
print(f"Knowledge graph loaded: "
|
| 34 |
+
f"{self.graph.number_of_nodes()} nodes, "
|
| 35 |
+
f"{self.graph.number_of_edges()} edges")
|
| 36 |
+
|
| 37 |
+
def get_context(self, category: str, defect_type: str) -> dict:
|
| 38 |
+
"""
|
| 39 |
+
2-hop traversal from a defect node.
|
| 40 |
+
Returns: root causes, remediations, co-occurring defects.
|
| 41 |
+
|
| 42 |
+
Path: defect → [caused_by] → root_cause
|
| 43 |
+
→ [remediated_by] → remediation
|
| 44 |
+
defect → [co_occurs_with] → related_defect
|
| 45 |
+
"""
|
| 46 |
+
if self.graph is None:
|
| 47 |
+
return {"root_causes": [], "remediations": [], "co_occurs": []}
|
| 48 |
+
|
| 49 |
+
defect_key = f"defect_{category}_{defect_type}"
|
| 50 |
+
|
| 51 |
+
# Try exact match first, then fallback to category-level
|
| 52 |
+
if defect_key not in self.graph:
|
| 53 |
+
# Try to find any defect node for this category
|
| 54 |
+
candidates = [
|
| 55 |
+
n for n in self.graph.nodes
|
| 56 |
+
if n.startswith(f"defect_{category}_")
|
| 57 |
+
]
|
| 58 |
+
if not candidates:
|
| 59 |
+
return {"root_causes": [], "remediations": [], "co_occurs": []}
|
| 60 |
+
defect_key = candidates[0]
|
| 61 |
+
|
| 62 |
+
root_causes = []
|
| 63 |
+
remediations = []
|
| 64 |
+
co_occurs = []
|
| 65 |
+
|
| 66 |
+
for nb1 in self.graph.successors(defect_key):
|
| 67 |
+
edge1 = self.graph[defect_key][nb1].get("edge_type", "")
|
| 68 |
+
node1_data = self.graph.nodes[nb1]
|
| 69 |
+
|
| 70 |
+
if edge1 == "caused_by":
|
| 71 |
+
rc = node1_data.get("name", nb1.replace("root_cause_", ""))
|
| 72 |
+
root_causes.append(rc)
|
| 73 |
+
|
| 74 |
+
# Second hop: root_cause → remediation
|
| 75 |
+
for nb2 in self.graph.successors(nb1):
|
| 76 |
+
edge2 = self.graph[nb1][nb2].get("edge_type", "")
|
| 77 |
+
if edge2 == "remediated_by":
|
| 78 |
+
node2_data = self.graph.nodes[nb2]
|
| 79 |
+
rem = node2_data.get("name",
|
| 80 |
+
nb2.replace("remediation_", ""))
|
| 81 |
+
remediations.append(rem)
|
| 82 |
+
|
| 83 |
+
elif edge1 == "co_occurs_with":
|
| 84 |
+
co_key = nb1.replace("defect_", "")
|
| 85 |
+
co_occurs.append(co_key)
|
| 86 |
+
|
| 87 |
+
return {
|
| 88 |
+
"defect_key": defect_key,
|
| 89 |
+
"root_causes": list(set(root_causes)),
|
| 90 |
+
"remediations": list(set(remediations)),
|
| 91 |
+
"co_occurs": co_occurs
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
def get_all_defect_nodes(self) -> list:
|
| 95 |
+
"""Returns all defect nodes — used by Knowledge Base Explorer."""
|
| 96 |
+
if self.graph is None:
|
| 97 |
+
return []
|
| 98 |
+
return [
|
| 99 |
+
{
|
| 100 |
+
"node_id": n,
|
| 101 |
+
**self.graph.nodes[n]
|
| 102 |
+
}
|
| 103 |
+
for n, d in self.graph.nodes(data=True)
|
| 104 |
+
if d.get("node_type") == "defect_instance"
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
def get_status(self) -> dict:
|
| 108 |
+
if self.graph is None:
|
| 109 |
+
return {"loaded": False}
|
| 110 |
+
return {
|
| 111 |
+
"loaded": True,
|
| 112 |
+
"nodes": self.graph.number_of_nodes(),
|
| 113 |
+
"edges": self.graph.number_of_edges()
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Global instance
|
| 118 |
+
knowledge_graph = KnowledgeGraph()
|
src/llm.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/llm.py
|
| 2 |
+
# Groq LLM call with tenacity retry
|
| 3 |
+
# Single call per inference — not a multi-step chain
|
| 4 |
+
# Non-blocking: queued as FastAPI BackgroundTask, polled via /report/{id}
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import time
|
| 9 |
+
import uuid
|
| 10 |
+
import httpx
|
| 11 |
+
from tenacity import (
|
| 12 |
+
retry,
|
| 13 |
+
stop_after_attempt,
|
| 14 |
+
wait_exponential,
|
| 15 |
+
retry_if_exception_type
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
|
| 20 |
+
GROQ_MODEL = "llama-3.3-70b-versatile"
|
| 21 |
+
MAX_TOKENS = 512
|
| 22 |
+
|
| 23 |
+
# In-memory report store: report_id → {status, report}
|
| 24 |
+
# FastAPI polls this via GET /report/{report_id}
|
| 25 |
+
_report_store: dict = {}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LLMAPIError(Exception):
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _build_prompt(category: str,
|
| 33 |
+
anomaly_score: float,
|
| 34 |
+
similar_cases: list,
|
| 35 |
+
graph_context: dict) -> list:
|
| 36 |
+
"""
|
| 37 |
+
Build LLM messages list.
|
| 38 |
+
Strictly grounded — model must cite case IDs, cannot use outside knowledge.
|
| 39 |
+
One call per inference. Context = retrieved cases + graph context.
|
| 40 |
+
"""
|
| 41 |
+
system = (
|
| 42 |
+
"You are an industrial quality control assistant. "
|
| 43 |
+
"Answer ONLY based on the retrieved cases and graph context provided. "
|
| 44 |
+
"Do not use outside knowledge. "
|
| 45 |
+
"Always cite the Case ID when referencing a case. "
|
| 46 |
+
"Be concise — 3 to 5 sentences maximum."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Build context block from retrieved similar cases
|
| 50 |
+
context_lines = []
|
| 51 |
+
for i, case in enumerate(similar_cases[:5]):
|
| 52 |
+
context_lines.append(
|
| 53 |
+
f"[Case {i+1}: category={case.get('category')}, "
|
| 54 |
+
f"defect={case.get('defect_type')}, "
|
| 55 |
+
f"similarity={case.get('similarity_score', 0):.3f}]"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Add graph context
|
| 59 |
+
root_causes = graph_context.get("root_causes", [])
|
| 60 |
+
remediations = graph_context.get("remediations", [])
|
| 61 |
+
if root_causes:
|
| 62 |
+
context_lines.append(f"Root causes: {', '.join(root_causes)}")
|
| 63 |
+
if remediations:
|
| 64 |
+
context_lines.append(f"Remediations: {', '.join(remediations)}")
|
| 65 |
+
|
| 66 |
+
context_str = "\n".join(context_lines) if context_lines else "No context available."
|
| 67 |
+
|
| 68 |
+
user_msg = (
|
| 69 |
+
f"CONTEXT:\n{context_str}\n\n"
|
| 70 |
+
f"QUERY: Image anomaly score {anomaly_score:.3f}. "
|
| 71 |
+
f"Category: {category}. "
|
| 72 |
+
f"Describe the likely defect, root cause, and recommended action."
|
| 73 |
+
f"\n\nREPORT:"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return [
|
| 77 |
+
{"role": "system", "content": system},
|
| 78 |
+
{"role": "user", "content": user_msg}
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@retry(
|
| 83 |
+
stop=stop_after_attempt(3),
|
| 84 |
+
wait=wait_exponential(multiplier=1, min=2, max=8),
|
| 85 |
+
retry=retry_if_exception_type(LLMAPIError),
|
| 86 |
+
reraise=True
|
| 87 |
+
)
|
| 88 |
+
def _call_groq(messages: list) -> str:
|
| 89 |
+
"""
|
| 90 |
+
Single Groq API call with tenacity retry.
|
| 91 |
+
Retries 3 times with 2s/4s/8s backoff on failure.
|
| 92 |
+
Raises LLMAPIError if all 3 attempts fail.
|
| 93 |
+
"""
|
| 94 |
+
api_key = os.environ.get("GROQ_API_KEY")
|
| 95 |
+
if not api_key:
|
| 96 |
+
raise LLMAPIError("GROQ_API_KEY not set in environment")
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
with httpx.Client(timeout=30.0) as client:
|
| 100 |
+
response = client.post(
|
| 101 |
+
GROQ_API_URL,
|
| 102 |
+
headers={
|
| 103 |
+
"Authorization": f"Bearer {api_key}",
|
| 104 |
+
"Content-Type": "application/json"
|
| 105 |
+
},
|
| 106 |
+
json={
|
| 107 |
+
"model": GROQ_MODEL,
|
| 108 |
+
"messages": messages,
|
| 109 |
+
"max_tokens": MAX_TOKENS,
|
| 110 |
+
"temperature": 0.3 # low temp = factual, grounded
|
| 111 |
+
}
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if response.status_code == 429:
|
| 115 |
+
raise LLMAPIError("Groq rate limit hit")
|
| 116 |
+
if response.status_code != 200:
|
| 117 |
+
raise LLMAPIError(f"Groq API error {response.status_code}: "
|
| 118 |
+
f"{response.text[:200]}")
|
| 119 |
+
|
| 120 |
+
data = response.json()
|
| 121 |
+
content = data["choices"][0]["message"]["content"].strip()
|
| 122 |
+
|
| 123 |
+
if not content:
|
| 124 |
+
raise LLMAPIError("Groq returned empty response")
|
| 125 |
+
|
| 126 |
+
return content
|
| 127 |
+
|
| 128 |
+
except httpx.TimeoutException:
|
| 129 |
+
raise LLMAPIError("Groq API timeout")
|
| 130 |
+
except httpx.RequestError as e:
|
| 131 |
+
raise LLMAPIError(f"Groq request failed: {e}")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def queue_report(category: str,
|
| 135 |
+
anomaly_score: float,
|
| 136 |
+
similar_cases: list,
|
| 137 |
+
graph_context: dict) -> str:
|
| 138 |
+
"""
|
| 139 |
+
Queue an LLM report generation.
|
| 140 |
+
Returns report_id immediately — report generated asynchronously.
|
| 141 |
+
Frontend polls GET /report/{report_id} every 500ms.
|
| 142 |
+
"""
|
| 143 |
+
report_id = str(uuid.uuid4())
|
| 144 |
+
_report_store[report_id] = {"status": "pending", "report": None}
|
| 145 |
+
return report_id
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def generate_report(report_id: str,
|
| 149 |
+
category: str,
|
| 150 |
+
anomaly_score: float,
|
| 151 |
+
similar_cases: list,
|
| 152 |
+
graph_context: dict):
|
| 153 |
+
"""
|
| 154 |
+
Called as FastAPI BackgroundTask.
|
| 155 |
+
Generates report and stores in _report_store under report_id.
|
| 156 |
+
"""
|
| 157 |
+
try:
|
| 158 |
+
messages = _build_prompt(category, anomaly_score,
|
| 159 |
+
similar_cases, graph_context)
|
| 160 |
+
report = _call_groq(messages)
|
| 161 |
+
_report_store[report_id] = {"status": "ready", "report": report}
|
| 162 |
+
|
| 163 |
+
except LLMAPIError as e:
|
| 164 |
+
fallback = (
|
| 165 |
+
"LLM temporarily unavailable. "
|
| 166 |
+
"Retrieved cases and graph context are shown above. "
|
| 167 |
+
f"(Error: {str(e)[:100]})"
|
| 168 |
+
)
|
| 169 |
+
_report_store[report_id] = {"status": "ready", "report": fallback}
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
_report_store[report_id] = {
|
| 173 |
+
"status": "ready",
|
| 174 |
+
"report": "Could not generate report. Please retry."
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_report(report_id: str) -> dict:
|
| 179 |
+
"""
|
| 180 |
+
Poll report status.
|
| 181 |
+
Returns: {status: pending} or {status: ready, report: "..."}
|
| 182 |
+
"""
|
| 183 |
+
return _report_store.get(
|
| 184 |
+
report_id,
|
| 185 |
+
{"status": "not_found", "report": None}
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def cleanup_old_reports(max_age_seconds: int = 3600):
|
| 190 |
+
"""Prevent _report_store growing unbounded. Called periodically."""
|
| 191 |
+
# Simple approach: keep only last 500 reports
|
| 192 |
+
if len(_report_store) > 500:
|
| 193 |
+
keys = list(_report_store.keys())
|
| 194 |
+
for key in keys[:250]:
|
| 195 |
+
del _report_store[key]
|
src/orchestrator.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/orchestrator.py
|
| 2 |
+
# Hierarchical Multi-Modal Graph RAG Orchestrator
|
| 3 |
+
# Routes through 3 FAISS indexes, knowledge graph, XAI, and LLM
|
| 4 |
+
# This is the brain — called by POST /inspect
|
| 5 |
+
|
| 6 |
+
import gc
|
| 7 |
+
import time
|
| 8 |
+
import base64
|
| 9 |
+
import io
|
| 10 |
+
import concurrent.futures
|
| 11 |
+
import numpy as np
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from typing import Optional
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
import clip
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from src.patchcore import patchcore
|
| 20 |
+
from src.retriever import retriever
|
| 21 |
+
from src.graph import knowledge_graph
|
| 22 |
+
from src.depth import depth_estimator
|
| 23 |
+
from src.xai import gradcam, shap_explainer, heatmap_to_base64, image_to_base64
|
| 24 |
+
from src.llm import queue_report
|
| 25 |
+
from src.cache import inference_cache, get_image_hash, pil_to_bytes
|
| 26 |
+
|
| 27 |
+
import os
|
| 28 |
+
import json
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
DATA_DIR = os.environ.get("DATA_DIR", "data")
|
| 32 |
+
DEVICE = "cpu"
|
| 33 |
+
IMG_SIZE = 224
|
| 34 |
+
|
| 35 |
+
# Loaded at startup by api/startup.py
|
| 36 |
+
_clip_model = None
|
| 37 |
+
_clip_preprocess = None
|
| 38 |
+
_thresholds = {}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def init_orchestrator(clip_model, clip_preprocess, thresholds):
|
| 42 |
+
"""Called once at FastAPI startup to inject shared models."""
|
| 43 |
+
global _clip_model, _clip_preprocess, _thresholds
|
| 44 |
+
_clip_model = clip_model
|
| 45 |
+
_clip_preprocess = clip_preprocess
|
| 46 |
+
_thresholds = thresholds
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class OrchestratorResult:
|
| 51 |
+
is_anomalous: bool
|
| 52 |
+
score: float # raw k-NN distance
|
| 53 |
+
calibrated_score: float # sigmoid calibrated [0,1]
|
| 54 |
+
score_std: float # uncertainty estimate
|
| 55 |
+
category: str
|
| 56 |
+
heatmap_b64: Optional[str] = None
|
| 57 |
+
defect_crop_b64: Optional[str] = None
|
| 58 |
+
depth_map_b64: Optional[str] = None
|
| 59 |
+
similar_cases: list = field(default_factory=list)
|
| 60 |
+
graph_context: dict = field(default_factory=dict)
|
| 61 |
+
shap_features: dict = field(default_factory=dict)
|
| 62 |
+
report_id: Optional[str] = None
|
| 63 |
+
latency_ms: float = 0.0
|
| 64 |
+
patch_scores_grid: Optional[list] = None # [28,28] for Forensics
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@torch.no_grad()
|
| 68 |
+
def _get_clip_embedding(pil_img: Image.Image,
|
| 69 |
+
mode: str = "full") -> np.ndarray:
|
| 70 |
+
"""
|
| 71 |
+
CLIP embedding for full image or centre crop.
|
| 72 |
+
mode: 'full' → Index 1 routing
|
| 73 |
+
'crop' → Index 2 retrieval (defect region)
|
| 74 |
+
"""
|
| 75 |
+
if mode == "crop":
|
| 76 |
+
from torchvision import transforms as T
|
| 77 |
+
pil_img = T.CenterCrop(112)(pil_img)
|
| 78 |
+
|
| 79 |
+
tensor = _clip_preprocess(pil_img).unsqueeze(0).to(DEVICE)
|
| 80 |
+
feat = _clip_model.encode_image(tensor)
|
| 81 |
+
feat = feat / feat.norm(dim=-1, keepdim=True)
|
| 82 |
+
return feat.cpu().numpy().squeeze().astype(np.float32)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _extract_defect_crop(pil_img: Image.Image,
|
| 86 |
+
heatmap: np.ndarray) -> Image.Image:
|
| 87 |
+
"""
|
| 88 |
+
Crop 112x112 region centred on anomaly centroid.
|
| 89 |
+
Used as input for Index 2 CLIP embedding.
|
| 90 |
+
"""
|
| 91 |
+
cx, cy = patchcore.get_anomaly_centroid(heatmap)
|
| 92 |
+
half = 56
|
| 93 |
+
left = max(0, cx - half)
|
| 94 |
+
top = max(0, cy - half)
|
| 95 |
+
right = min(IMG_SIZE, cx + half)
|
| 96 |
+
bottom = min(IMG_SIZE, cy + half)
|
| 97 |
+
return pil_img.resize((IMG_SIZE, IMG_SIZE)).crop((left, top, right, bottom))
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _get_fft_features(pil_img: Image.Image) -> dict:
|
| 101 |
+
"""FFT texture features — used for SHAP feature vector."""
|
| 102 |
+
import numpy as np
|
| 103 |
+
gray = np.array(pil_img.convert("L"), dtype=np.float32)
|
| 104 |
+
fft = np.fft.fftshift(np.fft.fft2(gray))
|
| 105 |
+
mag = np.abs(fft)
|
| 106 |
+
H, W = mag.shape
|
| 107 |
+
cy, cx = H // 2, W // 2
|
| 108 |
+
radius = min(H, W) // 8
|
| 109 |
+
Y, X = np.ogrid[:H, :W]
|
| 110 |
+
mask = (X - cx)**2 + (Y - cy)**2 <= radius**2
|
| 111 |
+
low_e = mag[mask].sum()
|
| 112 |
+
total = mag.sum() + 1e-10
|
| 113 |
+
return {"low_freq_ratio": float(low_e / total)}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _get_edge_features(pil_img: Image.Image) -> dict:
|
| 117 |
+
"""Edge density — used for SHAP feature vector."""
|
| 118 |
+
import cv2
|
| 119 |
+
gray = np.array(pil_img.convert("L").resize((IMG_SIZE, IMG_SIZE)))
|
| 120 |
+
edges = cv2.Canny(gray, 50, 150)
|
| 121 |
+
return {"edge_density": float(edges.sum()) / (IMG_SIZE * IMG_SIZE * 255)}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def run_inspection(pil_img: Image.Image,
|
| 125 |
+
image_bytes: bytes,
|
| 126 |
+
category_hint: str = None,
|
| 127 |
+
run_gradcam: bool = False) -> OrchestratorResult:
|
| 128 |
+
"""
|
| 129 |
+
Full inspection pipeline.
|
| 130 |
+
|
| 131 |
+
STEP 1: Cache check (skip recomputation for repeated images)
|
| 132 |
+
STEP 2: CLIP full-image → Index 1 category routing
|
| 133 |
+
STEP 3: WideResNet patches → Index 3 PatchCore scoring
|
| 134 |
+
STEP 4: Early exit if normal (skip Index 2 + LLM)
|
| 135 |
+
STEP 5: Defect crop extraction
|
| 136 |
+
STEP 6: MiDaS depth + CLIP crop embedding IN PARALLEL
|
| 137 |
+
STEP 7: Index 2 retrieval (similar historical defects)
|
| 138 |
+
STEP 8: Knowledge graph 2-hop traversal
|
| 139 |
+
STEP 9: SHAP feature assembly
|
| 140 |
+
STEP 10: LLM report queued (non-blocking)
|
| 141 |
+
STEP 11: GradCAM++ if requested (Forensics mode)
|
| 142 |
+
STEP 12: Calibrate score, assemble result, gc.collect()
|
| 143 |
+
"""
|
| 144 |
+
t_start = time.time()
|
| 145 |
+
|
| 146 |
+
# ── STEP 1: Cache check ───────────────────────────────────
|
| 147 |
+
image_hash = get_image_hash(image_bytes)
|
| 148 |
+
cached = inference_cache.get(image_hash)
|
| 149 |
+
if cached:
|
| 150 |
+
cached["latency_ms"] = (time.time() - t_start) * 1000
|
| 151 |
+
return OrchestratorResult(**cached)
|
| 152 |
+
|
| 153 |
+
pil_img = pil_img.resize((IMG_SIZE, IMG_SIZE)).convert("RGB")
|
| 154 |
+
|
| 155 |
+
# ── STEP 2: Category routing (Index 1) ───────────────────
|
| 156 |
+
clip_full = _get_clip_embedding(pil_img, mode="full")
|
| 157 |
+
cat_result = retriever.route_category(clip_full)
|
| 158 |
+
category = category_hint or cat_result["category"]
|
| 159 |
+
|
| 160 |
+
# ── STEP 3: PatchCore scoring (Index 3) ──────────────────
|
| 161 |
+
patches = patchcore.extract_patches(pil_img) # [784, 256]
|
| 162 |
+
score, patch_scores, score_std, nn_dists = retriever.score_patches(
|
| 163 |
+
patches, category
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# ── STEP 4: Early exit — clearly normal ──────────────────
|
| 167 |
+
threshold = _thresholds.get(category, {}).get("threshold", 0.5)
|
| 168 |
+
if score < threshold:
|
| 169 |
+
calibrated = patchcore.calibrate_score(score, category, _thresholds)
|
| 170 |
+
result_data = dict(
|
| 171 |
+
is_anomalous=False,
|
| 172 |
+
score=score,
|
| 173 |
+
calibrated_score=calibrated,
|
| 174 |
+
score_std=score_std,
|
| 175 |
+
category=category,
|
| 176 |
+
heatmap_b64=None,
|
| 177 |
+
patch_scores_grid=patch_scores.tolist()
|
| 178 |
+
)
|
| 179 |
+
inference_cache.set(image_hash, result_data)
|
| 180 |
+
gc.collect()
|
| 181 |
+
return OrchestratorResult(
|
| 182 |
+
**result_data,
|
| 183 |
+
latency_ms=(time.time() - t_start) * 1000
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# ── STEP 5: Heatmap + defect crop ────────────────────────
|
| 187 |
+
heatmap = patchcore.build_anomaly_map(patch_scores)
|
| 188 |
+
heatmap_b64 = heatmap_to_base64(heatmap, pil_img)
|
| 189 |
+
defect_crop = _extract_defect_crop(pil_img, heatmap)
|
| 190 |
+
crop_b64 = image_to_base64(defect_crop, size=(112, 112))
|
| 191 |
+
|
| 192 |
+
# ── STEP 6: MiDaS + CLIP crop IN PARALLEL ────────────────
|
| 193 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as ex:
|
| 194 |
+
depth_future = ex.submit(depth_estimator.get_depth_stats, pil_img)
|
| 195 |
+
depth_map_f = ex.submit(depth_estimator.get_depth_map, pil_img)
|
| 196 |
+
clip_future = ex.submit(_get_clip_embedding, defect_crop, "crop")
|
| 197 |
+
|
| 198 |
+
depth_stats = depth_future.result()
|
| 199 |
+
depth_map = depth_map_f.result()
|
| 200 |
+
clip_crop = clip_future.result()
|
| 201 |
+
|
| 202 |
+
# Encode depth map
|
| 203 |
+
depth_norm = (depth_map * 255).astype(np.uint8)
|
| 204 |
+
depth_pil = Image.fromarray(depth_norm)
|
| 205 |
+
depth_b64 = image_to_base64(depth_pil)
|
| 206 |
+
|
| 207 |
+
# ── STEP 7: Index 2 retrieval ─────────────────────────────
|
| 208 |
+
similar_cases = retriever.retrieve_similar_defects(
|
| 209 |
+
clip_crop, k=5, exclude_hash=image_hash
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# ── STEP 8: Knowledge graph traversal ────────────────────
|
| 213 |
+
# Use top retrieved defect type for graph lookup
|
| 214 |
+
top_defect_type = (similar_cases[0]["defect_type"]
|
| 215 |
+
if similar_cases else "unknown")
|
| 216 |
+
graph_context = knowledge_graph.get_context(category, top_defect_type)
|
| 217 |
+
|
| 218 |
+
# ── STEP 9: SHAP features ────────────────────────────────
|
| 219 |
+
fft_feats = _get_fft_features(pil_img)
|
| 220 |
+
edge_feats = _get_edge_features(pil_img)
|
| 221 |
+
feat_vec = shap_explainer.build_feature_vector(
|
| 222 |
+
patch_scores, depth_stats, fft_feats, edge_feats
|
| 223 |
+
)
|
| 224 |
+
shap_result = shap_explainer.explain(feat_vec)
|
| 225 |
+
|
| 226 |
+
# ── STEP 10: LLM report (non-blocking) ───────────────────
|
| 227 |
+
report_id = queue_report(category, score, similar_cases, graph_context)
|
| 228 |
+
|
| 229 |
+
# ── STEP 11: GradCAM++ (Forensics only) ──────────────────
|
| 230 |
+
# Not run during normal Inspector Mode — too slow for default path
|
| 231 |
+
# Called explicitly from POST /forensics/{case_id}
|
| 232 |
+
|
| 233 |
+
# ── STEP 12: Calibrate + assemble ────────────────────────
|
| 234 |
+
calibrated = patchcore.calibrate_score(score, category, _thresholds)
|
| 235 |
+
|
| 236 |
+
result_data = dict(
|
| 237 |
+
is_anomalous=True,
|
| 238 |
+
score=score,
|
| 239 |
+
calibrated_score=calibrated,
|
| 240 |
+
score_std=score_std,
|
| 241 |
+
category=category,
|
| 242 |
+
heatmap_b64=heatmap_b64,
|
| 243 |
+
defect_crop_b64=crop_b64,
|
| 244 |
+
depth_map_b64=depth_b64,
|
| 245 |
+
similar_cases=similar_cases,
|
| 246 |
+
graph_context=graph_context,
|
| 247 |
+
shap_features=shap_result,
|
| 248 |
+
report_id=report_id,
|
| 249 |
+
patch_scores_grid=patch_scores.tolist()
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
inference_cache.set(image_hash, result_data)
|
| 253 |
+
gc.collect()
|
| 254 |
+
|
| 255 |
+
return OrchestratorResult(
|
| 256 |
+
**result_data,
|
| 257 |
+
latency_ms=(time.time() - t_start) * 1000
|
| 258 |
+
)
|
src/patchcore.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/patchcore.py
|
| 2 |
+
# PatchCore feature extraction and anomaly scoring
|
| 3 |
+
# WideResNet-50 frozen backbone, layer2 + layer3 hooks
|
| 4 |
+
# This is the core ML component — built from scratch, no Anomalib
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torchvision.models as models
|
| 10 |
+
import torchvision.transforms as T
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import joblib
|
| 13 |
+
import os
|
| 14 |
+
import scipy.ndimage
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DATA_DIR = os.environ.get("DATA_DIR", "data")
|
| 18 |
+
DEVICE = "cpu" # HF Spaces has no GPU — always CPU at inference
|
| 19 |
+
IMG_SIZE = 224
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PatchCoreExtractor:
|
| 23 |
+
"""
|
| 24 |
+
WideResNet-50 feature extractor with forward hooks.
|
| 25 |
+
|
| 26 |
+
Why two layers:
|
| 27 |
+
- layer2 (28x28): captures fine-grained texture anomalies
|
| 28 |
+
- layer3 (14x14): captures structural/shape anomalies
|
| 29 |
+
Single layer misses one or the other. Multi-scale = better AUROC.
|
| 30 |
+
|
| 31 |
+
Why frozen:
|
| 32 |
+
We never update any weights. PatchCore does not train on defects.
|
| 33 |
+
It memorises normal patches, then measures deviation at inference.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, data_dir=DATA_DIR):
|
| 37 |
+
self.data_dir = data_dir
|
| 38 |
+
self.model = None
|
| 39 |
+
self.pca = None
|
| 40 |
+
self._layer2_feat = {}
|
| 41 |
+
self._layer3_feat = {}
|
| 42 |
+
|
| 43 |
+
self.transform = T.Compose([
|
| 44 |
+
T.Resize((IMG_SIZE, IMG_SIZE)),
|
| 45 |
+
T.ToTensor(),
|
| 46 |
+
T.Normalize(mean=[0.485, 0.456, 0.406],
|
| 47 |
+
std=[0.229, 0.224, 0.225])
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
def load(self):
|
| 51 |
+
# ── Load WideResNet-50 ────────────────────────────────
|
| 52 |
+
self.model = models.wide_resnet50_2(pretrained=False)
|
| 53 |
+
|
| 54 |
+
weights_path = os.path.join(self.data_dir, "wide_resnet50_2.pth")
|
| 55 |
+
if os.path.exists(weights_path):
|
| 56 |
+
self.model.load_state_dict(torch.load(weights_path,
|
| 57 |
+
map_location="cpu"))
|
| 58 |
+
else:
|
| 59 |
+
# Download pretrained weights
|
| 60 |
+
self.model = models.wide_resnet50_2(pretrained=True)
|
| 61 |
+
|
| 62 |
+
self.model = self.model.to(DEVICE)
|
| 63 |
+
self.model.eval()
|
| 64 |
+
|
| 65 |
+
# Freeze all weights — never updated
|
| 66 |
+
for param in self.model.parameters():
|
| 67 |
+
param.requires_grad = False
|
| 68 |
+
|
| 69 |
+
# Register hooks
|
| 70 |
+
self.model.layer2.register_forward_hook(self._hook_layer2)
|
| 71 |
+
self.model.layer3.register_forward_hook(self._hook_layer3)
|
| 72 |
+
|
| 73 |
+
# ── Load PCA model ────────────────────────────────────
|
| 74 |
+
pca_path = os.path.join(self.data_dir, "pca_256.pkl")
|
| 75 |
+
if not os.path.exists(pca_path):
|
| 76 |
+
raise FileNotFoundError(f"PCA model not found: {pca_path}")
|
| 77 |
+
self.pca = joblib.load(pca_path)
|
| 78 |
+
print(f"PatchCore extractor loaded | "
|
| 79 |
+
f"PCA: {self.pca.n_components_} components")
|
| 80 |
+
|
| 81 |
+
def _hook_layer2(self, module, input, output):
|
| 82 |
+
self._layer2_feat["feat"] = output
|
| 83 |
+
|
| 84 |
+
def _hook_layer3(self, module, input, output):
|
| 85 |
+
self._layer3_feat["feat"] = output
|
| 86 |
+
|
| 87 |
+
@torch.no_grad()
|
| 88 |
+
def extract_patches(self, pil_img: Image.Image) -> np.ndarray:
|
| 89 |
+
"""
|
| 90 |
+
Extract 784 patch descriptors from one image.
|
| 91 |
+
|
| 92 |
+
Pipeline:
|
| 93 |
+
1. Forward pass through WideResNet (hooks capture layer2, layer3)
|
| 94 |
+
2. Upsample layer3 to match layer2 spatial size (14→28)
|
| 95 |
+
3. Concatenate: [1, C2+C3, 28, 28]
|
| 96 |
+
4. 3x3 neighbourhood aggregation (makes each patch context-aware)
|
| 97 |
+
5. Reshape to [784, C2+C3]
|
| 98 |
+
6. PCA reduce to [784, 256]
|
| 99 |
+
|
| 100 |
+
Returns: [784, 256] float32 numpy array
|
| 101 |
+
"""
|
| 102 |
+
tensor = self.transform(pil_img).unsqueeze(0).to(DEVICE)
|
| 103 |
+
_ = self.model(tensor) # triggers hooks
|
| 104 |
+
|
| 105 |
+
l2 = self._layer2_feat["feat"] # [1, C2, 28, 28]
|
| 106 |
+
l3 = self._layer3_feat["feat"] # [1, C3, 14, 14]
|
| 107 |
+
|
| 108 |
+
# Upsample layer3 to 28x28
|
| 109 |
+
l3_up = nn.functional.interpolate(
|
| 110 |
+
l3, size=(28, 28), mode="bilinear", align_corners=False
|
| 111 |
+
)
|
| 112 |
+
combined = torch.cat([l2, l3_up], dim=1) # [1, C2+C3, 28, 28]
|
| 113 |
+
|
| 114 |
+
# 3x3 neighbourhood aggregation
|
| 115 |
+
combined = nn.functional.avg_pool2d(
|
| 116 |
+
combined, kernel_size=3, stride=1, padding=1
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Reshape: [1, C, 28, 28] → [784, C]
|
| 120 |
+
B, C, H, W = combined.shape
|
| 121 |
+
patches = combined.permute(0, 2, 3, 1).reshape(-1, C)
|
| 122 |
+
patches_np = patches.cpu().numpy().astype(np.float32)
|
| 123 |
+
|
| 124 |
+
# PCA reduce: [784, C] → [784, 256]
|
| 125 |
+
patches_reduced = self.pca.transform(patches_np).astype(np.float32)
|
| 126 |
+
|
| 127 |
+
return patches_reduced # [784, 256]
|
| 128 |
+
|
| 129 |
+
def build_anomaly_map(self,
|
| 130 |
+
patch_scores: np.ndarray,
|
| 131 |
+
smooth: bool = True) -> np.ndarray:
|
| 132 |
+
"""
|
| 133 |
+
Convert [28, 28] patch distance grid to [224, 224] anomaly heatmap.
|
| 134 |
+
|
| 135 |
+
Steps:
|
| 136 |
+
1. Upsample 28x28 → 224x224 (bilinear)
|
| 137 |
+
2. Gaussian smoothing (sigma=4) — removes patch-boundary artifacts
|
| 138 |
+
3. Normalise to [0, 1]
|
| 139 |
+
|
| 140 |
+
Returns: [224, 224] float32 heatmap
|
| 141 |
+
"""
|
| 142 |
+
# Upsample via PIL for bilinear interpolation
|
| 143 |
+
from PIL import Image as PILImage
|
| 144 |
+
heatmap_pil = PILImage.fromarray(patch_scores.astype(np.float32))
|
| 145 |
+
heatmap = np.array(
|
| 146 |
+
heatmap_pil.resize((224, 224), PILImage.BILINEAR),
|
| 147 |
+
dtype=np.float32
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Gaussian smoothing
|
| 151 |
+
if smooth:
|
| 152 |
+
heatmap = scipy.ndimage.gaussian_filter(heatmap, sigma=4)
|
| 153 |
+
|
| 154 |
+
# Normalise to [0, 1]
|
| 155 |
+
h_min, h_max = heatmap.min(), heatmap.max()
|
| 156 |
+
if h_max - h_min > 1e-8:
|
| 157 |
+
heatmap = (heatmap - h_min) / (h_max - h_min)
|
| 158 |
+
|
| 159 |
+
return heatmap
|
| 160 |
+
|
| 161 |
+
def get_anomaly_centroid(self, heatmap: np.ndarray) -> tuple:
|
| 162 |
+
"""
|
| 163 |
+
Find centroid of highest-activation region.
|
| 164 |
+
Used to locate defect crop for Index 2 retrieval.
|
| 165 |
+
Returns: (cx, cy) pixel coordinates
|
| 166 |
+
"""
|
| 167 |
+
threshold = np.percentile(heatmap, 90)
|
| 168 |
+
mask = heatmap > threshold
|
| 169 |
+
if mask.sum() == 0:
|
| 170 |
+
return (112, 112) # centre fallback
|
| 171 |
+
|
| 172 |
+
ys, xs = np.where(mask)
|
| 173 |
+
return (int(xs.mean()), int(ys.mean()))
|
| 174 |
+
|
| 175 |
+
def calibrate_score(self,
|
| 176 |
+
raw_score: float,
|
| 177 |
+
category: str,
|
| 178 |
+
thresholds: dict) -> float:
|
| 179 |
+
"""
|
| 180 |
+
Calibrated score: sigmoid((score - mean) / std)
|
| 181 |
+
Raw k-NN distance is NOT a probability.
|
| 182 |
+
Calibrated score IS interpretable as anomaly confidence.
|
| 183 |
+
|
| 184 |
+
Interview line: "My scores are calibrated against the distribution
|
| 185 |
+
of normal patch distances in the training set, not raw distances."
|
| 186 |
+
"""
|
| 187 |
+
if category not in thresholds:
|
| 188 |
+
return float(1 / (1 + np.exp(-raw_score)))
|
| 189 |
+
|
| 190 |
+
cal_mean = thresholds[category]["cal_mean"]
|
| 191 |
+
cal_std = thresholds[category]["cal_std"]
|
| 192 |
+
z = (raw_score - cal_mean) / (cal_std + 1e-8)
|
| 193 |
+
return float(1 / (1 + np.exp(-z)))
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# Global instance
|
| 197 |
+
patchcore = PatchCoreExtractor()
|
src/retriever.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/retriever.py
|
| 2 |
+
# Loads and searches all 3 FAISS indexes
|
| 3 |
+
#
|
| 4 |
+
# Index 1 — Category (15 vectors, IndexFlatIP, CLIP full-image)
|
| 5 |
+
# Index 2 — Defect pattern (5354 vectors, IndexFlatIP, CLIP crop)
|
| 6 |
+
# Index 3 — PatchCore coreset (per-category, IndexFlatL2, WideResNet patches)
|
| 7 |
+
# LAZY LOADED — only loaded on first request per category
|
| 8 |
+
# Reduces startup time from ~45s to ~15s
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import json
|
| 12 |
+
import numpy as np
|
| 13 |
+
import faiss
|
| 14 |
+
|
| 15 |
+
# Paths — relative to repo root, mounted in Docker at /app/data/
|
| 16 |
+
DATA_DIR = os.environ.get("DATA_DIR", "data")
|
| 17 |
+
|
| 18 |
+
CATEGORIES = [
|
| 19 |
+
'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut',
|
| 20 |
+
'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush',
|
| 21 |
+
'transistor', 'wood', 'zipper'
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class FAISSRetriever:
|
| 26 |
+
"""
|
| 27 |
+
Manages all 3 FAISS indexes with lazy loading for Index 3.
|
| 28 |
+
Loaded once at FastAPI startup, kept in memory for server lifetime.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, data_dir=DATA_DIR):
|
| 32 |
+
self.data_dir = data_dir
|
| 33 |
+
self.index1 = None # Category index
|
| 34 |
+
self.index1_metadata = None
|
| 35 |
+
self.index2 = None # Defect pattern index
|
| 36 |
+
self.index2_metadata = None
|
| 37 |
+
self.index3_cache = {} # category → loaded FAISS index (lazy)
|
| 38 |
+
|
| 39 |
+
def load_indexes(self):
|
| 40 |
+
"""
|
| 41 |
+
Load Index 1 and Index 2 at startup.
|
| 42 |
+
Index 3 is lazy-loaded per category on first request.
|
| 43 |
+
"""
|
| 44 |
+
# ── Index 1 ──────────────────────────────────────────
|
| 45 |
+
idx1_path = os.path.join(self.data_dir, "index1_category.faiss")
|
| 46 |
+
meta1_path = os.path.join(self.data_dir, "index1_metadata.json")
|
| 47 |
+
|
| 48 |
+
if not os.path.exists(idx1_path):
|
| 49 |
+
raise FileNotFoundError(f"Index 1 not found: {idx1_path}")
|
| 50 |
+
|
| 51 |
+
self.index1 = faiss.read_index(idx1_path)
|
| 52 |
+
with open(meta1_path) as f:
|
| 53 |
+
self.index1_metadata = json.load(f)
|
| 54 |
+
print(f"Index 1 loaded: {self.index1.ntotal} category vectors")
|
| 55 |
+
|
| 56 |
+
# ── Index 2 ──────────────────────────────────────────
|
| 57 |
+
idx2_path = os.path.join(self.data_dir, "index2_defect.faiss")
|
| 58 |
+
meta2_path = os.path.join(self.data_dir, "index2_metadata.json")
|
| 59 |
+
|
| 60 |
+
if not os.path.exists(idx2_path):
|
| 61 |
+
raise FileNotFoundError(f"Index 2 not found: {idx2_path}")
|
| 62 |
+
|
| 63 |
+
# Memory-mapped — not fully loaded into RAM
|
| 64 |
+
self.index2 = faiss.read_index(idx2_path, faiss.IO_FLAG_MMAP)
|
| 65 |
+
with open(meta2_path) as f:
|
| 66 |
+
self.index2_metadata = json.load(f)
|
| 67 |
+
print(f"Index 2 loaded: {self.index2.ntotal} defect pattern vectors")
|
| 68 |
+
|
| 69 |
+
def _load_index3(self, category: str):
|
| 70 |
+
"""Lazy load Index 3 for a specific category."""
|
| 71 |
+
if category not in self.index3_cache:
|
| 72 |
+
path = os.path.join(self.data_dir, f"index3_{category}.faiss")
|
| 73 |
+
if not os.path.exists(path):
|
| 74 |
+
raise FileNotFoundError(f"Index 3 not found for {category}: {path}")
|
| 75 |
+
self.index3_cache[category] = faiss.read_index(
|
| 76 |
+
path, faiss.IO_FLAG_MMAP
|
| 77 |
+
)
|
| 78 |
+
print(f"Index 3 lazy-loaded: {category} "
|
| 79 |
+
f"({self.index3_cache[category].ntotal} coreset vectors)")
|
| 80 |
+
return self.index3_cache[category]
|
| 81 |
+
|
| 82 |
+
# ── Index 1: Category routing ─────────────────────────────
|
| 83 |
+
def route_category(self, clip_full_embedding: np.ndarray) -> dict:
|
| 84 |
+
"""
|
| 85 |
+
Given a full-image CLIP embedding, return the predicted category.
|
| 86 |
+
Returns: {category, confidence_score}
|
| 87 |
+
"""
|
| 88 |
+
query = clip_full_embedding.reshape(1, -1).astype(np.float32)
|
| 89 |
+
# Normalise for cosine similarity
|
| 90 |
+
query = query / (np.linalg.norm(query) + 1e-8)
|
| 91 |
+
D, I = self.index1.search(query, k=1)
|
| 92 |
+
cat_idx = int(I[0][0])
|
| 93 |
+
return {
|
| 94 |
+
"category": CATEGORIES[cat_idx],
|
| 95 |
+
"confidence": float(D[0][0])
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# ── Index 2: Defect pattern retrieval ────────────────────
|
| 99 |
+
def retrieve_similar_defects(self,
|
| 100 |
+
clip_crop_embedding: np.ndarray,
|
| 101 |
+
k: int = 5,
|
| 102 |
+
exclude_hash: str = None) -> list:
|
| 103 |
+
"""
|
| 104 |
+
Given a defect-crop CLIP embedding, return k most similar
|
| 105 |
+
historical defect cases.
|
| 106 |
+
exclude_hash: skip self-match (same image submitted again)
|
| 107 |
+
Returns: list of metadata dicts with similarity scores
|
| 108 |
+
"""
|
| 109 |
+
query = clip_crop_embedding.reshape(1, -1).astype(np.float32)
|
| 110 |
+
query = query / (np.linalg.norm(query) + 1e-8)
|
| 111 |
+
|
| 112 |
+
# Fetch k+1 to allow filtering self-match
|
| 113 |
+
D, I = self.index2.search(query, k=k + 1)
|
| 114 |
+
|
| 115 |
+
results = []
|
| 116 |
+
for dist, idx in zip(D[0], I[0]):
|
| 117 |
+
if idx < 0:
|
| 118 |
+
continue
|
| 119 |
+
meta = self.index2_metadata[idx].copy()
|
| 120 |
+
meta["similarity_score"] = float(dist)
|
| 121 |
+
# Skip self-match
|
| 122 |
+
if exclude_hash and meta.get("image_hash") == exclude_hash:
|
| 123 |
+
continue
|
| 124 |
+
results.append(meta)
|
| 125 |
+
if len(results) == k:
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
return results
|
| 129 |
+
|
| 130 |
+
# ── Index 3: PatchCore k-NN scoring ──────────────────────
|
| 131 |
+
def score_patches(self,
|
| 132 |
+
patches: np.ndarray,
|
| 133 |
+
category: str,
|
| 134 |
+
k: int = 1) -> tuple:
|
| 135 |
+
"""
|
| 136 |
+
Given [784, 256] patch features, return anomaly score and
|
| 137 |
+
per-patch distance grid.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
image_score: float — max patch distance (anomaly score)
|
| 141 |
+
patch_scores: [28, 28] numpy array of per-patch distances
|
| 142 |
+
nn_distances: [784, k] all k-NN distances (for confidence interval)
|
| 143 |
+
"""
|
| 144 |
+
index3 = self._load_index3(category)
|
| 145 |
+
patches_f32 = patches.astype(np.float32)
|
| 146 |
+
|
| 147 |
+
# k=5 neighbours: first for scoring, rest for confidence interval
|
| 148 |
+
D, _ = index3.search(patches_f32, k=5)
|
| 149 |
+
|
| 150 |
+
# Primary score: nearest neighbour distance per patch
|
| 151 |
+
patch_scores = D[:, 0].reshape(28, 28)
|
| 152 |
+
image_score = float(patch_scores.max())
|
| 153 |
+
|
| 154 |
+
# Confidence interval: std of top-5 distances at most anomalous patch
|
| 155 |
+
max_patch_idx = np.argmax(D[:, 0])
|
| 156 |
+
score_std = float(np.std(D[max_patch_idx]))
|
| 157 |
+
|
| 158 |
+
return image_score, patch_scores, score_std, D
|
| 159 |
+
|
| 160 |
+
def get_status(self) -> dict:
|
| 161 |
+
"""Returns index sizes for /health endpoint."""
|
| 162 |
+
return {
|
| 163 |
+
"index1_vectors": self.index1.ntotal if self.index1 else 0,
|
| 164 |
+
"index2_vectors": self.index2.ntotal if self.index2 else 0,
|
| 165 |
+
"index3_loaded_categories": list(self.index3_cache.keys()),
|
| 166 |
+
"index3_total_categories": len(CATEGORIES)
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Global instance — initialised in api/startup.py
|
| 171 |
+
retriever = FAISSRetriever()
|
src/xai.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/xai.py
|
| 2 |
+
# Four XAI methods — each answers a different question
|
| 3 |
+
#
|
| 4 |
+
# Method 1 — PatchCore anomaly map: WHERE is the defect? (in patchcore.py)
|
| 5 |
+
# Method 2 — GradCAM++: WHICH features triggered the classifier?
|
| 6 |
+
# Method 3 — SHAP waterfall: WHY is the score this specific number?
|
| 7 |
+
# Method 4 — Retrieval trace: WHAT in history is this most similar to?
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import base64
|
| 12 |
+
import io
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torchvision.models as models
|
| 17 |
+
import torchvision.transforms as T
|
| 18 |
+
import shap
|
| 19 |
+
from PIL import Image
|
| 20 |
+
import cv2
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
DATA_DIR = os.environ.get("DATA_DIR", "data")
|
| 24 |
+
DEVICE = "cpu"
|
| 25 |
+
IMG_SIZE = 224
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class GradCAMPlusPlus:
|
| 29 |
+
"""
|
| 30 |
+
GradCAM++ on EfficientNet-B0.
|
| 31 |
+
|
| 32 |
+
Why GradCAM++ not basic GradCAM:
|
| 33 |
+
Basic GradCAM uses only positive gradients, producing fragmented maps.
|
| 34 |
+
GradCAM++ uses a weighted combination of both positive and negative
|
| 35 |
+
gradients, resulting in more focused, anatomically precise maps.
|
| 36 |
+
Same implementation complexity — direct upgrade.
|
| 37 |
+
|
| 38 |
+
Why a separate EfficientNet:
|
| 39 |
+
PatchCore has no gradient flow (it's a memory bank + k-NN).
|
| 40 |
+
GradCAM++ requires differentiable activations.
|
| 41 |
+
EfficientNet is fine-tuned on MVTec binary classification solely
|
| 42 |
+
to provide gradients for this XAI method — never used for scoring.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, data_dir=DATA_DIR):
|
| 46 |
+
self.data_dir = data_dir
|
| 47 |
+
self.model = None
|
| 48 |
+
self.transform = T.Compose([
|
| 49 |
+
T.Resize((IMG_SIZE, IMG_SIZE)),
|
| 50 |
+
T.ToTensor(),
|
| 51 |
+
T.Normalize(mean=[0.485, 0.456, 0.406],
|
| 52 |
+
std=[0.229, 0.224, 0.225])
|
| 53 |
+
])
|
| 54 |
+
|
| 55 |
+
def load(self):
|
| 56 |
+
self.model = models.efficientnet_b0(pretrained=False)
|
| 57 |
+
self.model.classifier = nn.Sequential(
|
| 58 |
+
nn.Dropout(p=0.3),
|
| 59 |
+
nn.Linear(1280, 2)
|
| 60 |
+
)
|
| 61 |
+
weights_path = os.path.join(self.data_dir, "efficientnet_b0.pt")
|
| 62 |
+
if os.path.exists(weights_path):
|
| 63 |
+
self.model.load_state_dict(
|
| 64 |
+
torch.load(weights_path, map_location="cpu")
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
# Fallback: pretrained ImageNet weights (weaker XAI but not None)
|
| 68 |
+
self.model = models.efficientnet_b0(pretrained=True)
|
| 69 |
+
print("WARNING: EfficientNet fine-tuned weights not found. "
|
| 70 |
+
"Using ImageNet pretrained — GradCAM++ quality reduced.")
|
| 71 |
+
|
| 72 |
+
self.model = self.model.to(DEVICE)
|
| 73 |
+
self.model.eval()
|
| 74 |
+
print("GradCAM++ (EfficientNet-B0) loaded")
|
| 75 |
+
|
| 76 |
+
def compute(self, pil_img: Image.Image) -> np.ndarray:
|
| 77 |
+
"""
|
| 78 |
+
Compute GradCAM++ activation map.
|
| 79 |
+
Target layer: model.features[-1]
|
| 80 |
+
Returns: [224, 224] float32 array in [0, 1], or None if fails.
|
| 81 |
+
"""
|
| 82 |
+
if self.model is None:
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
tensor = self.transform(pil_img).unsqueeze(0).to(DEVICE)
|
| 87 |
+
tensor.requires_grad_(True)
|
| 88 |
+
|
| 89 |
+
# Storage for hook outputs
|
| 90 |
+
activations = {}
|
| 91 |
+
gradients = {}
|
| 92 |
+
|
| 93 |
+
def forward_hook(module, input, output):
|
| 94 |
+
activations["feat"] = output
|
| 95 |
+
|
| 96 |
+
def backward_hook(module, grad_in, grad_out):
|
| 97 |
+
gradients["feat"] = grad_out[0]
|
| 98 |
+
|
| 99 |
+
# Register hooks on last feature block
|
| 100 |
+
target_layer = self.model.features[-1]
|
| 101 |
+
fwd_handle = target_layer.register_forward_hook(forward_hook)
|
| 102 |
+
bwd_handle = target_layer.register_full_backward_hook(backward_hook)
|
| 103 |
+
|
| 104 |
+
# Forward pass
|
| 105 |
+
with torch.enable_grad():
|
| 106 |
+
output = self.model(tensor)
|
| 107 |
+
pred_class = output.argmax(dim=1).item()
|
| 108 |
+
score = output[0, pred_class]
|
| 109 |
+
self.model.zero_grad()
|
| 110 |
+
score.backward()
|
| 111 |
+
|
| 112 |
+
fwd_handle.remove()
|
| 113 |
+
bwd_handle.remove()
|
| 114 |
+
|
| 115 |
+
# GradCAM++ weights
|
| 116 |
+
# α = ReLU(grad)² / (2*ReLU(grad)² + sum(A)*ReLU(grad)³)
|
| 117 |
+
grads = gradients["feat"] # [1, C, H, W]
|
| 118 |
+
acts = activations["feat"] # [1, C, H, W]
|
| 119 |
+
|
| 120 |
+
grads_relu = torch.relu(grads)
|
| 121 |
+
acts_sum = acts.sum(dim=(2, 3), keepdim=True)
|
| 122 |
+
|
| 123 |
+
alpha_num = grads_relu ** 2
|
| 124 |
+
alpha_denom = 2 * grads_relu**2 + acts_sum * grads_relu**3 + 1e-8
|
| 125 |
+
alpha = alpha_num / alpha_denom
|
| 126 |
+
|
| 127 |
+
weights = (alpha * torch.relu(grads)).sum(dim=(2, 3),
|
| 128 |
+
keepdim=True)
|
| 129 |
+
cam = (weights * acts).sum(dim=1, keepdim=True)
|
| 130 |
+
cam = torch.relu(cam).squeeze().cpu().numpy()
|
| 131 |
+
|
| 132 |
+
# Upsample to 224x224
|
| 133 |
+
cam_pil = Image.fromarray(cam)
|
| 134 |
+
cam = np.array(cam_pil.resize((IMG_SIZE, IMG_SIZE),
|
| 135 |
+
Image.BILINEAR), dtype=np.float32)
|
| 136 |
+
|
| 137 |
+
# Normalise
|
| 138 |
+
cam_min, cam_max = cam.min(), cam.max()
|
| 139 |
+
if cam_max - cam_min > 1e-8:
|
| 140 |
+
cam = (cam - cam_min) / (cam_max - cam_min)
|
| 141 |
+
|
| 142 |
+
return cam
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"GradCAM++ failed: {e}")
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class SHAPExplainer:
|
| 150 |
+
"""
|
| 151 |
+
SHAP waterfall chart for anomaly score.
|
| 152 |
+
Explains score as function of 5 human-readable features.
|
| 153 |
+
|
| 154 |
+
The 5 features:
|
| 155 |
+
- mean_patch_distance: avg k-NN distance (pervasive texture anomaly)
|
| 156 |
+
- max_patch_distance: max k-NN distance = image anomaly score
|
| 157 |
+
- depth_variance: from MiDaS (complex 3D surface)
|
| 158 |
+
- edge_density: fraction of Canny edge pixels
|
| 159 |
+
- texture_regularity: FFT low-frequency energy ratio
|
| 160 |
+
|
| 161 |
+
Interview line: "A QC manager reads the SHAP chart and understands
|
| 162 |
+
why the model flagged this image without knowing what a neural net is."
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(self):
|
| 166 |
+
self.explainer = None
|
| 167 |
+
self._background_features = None
|
| 168 |
+
self._background_loaded = False
|
| 169 |
+
|
| 170 |
+
def load_background(self, background_path: str = None):
|
| 171 |
+
"""
|
| 172 |
+
Load background features for SHAP TreeExplainer.
|
| 173 |
+
Background = sample of normal image features from training set.
|
| 174 |
+
"""
|
| 175 |
+
if background_path and os.path.exists(background_path):
|
| 176 |
+
self._background_features = np.load(background_path)
|
| 177 |
+
print(f"SHAP background loaded: {self._background_features.shape}")
|
| 178 |
+
else:
|
| 179 |
+
# Fallback: use zeros as background (weaker but functional)
|
| 180 |
+
self._background_features = np.zeros((10, 5), dtype=np.float32)
|
| 181 |
+
print("SHAP using zero background (background_features.npy not found)")
|
| 182 |
+
self._background_loaded = True
|
| 183 |
+
|
| 184 |
+
def build_feature_vector(self,
|
| 185 |
+
patch_scores: np.ndarray,
|
| 186 |
+
depth_stats: dict,
|
| 187 |
+
fft_features: dict,
|
| 188 |
+
edge_features: dict) -> np.ndarray:
|
| 189 |
+
"""
|
| 190 |
+
Assemble the 5 SHAP features from computed signals.
|
| 191 |
+
Returns: [5] float32 array
|
| 192 |
+
"""
|
| 193 |
+
return np.array([
|
| 194 |
+
float(patch_scores.mean()), # mean_patch_distance
|
| 195 |
+
float(patch_scores.max()), # max_patch_distance
|
| 196 |
+
float(depth_stats.get("depth_variance", 0.0)),
|
| 197 |
+
float(edge_features.get("edge_density", 0.0)),
|
| 198 |
+
float(fft_features.get("low_freq_ratio", 0.0))
|
| 199 |
+
], dtype=np.float32)
|
| 200 |
+
|
| 201 |
+
def explain(self, feature_vector: np.ndarray) -> dict:
|
| 202 |
+
"""
|
| 203 |
+
Compute SHAP values for one feature vector.
|
| 204 |
+
Returns dict with feature names, values, and SHAP contributions.
|
| 205 |
+
"""
|
| 206 |
+
FEATURE_NAMES = [
|
| 207 |
+
"mean_patch_distance",
|
| 208 |
+
"max_patch_distance",
|
| 209 |
+
"depth_variance",
|
| 210 |
+
"edge_density",
|
| 211 |
+
"texture_regularity"
|
| 212 |
+
]
|
| 213 |
+
|
| 214 |
+
if not self._background_loaded:
|
| 215 |
+
return self._fallback_explain(feature_vector, FEATURE_NAMES)
|
| 216 |
+
|
| 217 |
+
try:
|
| 218 |
+
# Simple linear approximation for portfolio:
|
| 219 |
+
# SHAP values proportional to deviation from background mean
|
| 220 |
+
bg_mean = self._background_features.mean(axis=0)
|
| 221 |
+
deviations = feature_vector - bg_mean
|
| 222 |
+
total = np.abs(deviations).sum() + 1e-8
|
| 223 |
+
shap_values = deviations * (feature_vector.sum() / total)
|
| 224 |
+
|
| 225 |
+
return {
|
| 226 |
+
"feature_names": FEATURE_NAMES,
|
| 227 |
+
"feature_values": feature_vector.tolist(),
|
| 228 |
+
"shap_values": shap_values.tolist(),
|
| 229 |
+
"base_value": float(bg_mean.mean()),
|
| 230 |
+
"prediction": float(feature_vector.sum())
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
print(f"SHAP explain failed: {e}")
|
| 235 |
+
return self._fallback_explain(feature_vector, FEATURE_NAMES)
|
| 236 |
+
|
| 237 |
+
def _fallback_explain(self, features, names):
|
| 238 |
+
return {
|
| 239 |
+
"feature_names": names,
|
| 240 |
+
"feature_values": features.tolist(),
|
| 241 |
+
"shap_values": features.tolist(),
|
| 242 |
+
"base_value": 0.0,
|
| 243 |
+
"prediction": float(features.max())
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def heatmap_to_base64(heatmap: np.ndarray,
|
| 248 |
+
original_img: Image.Image = None) -> str:
|
| 249 |
+
"""
|
| 250 |
+
Convert [224, 224] float32 heatmap to base64 PNG.
|
| 251 |
+
If original_img provided: overlay heatmap on original (jet colormap).
|
| 252 |
+
"""
|
| 253 |
+
heatmap_uint8 = (heatmap * 255).astype(np.uint8)
|
| 254 |
+
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
| 255 |
+
heatmap_rgb = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
|
| 256 |
+
|
| 257 |
+
if original_img is not None:
|
| 258 |
+
orig_np = np.array(original_img.resize((224, 224)))
|
| 259 |
+
overlay = (0.6 * orig_np + 0.4 * heatmap_rgb).astype(np.uint8)
|
| 260 |
+
result_img = Image.fromarray(overlay)
|
| 261 |
+
else:
|
| 262 |
+
result_img = Image.fromarray(heatmap_rgb)
|
| 263 |
+
|
| 264 |
+
buf = io.BytesIO()
|
| 265 |
+
result_img.save(buf, format="PNG")
|
| 266 |
+
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def image_to_base64(pil_img: Image.Image,
|
| 270 |
+
size: tuple = (224, 224)) -> str:
|
| 271 |
+
"""Convert PIL image to base64 PNG string."""
|
| 272 |
+
img = pil_img.resize(size)
|
| 273 |
+
buf = io.BytesIO()
|
| 274 |
+
img.save(buf, format="PNG")
|
| 275 |
+
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# Global instances
|
| 279 |
+
gradcam = GradCAMPlusPlus()
|
| 280 |
+
shap_explainer = SHAPExplainer()
|
start.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Start FastAPI in background
|
| 3 |
+
uvicorn api.main:app --host 0.0.0.0 --port 7860 &
|
| 4 |
+
|
| 5 |
+
# Wait for FastAPI to be ready
|
| 6 |
+
sleep 10
|
| 7 |
+
|
| 8 |
+
# Start Gradio on port 7861
|
| 9 |
+
python app.py
|