Spaces:
Sleeping
Sleeping
Upload 15 files
Browse files- Dockerfile +26 -0
- Notebooks/EfficientNet_ConvNext_Fusion.ipynb +0 -0
- Notebooks/Model_Compression.ipynb +276 -0
- Notebooks/Resnet18_fine_tuning.ipynb +0 -0
- Notebooks/damage_detector_yolo.ipynb +0 -0
- app.py +107 -0
- checkpoints/best_fusion_model_fp16.pth +3 -0
- checkpoints/best_resnet_model.pt +3 -0
- checkpoints/damage_detector.pt +3 -0
- index.html +419 -0
- requirements.txt +12 -0
- scripts/gradcam.py +100 -0
- scripts/model_loader.py +38 -0
- scripts/prediction_helper.py +175 -0
- scripts/yolo.py +57 -0
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
# --- SYSTEM DEPENDENCIES (CRITICAL FOR OPENCV / YOLO) ---
|
| 9 |
+
RUN apt-get update && apt-get install -y \
|
| 10 |
+
build-essential \
|
| 11 |
+
gcc \
|
| 12 |
+
libgl1 \
|
| 13 |
+
libglib2.0-0 \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
# --- PYTHON DEPENDENCIES ---
|
| 17 |
+
COPY requirements.txt .
|
| 18 |
+
RUN pip install --no-cache-dir --upgrade pip \
|
| 19 |
+
&& pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# --- APP CODE ---
|
| 22 |
+
COPY . .
|
| 23 |
+
|
| 24 |
+
EXPOSE 7860
|
| 25 |
+
|
| 26 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
Notebooks/EfficientNet_ConvNext_Fusion.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Notebooks/Model_Compression.ipynb
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "671818be",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Model Conversion or Compression \n",
|
| 9 |
+
"**This notebook demonstrates how to convert a PyTorch model to FP16 precision, which can reduce the model size and potentially speed up inference on compatible hardware. We will use the `FusionClassifier` as an example, but the same approach can be applied to other models as well.**\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**From FP32 to FP16**"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": 7,
|
| 17 |
+
"id": "b1715593",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [
|
| 20 |
+
{
|
| 21 |
+
"name": "stderr",
|
| 22 |
+
"output_type": "stream",
|
| 23 |
+
"text": [
|
| 24 |
+
"Loading weights: 100%|██████████| 342/342 [00:00<00:00, 2845.51it/s]\n",
|
| 25 |
+
"[transformers] \u001b[1mConvNextModel LOAD REPORT\u001b[0m from: facebook/convnext-small-224\n",
|
| 26 |
+
"Key | Status | | \n",
|
| 27 |
+
"------------------+------------+--+-\n",
|
| 28 |
+
"classifier.bias | UNEXPECTED | | \n",
|
| 29 |
+
"classifier.weight | UNEXPECTED | | \n",
|
| 30 |
+
"\n",
|
| 31 |
+
"Notes:\n",
|
| 32 |
+
"- UNEXPECTED:\tcan be ignored when loading from different task/architecture; not ok if you expect identical arch.\n"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"name": "stdout",
|
| 37 |
+
"output_type": "stream",
|
| 38 |
+
"text": [
|
| 39 |
+
"============================================================\n",
|
| 40 |
+
"Initializing model...\n",
|
| 41 |
+
"============================================================\n",
|
| 42 |
+
"Model weights loaded successfully.\n",
|
| 43 |
+
"Model converted to FP16.\n",
|
| 44 |
+
"============================================================\n",
|
| 45 |
+
"FP16 model saved successfully.\n",
|
| 46 |
+
"Saved Path : D:\\DamageLens\\checkpoints\\best_fusion_model_fp16.pth\n",
|
| 47 |
+
"FP16 Model Size : 135.77 MB\n",
|
| 48 |
+
"============================================================\n"
|
| 49 |
+
]
|
| 50 |
+
}
|
| 51 |
+
],
|
| 52 |
+
"source": [
|
| 53 |
+
"import os\n",
|
| 54 |
+
"import torch\n",
|
| 55 |
+
"import torch.nn as nn\n",
|
| 56 |
+
"import torchvision.models as models\n",
|
| 57 |
+
"from transformers import ConvNextModel\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"# =========================================================\n",
|
| 61 |
+
"# FUSION MODEL\n",
|
| 62 |
+
"# =========================================================\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"class FusionClassifier(nn.Module):\n",
|
| 65 |
+
" def __init__(self, num_classes, convnext_model_name=\"facebook/convnext-small-224\"):\n",
|
| 66 |
+
" super().__init__()\n",
|
| 67 |
+
"\n",
|
| 68 |
+
" # -------------------------------------------------\n",
|
| 69 |
+
" # EfficientNet-V2-S\n",
|
| 70 |
+
" # -------------------------------------------------\n",
|
| 71 |
+
" eff = models.efficientnet_v2_s(\n",
|
| 72 |
+
" weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1\n",
|
| 73 |
+
" )\n",
|
| 74 |
+
"\n",
|
| 75 |
+
" # Freeze all\n",
|
| 76 |
+
" for param in eff.parameters():\n",
|
| 77 |
+
" param.requires_grad = False\n",
|
| 78 |
+
"\n",
|
| 79 |
+
" # Unfreeze last stages\n",
|
| 80 |
+
" for param in eff.features[5].parameters():\n",
|
| 81 |
+
" param.requires_grad = True\n",
|
| 82 |
+
"\n",
|
| 83 |
+
" for param in eff.features[6].parameters():\n",
|
| 84 |
+
" param.requires_grad = True\n",
|
| 85 |
+
"\n",
|
| 86 |
+
" for param in eff.features[7].parameters():\n",
|
| 87 |
+
" param.requires_grad = True\n",
|
| 88 |
+
"\n",
|
| 89 |
+
" self.eff_features = eff.features\n",
|
| 90 |
+
" self.eff_avgpool = eff.avgpool\n",
|
| 91 |
+
" self.eff_out_dim = eff.classifier[1].in_features # 1280\n",
|
| 92 |
+
"\n",
|
| 93 |
+
" # -------------------------------------------------\n",
|
| 94 |
+
" # ConvNeXt Small\n",
|
| 95 |
+
" # -------------------------------------------------\n",
|
| 96 |
+
" cnx = ConvNextModel.from_pretrained(convnext_model_name)\n",
|
| 97 |
+
"\n",
|
| 98 |
+
" # Freeze all\n",
|
| 99 |
+
" for param in cnx.parameters():\n",
|
| 100 |
+
" param.requires_grad = False\n",
|
| 101 |
+
"\n",
|
| 102 |
+
" # Unfreeze stages\n",
|
| 103 |
+
" for param in cnx.encoder.stages[2].parameters():\n",
|
| 104 |
+
" param.requires_grad = True\n",
|
| 105 |
+
"\n",
|
| 106 |
+
" for param in cnx.encoder.stages[3].parameters():\n",
|
| 107 |
+
" param.requires_grad = True\n",
|
| 108 |
+
"\n",
|
| 109 |
+
" for param in cnx.layernorm.parameters():\n",
|
| 110 |
+
" param.requires_grad = True\n",
|
| 111 |
+
"\n",
|
| 112 |
+
" self.cnx_backbone = cnx\n",
|
| 113 |
+
" self.cnx_out_dim = 768\n",
|
| 114 |
+
"\n",
|
| 115 |
+
" # -------------------------------------------------\n",
|
| 116 |
+
" # Fusion Head\n",
|
| 117 |
+
" # -------------------------------------------------\n",
|
| 118 |
+
" fused_dim = self.eff_out_dim + self.cnx_out_dim\n",
|
| 119 |
+
"\n",
|
| 120 |
+
" self.fusion_head = nn.Sequential(\n",
|
| 121 |
+
" nn.Dropout(0.4),\n",
|
| 122 |
+
"\n",
|
| 123 |
+
" nn.Linear(fused_dim, 512),\n",
|
| 124 |
+
" nn.LayerNorm(512),\n",
|
| 125 |
+
" nn.GELU(),\n",
|
| 126 |
+
"\n",
|
| 127 |
+
" nn.Dropout(0.3),\n",
|
| 128 |
+
"\n",
|
| 129 |
+
" nn.Linear(512, 256),\n",
|
| 130 |
+
" nn.LayerNorm(256),\n",
|
| 131 |
+
" nn.GELU(),\n",
|
| 132 |
+
"\n",
|
| 133 |
+
" nn.Dropout(0.2),\n",
|
| 134 |
+
"\n",
|
| 135 |
+
" nn.Linear(256, num_classes)\n",
|
| 136 |
+
" )\n",
|
| 137 |
+
"\n",
|
| 138 |
+
" def forward(self, pixel_values_eff, pixel_values_cnx):\n",
|
| 139 |
+
"\n",
|
| 140 |
+
" # EfficientNet branch\n",
|
| 141 |
+
" x_eff = self.eff_features(pixel_values_eff)\n",
|
| 142 |
+
" x_eff = self.eff_avgpool(x_eff)\n",
|
| 143 |
+
" x_eff = torch.flatten(x_eff, 1)\n",
|
| 144 |
+
"\n",
|
| 145 |
+
" # ConvNeXt branch\n",
|
| 146 |
+
" cnx_out = self.cnx_backbone(\n",
|
| 147 |
+
" pixel_values=pixel_values_cnx,\n",
|
| 148 |
+
" return_dict=True\n",
|
| 149 |
+
" )\n",
|
| 150 |
+
"\n",
|
| 151 |
+
" x_cnx = cnx_out.pooler_output\n",
|
| 152 |
+
"\n",
|
| 153 |
+
" # Fusion\n",
|
| 154 |
+
" fused = torch.cat([x_eff, x_cnx], dim=1)\n",
|
| 155 |
+
"\n",
|
| 156 |
+
" logits = self.fusion_head(fused)\n",
|
| 157 |
+
"\n",
|
| 158 |
+
" return logits\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"# =========================================================\n",
|
| 162 |
+
"# CONFIG\n",
|
| 163 |
+
"# =========================================================\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"class_map = {\n",
|
| 166 |
+
" 0: \"Front Breakage\",\n",
|
| 167 |
+
" 1: \"Front Crushed\",\n",
|
| 168 |
+
" 2: \"Front Normal\",\n",
|
| 169 |
+
" 3: \"Rear Breakage\",\n",
|
| 170 |
+
" 4: \"Rear Crushed\",\n",
|
| 171 |
+
" 5: \"Rear Normal\"\n",
|
| 172 |
+
"}\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"device = torch.device(\"cpu\")\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"CHECKPOINT_PATH = r\"D:\\DamageLens\\checkpoints\\best_fusion_model.pt\"\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"SAVE_FP16_PATH = r\"D:\\DamageLens\\checkpoints\\best_fusion_model_fp16.pth\"\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"NUM_CLASSES = len(class_map)\n",
|
| 181 |
+
"\n",
|
| 182 |
+
"CONVNEXT_MODEL_NAME = \"facebook/convnext-small-224\"\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"# =========================================================\n",
|
| 186 |
+
"# INITIALIZE MODEL\n",
|
| 187 |
+
"# =========================================================\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"model = FusionClassifier(\n",
|
| 190 |
+
" num_classes=NUM_CLASSES,\n",
|
| 191 |
+
" convnext_model_name=CONVNEXT_MODEL_NAME\n",
|
| 192 |
+
")\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"print(\"=\" * 60)\n",
|
| 195 |
+
"print(\"Initializing model...\")\n",
|
| 196 |
+
"print(\"=\" * 60)\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"# =========================================================\n",
|
| 200 |
+
"# LOAD TRAINED WEIGHTS\n",
|
| 201 |
+
"# =========================================================\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"checkpoint = torch.load(\n",
|
| 204 |
+
" CHECKPOINT_PATH,\n",
|
| 205 |
+
" map_location=device\n",
|
| 206 |
+
")\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"# If checkpoint contains state_dict\n",
|
| 209 |
+
"if \"model_state_dict\" in checkpoint:\n",
|
| 210 |
+
" model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"# If checkpoint is directly state_dict\n",
|
| 213 |
+
"else:\n",
|
| 214 |
+
" model.load_state_dict(checkpoint)\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"print(\"Model weights loaded successfully.\")\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"# =========================================================\n",
|
| 220 |
+
"# CONVERT TO FP16\n",
|
| 221 |
+
"# =========================================================\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"model = model.half()\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"print(\"Model converted to FP16.\")\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"\n",
|
| 228 |
+
"# =========================================================\n",
|
| 229 |
+
"# CREATE CHECKPOINT DIRECTORY\n",
|
| 230 |
+
"# =========================================================\n",
|
| 231 |
+
"\n",
|
| 232 |
+
"os.makedirs(\"checkpoints\", exist_ok=True)\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"# =========================================================\n",
|
| 236 |
+
"# SAVE FP16 MODEL\n",
|
| 237 |
+
"# =========================================================\n",
|
| 238 |
+
"\n",
|
| 239 |
+
"torch.save(\n",
|
| 240 |
+
" model.state_dict(),\n",
|
| 241 |
+
" SAVE_FP16_PATH\n",
|
| 242 |
+
")\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"print(\"=\" * 60)\n",
|
| 245 |
+
"print(\"FP16 model saved successfully.\")\n",
|
| 246 |
+
"print(f\"Saved Path : {SAVE_FP16_PATH}\")\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"size_mb = os.path.getsize(SAVE_FP16_PATH) / (1024 * 1024)\n",
|
| 249 |
+
"\n",
|
| 250 |
+
"print(f\"FP16 Model Size : {size_mb:.2f} MB\")\n",
|
| 251 |
+
"print(\"=\" * 60)"
|
| 252 |
+
]
|
| 253 |
+
}
|
| 254 |
+
],
|
| 255 |
+
"metadata": {
|
| 256 |
+
"kernelspec": {
|
| 257 |
+
"display_name": "myvenv",
|
| 258 |
+
"language": "python",
|
| 259 |
+
"name": "python3"
|
| 260 |
+
},
|
| 261 |
+
"language_info": {
|
| 262 |
+
"codemirror_mode": {
|
| 263 |
+
"name": "ipython",
|
| 264 |
+
"version": 3
|
| 265 |
+
},
|
| 266 |
+
"file_extension": ".py",
|
| 267 |
+
"mimetype": "text/x-python",
|
| 268 |
+
"name": "python",
|
| 269 |
+
"nbconvert_exporter": "python",
|
| 270 |
+
"pygments_lexer": "ipython3",
|
| 271 |
+
"version": "3.11.0"
|
| 272 |
+
}
|
| 273 |
+
},
|
| 274 |
+
"nbformat": 4,
|
| 275 |
+
"nbformat_minor": 5
|
| 276 |
+
}
|
Notebooks/Resnet18_fine_tuning.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Notebooks/damage_detector_yolo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import shutil
|
| 4 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 5 |
+
from fastapi.staticfiles import StaticFiles
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
from scripts.gradcam import get_resnet_gradcam, get_fusion_gradcam
|
| 10 |
+
from scripts.yolo import get_yolo_damage_boxes
|
| 11 |
+
from scripts.model_loader import initialize_models
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
app = FastAPI()
|
| 15 |
+
app.add_middleware(
|
| 16 |
+
CORSMiddleware,
|
| 17 |
+
allow_origins=["*"],
|
| 18 |
+
allow_credentials=True,
|
| 19 |
+
allow_methods=["*"],
|
| 20 |
+
allow_headers=["*"],
|
| 21 |
+
)
|
| 22 |
+
UPLOAD_DIR = "static/uploads"
|
| 23 |
+
RESULT_DIR = "static/results"
|
| 24 |
+
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 25 |
+
os.makedirs(RESULT_DIR, exist_ok=True)
|
| 26 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 27 |
+
class_map = {
|
| 28 |
+
0: "Front Breakage",
|
| 29 |
+
1: "Front Crushed",
|
| 30 |
+
2: "Front Normal",
|
| 31 |
+
3: "Rear Breakage",
|
| 32 |
+
4: "Rear Crushed",
|
| 33 |
+
5: "Rear Normal"
|
| 34 |
+
}
|
| 35 |
+
resnet_predictor, fusion_predictor = initialize_models(class_map)
|
| 36 |
+
|
| 37 |
+
@app.get("/")
|
| 38 |
+
def api_status():
|
| 39 |
+
return {"status": "API is running"}
|
| 40 |
+
|
| 41 |
+
@app.post("/predict")
|
| 42 |
+
async def predict_and_generate_cams(file: UploadFile = File(...), mode: str = "resnet"):
|
| 43 |
+
mode = mode.lower()
|
| 44 |
+
if mode not in {"resnet", "fusion"}:
|
| 45 |
+
raise HTTPException(status_code=400, detail="mode must be 'resnet' or 'fusion'")
|
| 46 |
+
unique_id = str(uuid.uuid4())
|
| 47 |
+
input_filename = f"{unique_id}_input.jpg"
|
| 48 |
+
input_path = os.path.join(UPLOAD_DIR, input_filename)
|
| 49 |
+
with open(input_path, "wb") as buffer:
|
| 50 |
+
shutil.copyfileobj(file.file, buffer)
|
| 51 |
+
if mode == "resnet":
|
| 52 |
+
output_name = f"{unique_id}_resnet.jpg"
|
| 53 |
+
output_path = os.path.join(RESULT_DIR, output_name)
|
| 54 |
+
get_resnet_gradcam(input_path, resnet_predictor, output_path)
|
| 55 |
+
selected_viz = f"/static/results/{output_name}"
|
| 56 |
+
resnet_viz = selected_viz
|
| 57 |
+
fusion_viz = None
|
| 58 |
+
else:
|
| 59 |
+
output_name = f"{unique_id}_fusion.jpg"
|
| 60 |
+
output_path = os.path.join(RESULT_DIR, output_name)
|
| 61 |
+
get_fusion_gradcam(input_path, fusion_predictor, output_path)
|
| 62 |
+
selected_viz = f"/static/results/{output_name}"
|
| 63 |
+
resnet_viz = None
|
| 64 |
+
fusion_viz = selected_viz
|
| 65 |
+
return {
|
| 66 |
+
"status": "success",
|
| 67 |
+
"original_image": f"/static/uploads/{input_filename}",
|
| 68 |
+
"selected_viz": selected_viz,
|
| 69 |
+
"resnet_viz": resnet_viz,
|
| 70 |
+
"fusion_viz": fusion_viz,
|
| 71 |
+
"mode": mode
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
@app.post("/predict/resnet")
|
| 75 |
+
async def resnet_prediction(image: UploadFile = File(...)):
|
| 76 |
+
try:
|
| 77 |
+
image = Image.open(image.file).convert("RGB")
|
| 78 |
+
except Exception:
|
| 79 |
+
raise HTTPException(status_code=400, detail="Invalid image file")
|
| 80 |
+
return resnet_predictor.resnet_predict(image_input=image)
|
| 81 |
+
|
| 82 |
+
@app.post("/predict/fusion")
|
| 83 |
+
async def fusion_prediction(image: UploadFile = File(...)):
|
| 84 |
+
try:
|
| 85 |
+
image = Image.open(image.file).convert("RGB")
|
| 86 |
+
except Exception:
|
| 87 |
+
raise HTTPException(status_code=400, detail="Invalid image file")
|
| 88 |
+
return fusion_predictor.predict(image_input=image)
|
| 89 |
+
|
| 90 |
+
@app.post("/predict/yolo")
|
| 91 |
+
async def yolo_detection(file: UploadFile = File(...)):
|
| 92 |
+
unique_id = str(uuid.uuid4())
|
| 93 |
+
input_filename = f"{unique_id}_input.jpg"
|
| 94 |
+
yolo_out_name = f"{unique_id}_yolo.jpg"
|
| 95 |
+
input_path = os.path.join(UPLOAD_DIR, input_filename)
|
| 96 |
+
yolo_path = os.path.join(RESULT_DIR, yolo_out_name)
|
| 97 |
+
with open(input_path, "wb") as buffer:
|
| 98 |
+
shutil.copyfileobj(file.file, buffer)
|
| 99 |
+
result = get_yolo_damage_boxes(input_path, yolo_path)
|
| 100 |
+
return {
|
| 101 |
+
"status": "success",
|
| 102 |
+
"original_image": f"/static/uploads/{input_filename}",
|
| 103 |
+
"yolo_image": f"/static/results/{yolo_out_name}",
|
| 104 |
+
"detections": result["detections"],
|
| 105 |
+
"total_detections": result["total_detections"],
|
| 106 |
+
"message": result["message"]
|
| 107 |
+
}
|
checkpoints/best_fusion_model_fp16.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59cede9aca6c4b39b6447458ddb9cdc3e3ba06c5d972ad62b6807bfcd0afa466
|
| 3 |
+
size 142369497
|
checkpoints/best_resnet_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:688cbd4f9eb2e97b6e67287b23f5f750b0367dfb08844704d49075fb086bbdd5
|
| 3 |
+
size 130360907
|
checkpoints/damage_detector.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7c3b9222d9977b5bfd78d65ea6be9d609c81de473349bb3f362088a86ba07f9f
|
| 3 |
+
size 51189913
|
index.html
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Car Damage AI</title>
|
| 7 |
+
<script src="https://cdn.plot.ly/plotly-2.27.0.min.js"></script>
|
| 8 |
+
<style>
|
| 9 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;800&display=swap');
|
| 10 |
+
|
| 11 |
+
:root {
|
| 12 |
+
--bg-dark: #09090b;
|
| 13 |
+
--bg-card: #18181b;
|
| 14 |
+
--text-primary: #e2e8f0;
|
| 15 |
+
--text-secondary: #a1a1aa;
|
| 16 |
+
--accent: #00c6ff;
|
| 17 |
+
--accent-hover: #0072ff;
|
| 18 |
+
--glass: rgba(255, 255, 255, 0.03);
|
| 19 |
+
--card-border: #27272a;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
* { margin: 0; padding: 0; box-sizing: border-box; font-family: 'Inter', sans-serif; }
|
| 23 |
+
|
| 24 |
+
body {
|
| 25 |
+
background-color: var(--bg-dark);
|
| 26 |
+
color: var(--text-primary);
|
| 27 |
+
min-height: 100vh;
|
| 28 |
+
display: flex;
|
| 29 |
+
justify-content: center;
|
| 30 |
+
align-items: flex-start;
|
| 31 |
+
padding: 40px 20px;
|
| 32 |
+
background-image: radial-gradient(circle at top right, rgba(0, 198, 255, 0.05) 0%, transparent 40%);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
.container {
|
| 36 |
+
width: 100%;
|
| 37 |
+
max-width: 850px;
|
| 38 |
+
background: var(--bg-card);
|
| 39 |
+
border-radius: 20px;
|
| 40 |
+
padding: 35px;
|
| 41 |
+
box-shadow: 0 20px 40px rgba(0,0,0,0.6);
|
| 42 |
+
animation: slideUpFade 0.6s ease-out forwards;
|
| 43 |
+
border: 1px solid var(--card-border);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
@keyframes slideUpFade { from { opacity: 0; transform: translateY(30px); } to { opacity: 1; transform: translateY(0); } }
|
| 47 |
+
|
| 48 |
+
/* Shimmering Main Title */
|
| 49 |
+
.shimmer-text {
|
| 50 |
+
text-align: center;
|
| 51 |
+
font-size: 2.5rem;
|
| 52 |
+
font-weight: 800;
|
| 53 |
+
background: linear-gradient(90deg, #e2e8f0 0%, #ffffff 25%, #00c6ff 50%, #e2e8f0 75%, #e2e8f0 100%);
|
| 54 |
+
background-size: 200% auto;
|
| 55 |
+
color: transparent;
|
| 56 |
+
-webkit-background-clip: text;
|
| 57 |
+
background-clip: text;
|
| 58 |
+
animation: shimmer 4s linear infinite;
|
| 59 |
+
margin-bottom: 0.2rem;
|
| 60 |
+
}
|
| 61 |
+
@keyframes shimmer { 0% { background-position: -200% center; } 100% { background-position: 200% center; } }
|
| 62 |
+
|
| 63 |
+
.subtitle { text-align: center; color: var(--text-secondary); font-size: 1rem; margin-bottom: 25px; }
|
| 64 |
+
|
| 65 |
+
/* Warning Box */
|
| 66 |
+
.warning-box {
|
| 67 |
+
background: rgba(0, 198, 255, 0.1);
|
| 68 |
+
border-left: 4px solid var(--accent);
|
| 69 |
+
color: var(--text-primary);
|
| 70 |
+
padding: 12px 15px;
|
| 71 |
+
border-radius: 8px;
|
| 72 |
+
margin-bottom: 25px;
|
| 73 |
+
font-size: 0.9rem;
|
| 74 |
+
display: flex;
|
| 75 |
+
align-items: center;
|
| 76 |
+
gap: 12px;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
/* Controls Section */
|
| 80 |
+
.controls-grid {
|
| 81 |
+
display: grid;
|
| 82 |
+
grid-template-columns: 1fr 1fr;
|
| 83 |
+
gap: 20px;
|
| 84 |
+
margin-bottom: 25px;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
.file-wrapper {
|
| 88 |
+
position: relative; height: 160px; border: 2px dashed #444; border-radius: 16px;
|
| 89 |
+
display: flex; justify-content: center; align-items: center; cursor: pointer;
|
| 90 |
+
transition: all 0.3s ease; background: var(--glass); overflow: hidden;
|
| 91 |
+
}
|
| 92 |
+
.file-wrapper:hover { border-color: var(--accent); background: rgba(0, 198, 255, 0.05); }
|
| 93 |
+
.file-wrapper input { position: absolute; width: 100%; height: 100%; opacity: 0; cursor: pointer; z-index: 2; }
|
| 94 |
+
|
| 95 |
+
.settings-card {
|
| 96 |
+
background: rgba(0,0,0,0.2);
|
| 97 |
+
border-radius: 16px;
|
| 98 |
+
padding: 20px;
|
| 99 |
+
border: 1px solid var(--card-border);
|
| 100 |
+
display: flex;
|
| 101 |
+
flex-direction: column;
|
| 102 |
+
justify-content: center;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
select {
|
| 106 |
+
width: 100%; background: #27272a; border: 1px solid #3f3f46; padding: 14px;
|
| 107 |
+
border-radius: 12px; color: white; outline: none; margin-top: 10px; font-size: 1rem;
|
| 108 |
+
}
|
| 109 |
+
select:focus { border-color: var(--accent); }
|
| 110 |
+
|
| 111 |
+
/* Preview Area & Animations */
|
| 112 |
+
.image-area {
|
| 113 |
+
width: 100%; height: 350px; background: #09090b; border-radius: 16px;
|
| 114 |
+
margin-bottom: 25px; display: none; justify-content: center; align-items: center;
|
| 115 |
+
overflow: hidden; position: relative; border: 1px solid var(--card-border);
|
| 116 |
+
}
|
| 117 |
+
.image-area img { max-width: 100%; max-height: 100%; object-fit: contain; z-index: 1;}
|
| 118 |
+
|
| 119 |
+
/* Scanner Animation */
|
| 120 |
+
.scan-line {
|
| 121 |
+
position: absolute; top: -10%; left: 0; width: 100%; height: 5px;
|
| 122 |
+
background: var(--accent); box-shadow: 0 0 15px var(--accent), 0 0 30px var(--accent);
|
| 123 |
+
z-index: 5; opacity: 0.8; display: none; animation: scanMove 2s ease-in-out infinite; filter: blur(1px);
|
| 124 |
+
}
|
| 125 |
+
@keyframes scanMove { 0% { top: -10%; opacity: 0.5; } 50% { opacity: 1; } 100% { top: 110%; opacity: 0.5; } }
|
| 126 |
+
|
| 127 |
+
/* Loader Overlay */
|
| 128 |
+
.loader-overlay {
|
| 129 |
+
position: absolute; top: 0; left: 0; width: 100%; height: 100%;
|
| 130 |
+
background: rgba(0,0,0,0.65); backdrop-filter: blur(4px);
|
| 131 |
+
display: none; flex-direction: column; justify-content: center; align-items: center; z-index: 10;
|
| 132 |
+
}
|
| 133 |
+
.spinner {
|
| 134 |
+
width: 50px; height: 50px; border: 4px solid rgba(0, 198, 255, 0.2);
|
| 135 |
+
border-top: 4px solid var(--accent); border-radius: 50%;
|
| 136 |
+
animation: spin 1s cubic-bezier(0.68, -0.55, 0.27, 1.55) infinite; margin-bottom: 15px;
|
| 137 |
+
}
|
| 138 |
+
@keyframes spin { 100% { transform: rotate(360deg); } }
|
| 139 |
+
|
| 140 |
+
/* Buttons */
|
| 141 |
+
.btn {
|
| 142 |
+
width: 100%; padding: 16px; background: linear-gradient(135deg, var(--accent) 0%, var(--accent-hover) 100%);
|
| 143 |
+
color: white; border: none; border-radius: 12px; cursor: pointer; font-weight: 700; font-size: 1rem;
|
| 144 |
+
transition: all 0.3s ease; box-shadow: 0 4px 15px rgba(0, 114, 255, 0.3);
|
| 145 |
+
}
|
| 146 |
+
.btn:hover:not(:disabled) { transform: scale(1.02); box-shadow: 0 8px 25px rgba(0, 198, 255, 0.5); }
|
| 147 |
+
.btn:disabled { background: #444; color: #888; box-shadow: none; transform: none; cursor: not-allowed;}
|
| 148 |
+
|
| 149 |
+
/* Results Tabs */
|
| 150 |
+
.results-section { display: none; margin-top: 30px; animation: slideUpFade 0.5s ease-out; }
|
| 151 |
+
.tabs { display: flex; gap: 10px; margin-bottom: 20px; border-bottom: 1px solid var(--card-border); padding-bottom: 10px; overflow-x: auto; }
|
| 152 |
+
.tab {
|
| 153 |
+
padding: 10px 20px; cursor: pointer; border-radius: 8px; color: var(--text-secondary);
|
| 154 |
+
font-weight: 600; transition: all 0.3s ease; white-space: nowrap;
|
| 155 |
+
}
|
| 156 |
+
.tab.active { background: rgba(0, 198, 255, 0.1); color: var(--accent); }
|
| 157 |
+
.tab-content { display: none; }
|
| 158 |
+
.tab-content.active { display: block; animation: slideUpFade 0.4s ease-out; }
|
| 159 |
+
|
| 160 |
+
/* Progress Bar */
|
| 161 |
+
.progress-wrapper { background: #27272a; border-radius: 20px; overflow: hidden; height: 12px; margin: 10px 0 20px 0; box-shadow: inset 0 2px 4px rgba(0,0,0,0.5); }
|
| 162 |
+
.progress-fill { height: 100%; background: linear-gradient(90deg, var(--accent), var(--accent-hover)); border-radius: 20px; width: 0%; transition: width 1.5s cubic-bezier(0.22, 1, 0.36, 1); }
|
| 163 |
+
|
| 164 |
+
/* Final Prediction Text */
|
| 165 |
+
.big-text { font-size: 2.5rem; font-weight: 800; background: -webkit-linear-gradient(45deg, #00c6ff, #0072ff); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 5px; }
|
| 166 |
+
|
| 167 |
+
/* Images Grid (Attention Maps) */
|
| 168 |
+
.img-grid { display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px; }
|
| 169 |
+
.img-card { background: rgba(0,0,0,0.3); border: 1px solid var(--card-border); border-radius: 12px; padding: 10px; text-align: center; }
|
| 170 |
+
.img-card img { width: 100%; border-radius: 8px; margin-top: 10px; }
|
| 171 |
+
|
| 172 |
+
/* YOLO Grid */
|
| 173 |
+
.yolo-grid { display: grid; grid-template-columns: 1.5fr 1fr; gap: 20px; }
|
| 174 |
+
.log-box { background: rgba(0,0,0,0.3); border: 1px solid var(--card-border); border-radius: 12px; padding: 20px; height: 100%; }
|
| 175 |
+
.detection-item { background: #27272a; padding: 12px; border-radius: 8px; margin-bottom: 10px; border-left: 4px solid var(--accent); box-shadow: 0 2px 4px rgba(0,0,0,0.2); }
|
| 176 |
+
|
| 177 |
+
@media (max-width: 768px) {
|
| 178 |
+
.controls-grid, .img-grid, .yolo-grid { grid-template-columns: 1fr; }
|
| 179 |
+
.shimmer-text { font-size: 2rem; }
|
| 180 |
+
}
|
| 181 |
+
</style>
|
| 182 |
+
</head>
|
| 183 |
+
<body>
|
| 184 |
+
|
| 185 |
+
<div class="container">
|
| 186 |
+
<div class="shimmer-text">🚗 Car Damage AI</div>
|
| 187 |
+
<div class="subtitle">Fusion Intelligence: ResNet + YOLO</div>
|
| 188 |
+
|
| 189 |
+
<div class="warning-box">
|
| 190 |
+
<span style="font-size: 1.2rem;">⏱️</span>
|
| 191 |
+
<span><b>Note:</b> The first analysis may take up to 3-4 mins while models warm up. Subsequent requests are faster!</span>
|
| 192 |
+
</div>
|
| 193 |
+
|
| 194 |
+
<div class="controls-grid">
|
| 195 |
+
<div class="file-wrapper">
|
| 196 |
+
<input type="file" id="fileInput" accept="image/jpeg, image/png, image/jpg">
|
| 197 |
+
<div style="text-align: center;">
|
| 198 |
+
<p style="font-size: 2.5rem; margin-bottom: 5px;">📷</p>
|
| 199 |
+
<p style="color:#a1a1aa; font-weight: 500;">Tap or Drag & Drop Vehicle Image</p>
|
| 200 |
+
</div>
|
| 201 |
+
</div>
|
| 202 |
+
|
| 203 |
+
<div class="settings-card">
|
| 204 |
+
<h3 style="font-size: 1.1rem; margin-bottom: 5px;">⚙️ Analysis Settings</h3>
|
| 205 |
+
<p style="font-size: 0.85rem; color: var(--text-secondary);">Select the neural network pipeline.</p>
|
| 206 |
+
<select id="engineMode">
|
| 207 |
+
<option value="fusion">Fusion</option>
|
| 208 |
+
<option value="resnet">ResNet</option>
|
| 209 |
+
</select>
|
| 210 |
+
</div>
|
| 211 |
+
</div>
|
| 212 |
+
|
| 213 |
+
<div class="image-area" id="previewBox">
|
| 214 |
+
<img id="displayImage" src="" alt="Car Image">
|
| 215 |
+
<div class="scan-line" id="scanLine"></div>
|
| 216 |
+
<div class="loader-overlay" id="loader">
|
| 217 |
+
<div class="spinner"></div>
|
| 218 |
+
<p style="color:white; font-weight:600; letter-spacing: 1px; margin-bottom: 5px;">🧠 ANALYZING...</p>
|
| 219 |
+
<p id="loaderStatusText" style="color:#00c6ff; font-size:0.9rem;">Extracting features...</p>
|
| 220 |
+
</div>
|
| 221 |
+
</div>
|
| 222 |
+
|
| 223 |
+
<button class="btn" id="analyzeBtn" onclick="analyze()">🚀 Run AI Analysis</button>
|
| 224 |
+
|
| 225 |
+
<div class="results-section" id="resultsSection">
|
| 226 |
+
<div class="tabs">
|
| 227 |
+
<div class="tab active" onclick="switchResultTab('tab-pred')">📊 Prediction</div>
|
| 228 |
+
<div class="tab" onclick="switchResultTab('tab-attention')">👀 Attention Maps</div>
|
| 229 |
+
<div class="tab" onclick="switchResultTab('tab-yolo')">🎯 Localization</div>
|
| 230 |
+
</div>
|
| 231 |
+
|
| 232 |
+
<div id="tab-pred" class="tab-content active">
|
| 233 |
+
<div class="settings-card">
|
| 234 |
+
<div id="finalPredText" class="big-text">--</div>
|
| 235 |
+
<div style="font-weight: 600; margin-top: 5px;" id="confText">Confidence Score: 0%</div>
|
| 236 |
+
<div class="progress-wrapper">
|
| 237 |
+
<div class="progress-fill" id="confBar"></div>
|
| 238 |
+
</div>
|
| 239 |
+
<h3 style="margin: 15px 0 5px 0; font-size: 1.1rem;">Probability Distribution</h3>
|
| 240 |
+
<div id="plotlyChart" style="width:100%; height:300px;"></div>
|
| 241 |
+
</div>
|
| 242 |
+
</div>
|
| 243 |
+
|
| 244 |
+
<div id="tab-attention" class="tab-content">
|
| 245 |
+
<div class="img-grid">
|
| 246 |
+
<div class="img-card">
|
| 247 |
+
<div style="font-weight:600; color:#e2e8f0;">Original Image</div>
|
| 248 |
+
<img id="camOriginal" src="" alt="Original Image">
|
| 249 |
+
</div>
|
| 250 |
+
<div class="img-card">
|
| 251 |
+
<div id="camSelectedLabel" style="font-weight:600; color:#e2e8f0;">Selected Grad-CAM</div>
|
| 252 |
+
<img id="camSelected" src="" alt="Selected Grad-CAM">
|
| 253 |
+
</div>
|
| 254 |
+
</div>
|
| 255 |
+
</div>
|
| 256 |
+
|
| 257 |
+
<div id="tab-yolo" class="tab-content">
|
| 258 |
+
<div class="yolo-grid">
|
| 259 |
+
<div class="settings-card">
|
| 260 |
+
<h3 style="margin-bottom: 10px;">Bounding Boxes</h3>
|
| 261 |
+
<img id="yoloImage" src="" alt="YOLO Output" style="width: 100%; border-radius: 8px;">
|
| 262 |
+
</div>
|
| 263 |
+
<div class="log-box">
|
| 264 |
+
<h3 style="margin-bottom: 15px;">Detection Log</h3>
|
| 265 |
+
<div id="yoloLogContainer">
|
| 266 |
+
</div>
|
| 267 |
+
</div>
|
| 268 |
+
</div>
|
| 269 |
+
</div>
|
| 270 |
+
</div>
|
| 271 |
+
|
| 272 |
+
</div>
|
| 273 |
+
|
| 274 |
+
<script>
|
| 275 |
+
const API_URL = "http://127.0.0.1:8000";
|
| 276 |
+
let currentFile = null;
|
| 277 |
+
|
| 278 |
+
// DOM Elements
|
| 279 |
+
const fileInput = document.getElementById('fileInput');
|
| 280 |
+
const displayImage = document.getElementById('displayImage');
|
| 281 |
+
const previewBox = document.getElementById('previewBox');
|
| 282 |
+
const resultsSection = document.getElementById('resultsSection');
|
| 283 |
+
const loader = document.getElementById('loader');
|
| 284 |
+
const loaderStatusText = document.getElementById('loaderStatusText');
|
| 285 |
+
const scanLine = document.getElementById('scanLine');
|
| 286 |
+
const analyzeBtn = document.getElementById('analyzeBtn');
|
| 287 |
+
|
| 288 |
+
fileInput.addEventListener('change', e => {
|
| 289 |
+
if(e.target.files[0]) {
|
| 290 |
+
currentFile = e.target.files[0];
|
| 291 |
+
const reader = new FileReader();
|
| 292 |
+
reader.onload = x => {
|
| 293 |
+
displayImage.src = x.target.result;
|
| 294 |
+
previewBox.style.display = 'flex';
|
| 295 |
+
resultsSection.style.display = 'none'; // Hide old results
|
| 296 |
+
};
|
| 297 |
+
reader.readAsDataURL(currentFile);
|
| 298 |
+
}
|
| 299 |
+
});
|
| 300 |
+
|
| 301 |
+
// --- BUG FIX IS HERE ---
|
| 302 |
+
function switchResultTab(tabId) {
|
| 303 |
+
// 1. Remove active state from all tabs and panels
|
| 304 |
+
document.querySelectorAll('.tab').forEach(t => t.classList.remove('active'));
|
| 305 |
+
document.querySelectorAll('.tab-content').forEach(c => c.classList.remove('active'));
|
| 306 |
+
|
| 307 |
+
// 2. Find the tab button that corresponds to this panel and make it active
|
| 308 |
+
const tabButton = document.querySelector(`.tab[onclick*="${tabId}"]`);
|
| 309 |
+
if(tabButton) {
|
| 310 |
+
tabButton.classList.add('active');
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
// 3. Make the specific panel active
|
| 314 |
+
document.getElementById(tabId).classList.add('active');
|
| 315 |
+
|
| 316 |
+
// 4. Resize Plotly chart if switching back to its tab to prevent layout squash
|
| 317 |
+
if(tabId === 'tab-pred') {
|
| 318 |
+
window.dispatchEvent(new Event('resize'));
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
// Plotly Chart Helper
|
| 323 |
+
function drawChart(dataObj, title) {
|
| 324 |
+
const labels = Object.keys(dataObj);
|
| 325 |
+
const values = Object.values(dataObj);
|
| 326 |
+
|
| 327 |
+
const trace = {
|
| 328 |
+
x: labels,
|
| 329 |
+
y: values,
|
| 330 |
+
type: 'bar',
|
| 331 |
+
marker: { color: '#00c6ff', line: { color: '#0072ff', width: 1.5 } },
|
| 332 |
+
opacity: 0.85
|
| 333 |
+
};
|
| 334 |
+
|
| 335 |
+
const layout = {
|
| 336 |
+
title: title || '',
|
| 337 |
+
paper_bgcolor: 'rgba(0,0,0,0)',
|
| 338 |
+
plot_bgcolor: 'rgba(0,0,0,0)',
|
| 339 |
+
font: { family: 'Inter', color: '#a1a1aa' },
|
| 340 |
+
margin: { l: 40, r: 10, t: 30, b: 40 },
|
| 341 |
+
xaxis: { title: 'Classes' },
|
| 342 |
+
yaxis: { title: 'Probability', range: [0, 1] }
|
| 343 |
+
};
|
| 344 |
+
|
| 345 |
+
Plotly.newPlot('plotlyChart', [trace], layout, {displayModeBar: false, responsive: true});
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
async function analyze() {
|
| 349 |
+
if(!currentFile) return alert("Please upload an image first.");
|
| 350 |
+
|
| 351 |
+
const engineMode = document.getElementById('engineMode').value; // fusion or resnet
|
| 352 |
+
|
| 353 |
+
// UI Prep
|
| 354 |
+
loader.style.display = 'flex';
|
| 355 |
+
scanLine.style.display = 'block';
|
| 356 |
+
analyzeBtn.disabled = true;
|
| 357 |
+
analyzeBtn.innerText = "Processing...";
|
| 358 |
+
resultsSection.style.display = 'none';
|
| 359 |
+
|
| 360 |
+
const formData = new FormData();
|
| 361 |
+
formData.append('image', currentFile);
|
| 362 |
+
|
| 363 |
+
try {
|
| 364 |
+
loaderStatusText.innerText = "Extracting features...";
|
| 365 |
+
const predRes = await fetch(`${API_URL}/predict/${engineMode}`, { method: 'POST', body: formData });
|
| 366 |
+
if (!predRes.ok) throw new Error("Prediction API failed");
|
| 367 |
+
const predData = await predRes.json();
|
| 368 |
+
|
| 369 |
+
loaderStatusText.innerText = "Generating Grad-CAM...";
|
| 370 |
+
const camForm = new FormData();
|
| 371 |
+
camForm.append('file', currentFile);
|
| 372 |
+
const camRes = await fetch(`${API_URL}/predict?mode=${engineMode}`, { method: 'POST', body: camForm });
|
| 373 |
+
if (!camRes.ok) throw new Error("Grad-CAM API failed");
|
| 374 |
+
const camData = await camRes.json();
|
| 375 |
+
|
| 376 |
+
loaderStatusText.innerText = "Running YOLO detection...";
|
| 377 |
+
const yoloRes = await fetch(`${API_URL}/predict/yolo`, { method: 'POST', body: camForm });
|
| 378 |
+
if (!yoloRes.ok) throw new Error("YOLO API failed");
|
| 379 |
+
const yoloData = await yoloRes.json();
|
| 380 |
+
|
| 381 |
+
const highestClass = Object.keys(predData).reduce((a, b) => predData[a] > predData[b] ? a : b);
|
| 382 |
+
const highestScore = predData[highestClass] || 0;
|
| 383 |
+
document.getElementById('finalPredText').innerText = highestClass;
|
| 384 |
+
document.getElementById('confText').innerText = `Confidence Score: ${(highestScore * 100).toFixed(2)}%`;
|
| 385 |
+
drawChart(predData, `${engineMode.toUpperCase()} Output`);
|
| 386 |
+
setTimeout(() => { document.getElementById('confBar').style.width = `${(highestScore * 100).toFixed(2)}%`; }, 100);
|
| 387 |
+
|
| 388 |
+
document.getElementById('camOriginal').src = `${API_URL}${camData.original_image}`;
|
| 389 |
+
document.getElementById('camSelected').src = `${API_URL}${camData.selected_viz}`;
|
| 390 |
+
document.getElementById('camSelectedLabel').innerText = engineMode === 'fusion' ? 'Fusion Grad-CAM' : 'ResNet Grad-CAM';
|
| 391 |
+
|
| 392 |
+
document.getElementById('yoloImage').src = `${API_URL}${yoloData.yolo_image}`;
|
| 393 |
+
const logContainer = document.getElementById('yoloLogContainer');
|
| 394 |
+
if (!yoloData.detections || yoloData.detections.length === 0) {
|
| 395 |
+
logContainer.innerHTML = '<div style="color: #a1a1aa; padding: 10px;">🟢 No damage boxes detected.</div>';
|
| 396 |
+
} else {
|
| 397 |
+
let logHTML = `<div style="color: #ffcc00; margin-bottom: 10px; font-weight:600;">🔴 Found ${yoloData.total_detections} damage region(s).</div>`;
|
| 398 |
+
yoloData.detections.forEach((det, idx) => {
|
| 399 |
+
logHTML += `<div class="detection-item"><b style="color: #e2e8f0;">Region ${idx + 1}</b><br><span style="color: #a1a1aa; font-size: 0.9em;">${det.label} · ${(det.confidence * 100).toFixed(1)}%</span></div>`;
|
| 400 |
+
});
|
| 401 |
+
logContainer.innerHTML = logHTML;
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
resultsSection.style.display = 'block';
|
| 405 |
+
switchResultTab('tab-pred');
|
| 406 |
+
} catch (error) {
|
| 407 |
+
alert(`Error connecting to AI server. Details: ${error.message}`);
|
| 408 |
+
console.error(error);
|
| 409 |
+
} finally {
|
| 410 |
+
loader.style.display = 'none';
|
| 411 |
+
scanLine.style.display = 'none';
|
| 412 |
+
analyzeBtn.disabled = false;
|
| 413 |
+
analyzeBtn.innerText = "🚀 Run AI Analysis";
|
| 414 |
+
}
|
| 415 |
+
}
|
| 416 |
+
</script>
|
| 417 |
+
|
| 418 |
+
</body>
|
| 419 |
+
</html>
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
transformers
|
| 4 |
+
fastapi
|
| 5 |
+
uvicorn
|
| 6 |
+
dotenv
|
| 7 |
+
matplotlib
|
| 8 |
+
opencv-python
|
| 9 |
+
python-multipart
|
| 10 |
+
ultralytics
|
| 11 |
+
plotly
|
| 12 |
+
pandas
|
scripts/gradcam.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
def get_resnet_gradcam(image_path, predictor, output_path):
|
| 8 |
+
model = predictor.model
|
| 9 |
+
device = predictor.device
|
| 10 |
+
model.eval()
|
| 11 |
+
features, gradients = [], []
|
| 12 |
+
|
| 13 |
+
def forward_hook(module, input, output):
|
| 14 |
+
features.append(output)
|
| 15 |
+
def backward_hook(module, grad_in, grad_out):
|
| 16 |
+
gradients.append(grad_out[0])
|
| 17 |
+
|
| 18 |
+
target_layer = model.model.layer4[-1]
|
| 19 |
+
handle_fw = target_layer.register_forward_hook(forward_hook)
|
| 20 |
+
handle_bw = target_layer.register_full_backward_hook(backward_hook)
|
| 21 |
+
|
| 22 |
+
original_img = Image.open(image_path).convert("RGB")
|
| 23 |
+
input_tensor = predictor.test_transforms(original_img).unsqueeze(0).to(device)
|
| 24 |
+
model.zero_grad()
|
| 25 |
+
output = model(input_tensor)
|
| 26 |
+
pred_class_idx = output.argmax(dim=1).item()
|
| 27 |
+
score = output[0, pred_class_idx]
|
| 28 |
+
score.backward()
|
| 29 |
+
|
| 30 |
+
handle_fw.remove()
|
| 31 |
+
handle_bw.remove()
|
| 32 |
+
|
| 33 |
+
acts = features[0].cpu().data.numpy()[0]
|
| 34 |
+
grads = gradients[0].cpu().data.numpy()[0]
|
| 35 |
+
weights = np.mean(grads, axis=(1, 2))
|
| 36 |
+
cam = np.zeros(acts.shape[1:], dtype=np.float32)
|
| 37 |
+
for i, w in enumerate(weights):
|
| 38 |
+
cam += w * acts[i]
|
| 39 |
+
|
| 40 |
+
cam = np.maximum(cam, 0)
|
| 41 |
+
cam = cv2.resize(cam, (original_img.width, original_img.height))
|
| 42 |
+
cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
|
| 43 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
|
| 44 |
+
original_np = np.array(original_img)
|
| 45 |
+
overlay = cv2.addWeighted(cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR), 0.6, heatmap, 0.4, 0)
|
| 46 |
+
cv2.imwrite(output_path, overlay)
|
| 47 |
+
return True
|
| 48 |
+
|
| 49 |
+
def get_fusion_gradcam(image_path, predictor, output_path):
|
| 50 |
+
model = predictor.model
|
| 51 |
+
device = predictor.device
|
| 52 |
+
model.eval()
|
| 53 |
+
target_layer = model.eff_features[-1]
|
| 54 |
+
activation = None
|
| 55 |
+
|
| 56 |
+
def forward_hook(module, inp, out):
|
| 57 |
+
nonlocal activation
|
| 58 |
+
activation = out
|
| 59 |
+
activation.retain_grad()
|
| 60 |
+
|
| 61 |
+
handle = target_layer.register_forward_hook(forward_hook)
|
| 62 |
+
original_img = Image.open(image_path).convert("RGB")
|
| 63 |
+
pixel_eff = predictor.eff_normalize(original_img).unsqueeze(0).to(device)
|
| 64 |
+
inputs_cnx = predictor.convnext_processor(images=original_img, return_tensors="pt")
|
| 65 |
+
pixel_cnx = inputs_cnx["pixel_values"].to(device)
|
| 66 |
+
|
| 67 |
+
if next(model.parameters()).dtype == torch.float16:
|
| 68 |
+
pixel_eff = pixel_eff.half()
|
| 69 |
+
pixel_cnx = pixel_cnx.half()
|
| 70 |
+
|
| 71 |
+
model.zero_grad()
|
| 72 |
+
output = model(pixel_eff, pixel_cnx)
|
| 73 |
+
pred_class_idx = output.argmax(dim=1).item()
|
| 74 |
+
score = output[0, pred_class_idx]
|
| 75 |
+
score.backward()
|
| 76 |
+
handle.remove()
|
| 77 |
+
|
| 78 |
+
if activation is None or activation.grad is None:
|
| 79 |
+
raise RuntimeError("Gradients could not be extracted. Ensure requires_grad=True is properly set.")
|
| 80 |
+
|
| 81 |
+
acts = activation[0].detach().float()
|
| 82 |
+
grads = activation.grad[0].detach().float()
|
| 83 |
+
weights = grads.mean(dim=(1, 2), keepdim=True)
|
| 84 |
+
cam = torch.sum(weights * acts, dim=0)
|
| 85 |
+
cam = F.relu(cam)
|
| 86 |
+
cam = cam.cpu().numpy()
|
| 87 |
+
|
| 88 |
+
if cam.max() > cam.min():
|
| 89 |
+
cam = (cam - cam.min()) / (cam.max() - cam.min())
|
| 90 |
+
else:
|
| 91 |
+
cam = np.zeros_like(cam)
|
| 92 |
+
|
| 93 |
+
cam = np.uint8(255 * cam)
|
| 94 |
+
cam_resized = cv2.resize(cam, (original_img.width, original_img.height), interpolation=cv2.INTER_LINEAR)
|
| 95 |
+
heatmap = cv2.applyColorMap(cam_resized, cv2.COLORMAP_JET)
|
| 96 |
+
original_np = np.array(original_img)
|
| 97 |
+
original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
|
| 98 |
+
overlay = cv2.addWeighted(original_bgr, 0.5, heatmap, 0.6, 0)
|
| 99 |
+
cv2.imwrite(output_path, overlay)
|
| 100 |
+
return True
|
scripts/model_loader.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from .prediction_helper import ResnetCarDamagePredictor, FusionCarDamagePredictor
|
| 4 |
+
|
| 5 |
+
CHECKPOINT_DIR = Path(__file__).resolve().parents[1] / "checkpoints"
|
| 6 |
+
MODEL_FILES = {
|
| 7 |
+
"resnet": "best_resnet_model.pt",
|
| 8 |
+
"fusion": "best_fusion_model_fp16.pth",
|
| 9 |
+
"yolo": "damage_detector.pt",
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_checkpoint_path(model_key: str) -> Path:
|
| 14 |
+
if model_key not in MODEL_FILES:
|
| 15 |
+
raise ValueError(f"Unknown model key: {model_key}")
|
| 16 |
+
|
| 17 |
+
path = CHECKPOINT_DIR / MODEL_FILES[model_key]
|
| 18 |
+
if not path.exists():
|
| 19 |
+
raise FileNotFoundError(f"Checkpoint not found: {path}")
|
| 20 |
+
return path
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ModelLoader:
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.base_dir = CHECKPOINT_DIR
|
| 26 |
+
|
| 27 |
+
def get_model_path(self, model_key: str) -> Path:
|
| 28 |
+
return get_checkpoint_path(model_key)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def initialize_models(class_map):
|
| 32 |
+
resnet_path = get_checkpoint_path("resnet")
|
| 33 |
+
fusion_path = get_checkpoint_path("fusion")
|
| 34 |
+
|
| 35 |
+
resnet_predictor = ResnetCarDamagePredictor(resnet_path, class_map)
|
| 36 |
+
fusion_predictor = FusionCarDamagePredictor(fusion_path, class_map)
|
| 37 |
+
|
| 38 |
+
return resnet_predictor, fusion_predictor
|
scripts/prediction_helper.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torchvision import transforms, models
|
| 5 |
+
from PIL import Image, UnidentifiedImageError
|
| 6 |
+
from transformers import ConvNextModel, ConvNextImageProcessor
|
| 7 |
+
|
| 8 |
+
class Car_Classifier_Resnet(nn.Module):
|
| 9 |
+
def __init__(self, num_classes):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.model = models.resnet18(weights="DEFAULT")
|
| 12 |
+
for param in self.model.parameters():
|
| 13 |
+
param.requires_grad = False
|
| 14 |
+
for param in self.model.layer3.parameters():
|
| 15 |
+
param.requires_grad = True
|
| 16 |
+
for param in self.model.layer4.parameters():
|
| 17 |
+
param.requires_grad = True
|
| 18 |
+
self.model.fc = nn.Sequential(
|
| 19 |
+
nn.Dropout(0.5),
|
| 20 |
+
nn.Linear(self.model.fc.in_features, 256),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
nn.Dropout(0.3),
|
| 23 |
+
nn.Linear(256, num_classes)
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
return self.model(x)
|
| 28 |
+
|
| 29 |
+
class ResnetCarDamagePredictor:
|
| 30 |
+
def __init__(self, checkpoint_path, class_map):
|
| 31 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
self.class_map = class_map
|
| 33 |
+
self.test_transforms = transforms.Compose([
|
| 34 |
+
transforms.Resize((128, 128)),
|
| 35 |
+
transforms.ToTensor(),
|
| 36 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
| 37 |
+
[0.229, 0.224, 0.225])
|
| 38 |
+
])
|
| 39 |
+
try:
|
| 40 |
+
self.model = Car_Classifier_Resnet(num_classes=len(class_map))
|
| 41 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 42 |
+
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
| 43 |
+
self.model.load_state_dict(state_dict)
|
| 44 |
+
self.model.to(self.device)
|
| 45 |
+
self.model.eval()
|
| 46 |
+
except Exception as e:
|
| 47 |
+
raise RuntimeError(f"Failed to load ResNet model: {str(e)}")
|
| 48 |
+
|
| 49 |
+
def resnet_predict(self, image_input):
|
| 50 |
+
try:
|
| 51 |
+
if isinstance(image_input, str):
|
| 52 |
+
image = Image.open(image_input).convert("RGB")
|
| 53 |
+
elif isinstance(image_input, Image.Image):
|
| 54 |
+
image = image_input.convert("RGB")
|
| 55 |
+
else:
|
| 56 |
+
raise TypeError("image_input must be a file path or PIL.Image")
|
| 57 |
+
image = self.test_transforms(image)
|
| 58 |
+
image = image.unsqueeze(0).to(self.device)
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
outputs = self.model(image)
|
| 61 |
+
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
|
| 62 |
+
class_probs = {
|
| 63 |
+
self.class_map[i]: float(probs[i].item())
|
| 64 |
+
for i in range(len(self.class_map))
|
| 65 |
+
}
|
| 66 |
+
return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True))
|
| 67 |
+
except UnidentifiedImageError:
|
| 68 |
+
raise ValueError("Invalid image file provided")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
raise RuntimeError(f"ResNet prediction failed: {str(e)}")
|
| 71 |
+
|
| 72 |
+
class FusionClassifier(nn.Module):
|
| 73 |
+
def __init__(self, num_classes, convnext_model_name="facebook/convnext-small-224"):
|
| 74 |
+
super().__init__()
|
| 75 |
+
eff = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)
|
| 76 |
+
for param in eff.parameters():
|
| 77 |
+
param.requires_grad = False
|
| 78 |
+
for param in eff.features[5].parameters():
|
| 79 |
+
param.requires_grad = True
|
| 80 |
+
for param in eff.features[6].parameters():
|
| 81 |
+
param.requires_grad = True
|
| 82 |
+
for param in eff.features[7].parameters():
|
| 83 |
+
param.requires_grad = True
|
| 84 |
+
self.eff_features = eff.features
|
| 85 |
+
self.eff_avgpool = eff.avgpool
|
| 86 |
+
self.eff_out_dim = eff.classifier[1].in_features
|
| 87 |
+
cnx = ConvNextModel.from_pretrained(convnext_model_name)
|
| 88 |
+
for param in cnx.parameters():
|
| 89 |
+
param.requires_grad = False
|
| 90 |
+
for param in cnx.encoder.stages[2].parameters():
|
| 91 |
+
param.requires_grad = True
|
| 92 |
+
for param in cnx.encoder.stages[3].parameters():
|
| 93 |
+
param.requires_grad = True
|
| 94 |
+
for param in cnx.layernorm.parameters():
|
| 95 |
+
param.requires_grad = True
|
| 96 |
+
self.cnx_backbone = cnx
|
| 97 |
+
self.cnx_out_dim = 768
|
| 98 |
+
fused_dim = self.eff_out_dim + self.cnx_out_dim
|
| 99 |
+
self.fusion_head = nn.Sequential(
|
| 100 |
+
nn.Dropout(p=0.4),
|
| 101 |
+
nn.Linear(fused_dim, 512),
|
| 102 |
+
nn.LayerNorm(512),
|
| 103 |
+
nn.GELU(),
|
| 104 |
+
nn.Dropout(p=0.3),
|
| 105 |
+
nn.Linear(512, 256),
|
| 106 |
+
nn.LayerNorm(256),
|
| 107 |
+
nn.GELU(),
|
| 108 |
+
nn.Dropout(p=0.2),
|
| 109 |
+
nn.Linear(256, num_classes)
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(self, pixel_values_eff, pixel_values_cnx):
|
| 113 |
+
x_eff = self.eff_features(pixel_values_eff)
|
| 114 |
+
x_eff = self.eff_avgpool(x_eff)
|
| 115 |
+
x_eff = torch.flatten(x_eff, 1)
|
| 116 |
+
cnx_out = self.cnx_backbone(pixel_values=pixel_values_cnx, return_dict=True)
|
| 117 |
+
x_cnx = cnx_out.pooler_output
|
| 118 |
+
fused = torch.cat([x_eff, x_cnx], dim=1)
|
| 119 |
+
return self.fusion_head(fused)
|
| 120 |
+
|
| 121 |
+
class FusionCarDamagePredictor:
|
| 122 |
+
def __init__(self, checkpoint_path, class_map, convnext_model_name="facebook/convnext-small-224"):
|
| 123 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 124 |
+
self.class_map = class_map
|
| 125 |
+
self.eff_normalize = transforms.Compose([
|
| 126 |
+
transforms.Resize((260, 260)),
|
| 127 |
+
transforms.ToTensor(),
|
| 128 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
| 129 |
+
[0.229, 0.224, 0.225])
|
| 130 |
+
])
|
| 131 |
+
self.convnext_processor = ConvNextImageProcessor.from_pretrained(convnext_model_name)
|
| 132 |
+
try:
|
| 133 |
+
self.model = FusionClassifier(
|
| 134 |
+
num_classes=len(class_map),
|
| 135 |
+
convnext_model_name=convnext_model_name
|
| 136 |
+
)
|
| 137 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 138 |
+
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
| 139 |
+
first_tensor = next(iter(state_dict.values()))
|
| 140 |
+
if first_tensor.dtype == torch.float16:
|
| 141 |
+
self.model = self.model.half()
|
| 142 |
+
self.model.load_state_dict(state_dict)
|
| 143 |
+
self.model.to(self.device)
|
| 144 |
+
self.model.eval()
|
| 145 |
+
except Exception as e:
|
| 146 |
+
raise RuntimeError(f"Failed to load Fusion model: {str(e)}")
|
| 147 |
+
|
| 148 |
+
def predict(self, image_input):
|
| 149 |
+
try:
|
| 150 |
+
if isinstance(image_input, str):
|
| 151 |
+
image = Image.open(image_input).convert("RGB")
|
| 152 |
+
elif isinstance(image_input, Image.Image):
|
| 153 |
+
image = image_input.convert("RGB")
|
| 154 |
+
else:
|
| 155 |
+
raise TypeError("image_input must be a file path or PIL.Image")
|
| 156 |
+
pixel_eff = self.eff_normalize(image)
|
| 157 |
+
pixel_eff = pixel_eff.unsqueeze(0).to(self.device)
|
| 158 |
+
inputs_cnx = self.convnext_processor(images=image, return_tensors="pt")
|
| 159 |
+
pixel_cnx = inputs_cnx["pixel_values"].to(self.device)
|
| 160 |
+
if next(self.model.parameters()).dtype == torch.float16:
|
| 161 |
+
pixel_eff = pixel_eff.half()
|
| 162 |
+
pixel_cnx = pixel_cnx.half()
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
logits = self.model(pixel_eff, pixel_cnx)
|
| 165 |
+
probs = torch.nn.functional.softmax(logits, dim=1)[0]
|
| 166 |
+
class_probs = {
|
| 167 |
+
self.class_map[i]: float(probs[i].item())
|
| 168 |
+
for i in range(len(self.class_map))
|
| 169 |
+
}
|
| 170 |
+
return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True))
|
| 171 |
+
except UnidentifiedImageError:
|
| 172 |
+
raise ValueError("Invalid image file provided")
|
| 173 |
+
except Exception as e:
|
| 174 |
+
raise RuntimeError(f"Fusion prediction failed: {str(e)}")
|
| 175 |
+
|
scripts/yolo.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from ultralytics import YOLO
|
| 5 |
+
from scripts.model_loader import ModelLoader
|
| 6 |
+
|
| 7 |
+
yolo_model = None
|
| 8 |
+
|
| 9 |
+
def get_yolo_model():
|
| 10 |
+
global yolo_model
|
| 11 |
+
if yolo_model is None:
|
| 12 |
+
loader = ModelLoader()
|
| 13 |
+
yolo_path = loader.get_model_path("yolo")
|
| 14 |
+
yolo_model = YOLO(str(yolo_path))
|
| 15 |
+
return yolo_model
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_yolo_damage_boxes(image_path, output_path):
|
| 19 |
+
try:
|
| 20 |
+
image = Image.open(image_path).convert("RGB")
|
| 21 |
+
model = get_yolo_model()
|
| 22 |
+
results = model.predict(
|
| 23 |
+
source=image,
|
| 24 |
+
conf=0.05,
|
| 25 |
+
imgsz=640,
|
| 26 |
+
verbose=False
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
result = results[0]
|
| 30 |
+
boxes = result.boxes
|
| 31 |
+
detections = []
|
| 32 |
+
|
| 33 |
+
if boxes is not None and len(boxes) > 0:
|
| 34 |
+
for box in boxes:
|
| 35 |
+
conf = float(box.conf[0])
|
| 36 |
+
cls_id = int(box.cls[0])
|
| 37 |
+
label = yolo_model.names[cls_id]
|
| 38 |
+
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
| 39 |
+
|
| 40 |
+
detections.append({
|
| 41 |
+
"label": label,
|
| 42 |
+
"confidence": round(conf, 4),
|
| 43 |
+
"box": [x1, y1, x2, y2]
|
| 44 |
+
})
|
| 45 |
+
|
| 46 |
+
plotted_bgr = result.plot()
|
| 47 |
+
plotted_rgb = plotted_bgr[..., ::-1]
|
| 48 |
+
cv2.imwrite(output_path, plotted_rgb)
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
"detections": detections,
|
| 52 |
+
"total_detections": len(detections),
|
| 53 |
+
"message": "No damage detected" if len(detections) == 0 else "Detections found"
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
raise RuntimeError(f"YOLO failed: {str(e)}")
|