junaid17 commited on
Commit
1ae016f
·
verified ·
1 Parent(s): 7822be3

Upload 15 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1
4
+ ENV PYTHONUNBUFFERED=1
5
+
6
+ WORKDIR /app
7
+
8
+ # --- SYSTEM DEPENDENCIES (CRITICAL FOR OPENCV / YOLO) ---
9
+ RUN apt-get update && apt-get install -y \
10
+ build-essential \
11
+ gcc \
12
+ libgl1 \
13
+ libglib2.0-0 \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # --- PYTHON DEPENDENCIES ---
17
+ COPY requirements.txt .
18
+ RUN pip install --no-cache-dir --upgrade pip \
19
+ && pip install --no-cache-dir -r requirements.txt
20
+
21
+ # --- APP CODE ---
22
+ COPY . .
23
+
24
+ EXPOSE 7860
25
+
26
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
Notebooks/EfficientNet_ConvNext_Fusion.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Notebooks/Model_Compression.ipynb ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "671818be",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Model Conversion or Compression \n",
9
+ "**This notebook demonstrates how to convert a PyTorch model to FP16 precision, which can reduce the model size and potentially speed up inference on compatible hardware. We will use the `FusionClassifier` as an example, but the same approach can be applied to other models as well.**\n",
10
+ "\n",
11
+ "**From FP32 to FP16**"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 7,
17
+ "id": "b1715593",
18
+ "metadata": {},
19
+ "outputs": [
20
+ {
21
+ "name": "stderr",
22
+ "output_type": "stream",
23
+ "text": [
24
+ "Loading weights: 100%|██████████| 342/342 [00:00<00:00, 2845.51it/s]\n",
25
+ "[transformers] \u001b[1mConvNextModel LOAD REPORT\u001b[0m from: facebook/convnext-small-224\n",
26
+ "Key | Status | | \n",
27
+ "------------------+------------+--+-\n",
28
+ "classifier.bias | UNEXPECTED | | \n",
29
+ "classifier.weight | UNEXPECTED | | \n",
30
+ "\n",
31
+ "Notes:\n",
32
+ "- UNEXPECTED:\tcan be ignored when loading from different task/architecture; not ok if you expect identical arch.\n"
33
+ ]
34
+ },
35
+ {
36
+ "name": "stdout",
37
+ "output_type": "stream",
38
+ "text": [
39
+ "============================================================\n",
40
+ "Initializing model...\n",
41
+ "============================================================\n",
42
+ "Model weights loaded successfully.\n",
43
+ "Model converted to FP16.\n",
44
+ "============================================================\n",
45
+ "FP16 model saved successfully.\n",
46
+ "Saved Path : D:\\DamageLens\\checkpoints\\best_fusion_model_fp16.pth\n",
47
+ "FP16 Model Size : 135.77 MB\n",
48
+ "============================================================\n"
49
+ ]
50
+ }
51
+ ],
52
+ "source": [
53
+ "import os\n",
54
+ "import torch\n",
55
+ "import torch.nn as nn\n",
56
+ "import torchvision.models as models\n",
57
+ "from transformers import ConvNextModel\n",
58
+ "\n",
59
+ "\n",
60
+ "# =========================================================\n",
61
+ "# FUSION MODEL\n",
62
+ "# =========================================================\n",
63
+ "\n",
64
+ "class FusionClassifier(nn.Module):\n",
65
+ " def __init__(self, num_classes, convnext_model_name=\"facebook/convnext-small-224\"):\n",
66
+ " super().__init__()\n",
67
+ "\n",
68
+ " # -------------------------------------------------\n",
69
+ " # EfficientNet-V2-S\n",
70
+ " # -------------------------------------------------\n",
71
+ " eff = models.efficientnet_v2_s(\n",
72
+ " weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1\n",
73
+ " )\n",
74
+ "\n",
75
+ " # Freeze all\n",
76
+ " for param in eff.parameters():\n",
77
+ " param.requires_grad = False\n",
78
+ "\n",
79
+ " # Unfreeze last stages\n",
80
+ " for param in eff.features[5].parameters():\n",
81
+ " param.requires_grad = True\n",
82
+ "\n",
83
+ " for param in eff.features[6].parameters():\n",
84
+ " param.requires_grad = True\n",
85
+ "\n",
86
+ " for param in eff.features[7].parameters():\n",
87
+ " param.requires_grad = True\n",
88
+ "\n",
89
+ " self.eff_features = eff.features\n",
90
+ " self.eff_avgpool = eff.avgpool\n",
91
+ " self.eff_out_dim = eff.classifier[1].in_features # 1280\n",
92
+ "\n",
93
+ " # -------------------------------------------------\n",
94
+ " # ConvNeXt Small\n",
95
+ " # -------------------------------------------------\n",
96
+ " cnx = ConvNextModel.from_pretrained(convnext_model_name)\n",
97
+ "\n",
98
+ " # Freeze all\n",
99
+ " for param in cnx.parameters():\n",
100
+ " param.requires_grad = False\n",
101
+ "\n",
102
+ " # Unfreeze stages\n",
103
+ " for param in cnx.encoder.stages[2].parameters():\n",
104
+ " param.requires_grad = True\n",
105
+ "\n",
106
+ " for param in cnx.encoder.stages[3].parameters():\n",
107
+ " param.requires_grad = True\n",
108
+ "\n",
109
+ " for param in cnx.layernorm.parameters():\n",
110
+ " param.requires_grad = True\n",
111
+ "\n",
112
+ " self.cnx_backbone = cnx\n",
113
+ " self.cnx_out_dim = 768\n",
114
+ "\n",
115
+ " # -------------------------------------------------\n",
116
+ " # Fusion Head\n",
117
+ " # -------------------------------------------------\n",
118
+ " fused_dim = self.eff_out_dim + self.cnx_out_dim\n",
119
+ "\n",
120
+ " self.fusion_head = nn.Sequential(\n",
121
+ " nn.Dropout(0.4),\n",
122
+ "\n",
123
+ " nn.Linear(fused_dim, 512),\n",
124
+ " nn.LayerNorm(512),\n",
125
+ " nn.GELU(),\n",
126
+ "\n",
127
+ " nn.Dropout(0.3),\n",
128
+ "\n",
129
+ " nn.Linear(512, 256),\n",
130
+ " nn.LayerNorm(256),\n",
131
+ " nn.GELU(),\n",
132
+ "\n",
133
+ " nn.Dropout(0.2),\n",
134
+ "\n",
135
+ " nn.Linear(256, num_classes)\n",
136
+ " )\n",
137
+ "\n",
138
+ " def forward(self, pixel_values_eff, pixel_values_cnx):\n",
139
+ "\n",
140
+ " # EfficientNet branch\n",
141
+ " x_eff = self.eff_features(pixel_values_eff)\n",
142
+ " x_eff = self.eff_avgpool(x_eff)\n",
143
+ " x_eff = torch.flatten(x_eff, 1)\n",
144
+ "\n",
145
+ " # ConvNeXt branch\n",
146
+ " cnx_out = self.cnx_backbone(\n",
147
+ " pixel_values=pixel_values_cnx,\n",
148
+ " return_dict=True\n",
149
+ " )\n",
150
+ "\n",
151
+ " x_cnx = cnx_out.pooler_output\n",
152
+ "\n",
153
+ " # Fusion\n",
154
+ " fused = torch.cat([x_eff, x_cnx], dim=1)\n",
155
+ "\n",
156
+ " logits = self.fusion_head(fused)\n",
157
+ "\n",
158
+ " return logits\n",
159
+ "\n",
160
+ "\n",
161
+ "# =========================================================\n",
162
+ "# CONFIG\n",
163
+ "# =========================================================\n",
164
+ "\n",
165
+ "class_map = {\n",
166
+ " 0: \"Front Breakage\",\n",
167
+ " 1: \"Front Crushed\",\n",
168
+ " 2: \"Front Normal\",\n",
169
+ " 3: \"Rear Breakage\",\n",
170
+ " 4: \"Rear Crushed\",\n",
171
+ " 5: \"Rear Normal\"\n",
172
+ "}\n",
173
+ "\n",
174
+ "device = torch.device(\"cpu\")\n",
175
+ "\n",
176
+ "CHECKPOINT_PATH = r\"D:\\DamageLens\\checkpoints\\best_fusion_model.pt\"\n",
177
+ "\n",
178
+ "SAVE_FP16_PATH = r\"D:\\DamageLens\\checkpoints\\best_fusion_model_fp16.pth\"\n",
179
+ "\n",
180
+ "NUM_CLASSES = len(class_map)\n",
181
+ "\n",
182
+ "CONVNEXT_MODEL_NAME = \"facebook/convnext-small-224\"\n",
183
+ "\n",
184
+ "\n",
185
+ "# =========================================================\n",
186
+ "# INITIALIZE MODEL\n",
187
+ "# =========================================================\n",
188
+ "\n",
189
+ "model = FusionClassifier(\n",
190
+ " num_classes=NUM_CLASSES,\n",
191
+ " convnext_model_name=CONVNEXT_MODEL_NAME\n",
192
+ ")\n",
193
+ "\n",
194
+ "print(\"=\" * 60)\n",
195
+ "print(\"Initializing model...\")\n",
196
+ "print(\"=\" * 60)\n",
197
+ "\n",
198
+ "\n",
199
+ "# =========================================================\n",
200
+ "# LOAD TRAINED WEIGHTS\n",
201
+ "# =========================================================\n",
202
+ "\n",
203
+ "checkpoint = torch.load(\n",
204
+ " CHECKPOINT_PATH,\n",
205
+ " map_location=device\n",
206
+ ")\n",
207
+ "\n",
208
+ "# If checkpoint contains state_dict\n",
209
+ "if \"model_state_dict\" in checkpoint:\n",
210
+ " model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
211
+ "\n",
212
+ "# If checkpoint is directly state_dict\n",
213
+ "else:\n",
214
+ " model.load_state_dict(checkpoint)\n",
215
+ "\n",
216
+ "print(\"Model weights loaded successfully.\")\n",
217
+ "\n",
218
+ "\n",
219
+ "# =========================================================\n",
220
+ "# CONVERT TO FP16\n",
221
+ "# =========================================================\n",
222
+ "\n",
223
+ "model = model.half()\n",
224
+ "\n",
225
+ "print(\"Model converted to FP16.\")\n",
226
+ "\n",
227
+ "\n",
228
+ "# =========================================================\n",
229
+ "# CREATE CHECKPOINT DIRECTORY\n",
230
+ "# =========================================================\n",
231
+ "\n",
232
+ "os.makedirs(\"checkpoints\", exist_ok=True)\n",
233
+ "\n",
234
+ "\n",
235
+ "# =========================================================\n",
236
+ "# SAVE FP16 MODEL\n",
237
+ "# =========================================================\n",
238
+ "\n",
239
+ "torch.save(\n",
240
+ " model.state_dict(),\n",
241
+ " SAVE_FP16_PATH\n",
242
+ ")\n",
243
+ "\n",
244
+ "print(\"=\" * 60)\n",
245
+ "print(\"FP16 model saved successfully.\")\n",
246
+ "print(f\"Saved Path : {SAVE_FP16_PATH}\")\n",
247
+ "\n",
248
+ "size_mb = os.path.getsize(SAVE_FP16_PATH) / (1024 * 1024)\n",
249
+ "\n",
250
+ "print(f\"FP16 Model Size : {size_mb:.2f} MB\")\n",
251
+ "print(\"=\" * 60)"
252
+ ]
253
+ }
254
+ ],
255
+ "metadata": {
256
+ "kernelspec": {
257
+ "display_name": "myvenv",
258
+ "language": "python",
259
+ "name": "python3"
260
+ },
261
+ "language_info": {
262
+ "codemirror_mode": {
263
+ "name": "ipython",
264
+ "version": 3
265
+ },
266
+ "file_extension": ".py",
267
+ "mimetype": "text/x-python",
268
+ "name": "python",
269
+ "nbconvert_exporter": "python",
270
+ "pygments_lexer": "ipython3",
271
+ "version": "3.11.0"
272
+ }
273
+ },
274
+ "nbformat": 4,
275
+ "nbformat_minor": 5
276
+ }
Notebooks/Resnet18_fine_tuning.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Notebooks/damage_detector_yolo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import shutil
4
+ from fastapi import FastAPI, UploadFile, File, HTTPException
5
+ from fastapi.staticfiles import StaticFiles
6
+ from PIL import Image
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from dotenv import load_dotenv
9
+ from scripts.gradcam import get_resnet_gradcam, get_fusion_gradcam
10
+ from scripts.yolo import get_yolo_damage_boxes
11
+ from scripts.model_loader import initialize_models
12
+
13
+ load_dotenv()
14
+ app = FastAPI()
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"],
18
+ allow_credentials=True,
19
+ allow_methods=["*"],
20
+ allow_headers=["*"],
21
+ )
22
+ UPLOAD_DIR = "static/uploads"
23
+ RESULT_DIR = "static/results"
24
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
25
+ os.makedirs(RESULT_DIR, exist_ok=True)
26
+ app.mount("/static", StaticFiles(directory="static"), name="static")
27
+ class_map = {
28
+ 0: "Front Breakage",
29
+ 1: "Front Crushed",
30
+ 2: "Front Normal",
31
+ 3: "Rear Breakage",
32
+ 4: "Rear Crushed",
33
+ 5: "Rear Normal"
34
+ }
35
+ resnet_predictor, fusion_predictor = initialize_models(class_map)
36
+
37
+ @app.get("/")
38
+ def api_status():
39
+ return {"status": "API is running"}
40
+
41
+ @app.post("/predict")
42
+ async def predict_and_generate_cams(file: UploadFile = File(...), mode: str = "resnet"):
43
+ mode = mode.lower()
44
+ if mode not in {"resnet", "fusion"}:
45
+ raise HTTPException(status_code=400, detail="mode must be 'resnet' or 'fusion'")
46
+ unique_id = str(uuid.uuid4())
47
+ input_filename = f"{unique_id}_input.jpg"
48
+ input_path = os.path.join(UPLOAD_DIR, input_filename)
49
+ with open(input_path, "wb") as buffer:
50
+ shutil.copyfileobj(file.file, buffer)
51
+ if mode == "resnet":
52
+ output_name = f"{unique_id}_resnet.jpg"
53
+ output_path = os.path.join(RESULT_DIR, output_name)
54
+ get_resnet_gradcam(input_path, resnet_predictor, output_path)
55
+ selected_viz = f"/static/results/{output_name}"
56
+ resnet_viz = selected_viz
57
+ fusion_viz = None
58
+ else:
59
+ output_name = f"{unique_id}_fusion.jpg"
60
+ output_path = os.path.join(RESULT_DIR, output_name)
61
+ get_fusion_gradcam(input_path, fusion_predictor, output_path)
62
+ selected_viz = f"/static/results/{output_name}"
63
+ resnet_viz = None
64
+ fusion_viz = selected_viz
65
+ return {
66
+ "status": "success",
67
+ "original_image": f"/static/uploads/{input_filename}",
68
+ "selected_viz": selected_viz,
69
+ "resnet_viz": resnet_viz,
70
+ "fusion_viz": fusion_viz,
71
+ "mode": mode
72
+ }
73
+
74
+ @app.post("/predict/resnet")
75
+ async def resnet_prediction(image: UploadFile = File(...)):
76
+ try:
77
+ image = Image.open(image.file).convert("RGB")
78
+ except Exception:
79
+ raise HTTPException(status_code=400, detail="Invalid image file")
80
+ return resnet_predictor.resnet_predict(image_input=image)
81
+
82
+ @app.post("/predict/fusion")
83
+ async def fusion_prediction(image: UploadFile = File(...)):
84
+ try:
85
+ image = Image.open(image.file).convert("RGB")
86
+ except Exception:
87
+ raise HTTPException(status_code=400, detail="Invalid image file")
88
+ return fusion_predictor.predict(image_input=image)
89
+
90
+ @app.post("/predict/yolo")
91
+ async def yolo_detection(file: UploadFile = File(...)):
92
+ unique_id = str(uuid.uuid4())
93
+ input_filename = f"{unique_id}_input.jpg"
94
+ yolo_out_name = f"{unique_id}_yolo.jpg"
95
+ input_path = os.path.join(UPLOAD_DIR, input_filename)
96
+ yolo_path = os.path.join(RESULT_DIR, yolo_out_name)
97
+ with open(input_path, "wb") as buffer:
98
+ shutil.copyfileobj(file.file, buffer)
99
+ result = get_yolo_damage_boxes(input_path, yolo_path)
100
+ return {
101
+ "status": "success",
102
+ "original_image": f"/static/uploads/{input_filename}",
103
+ "yolo_image": f"/static/results/{yolo_out_name}",
104
+ "detections": result["detections"],
105
+ "total_detections": result["total_detections"],
106
+ "message": result["message"]
107
+ }
checkpoints/best_fusion_model_fp16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59cede9aca6c4b39b6447458ddb9cdc3e3ba06c5d972ad62b6807bfcd0afa466
3
+ size 142369497
checkpoints/best_resnet_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:688cbd4f9eb2e97b6e67287b23f5f750b0367dfb08844704d49075fb086bbdd5
3
+ size 130360907
checkpoints/damage_detector.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c3b9222d9977b5bfd78d65ea6be9d609c81de473349bb3f362088a86ba07f9f
3
+ size 51189913
index.html ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Car Damage AI</title>
7
+ <script src="https://cdn.plot.ly/plotly-2.27.0.min.js"></script>
8
+ <style>
9
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;800&display=swap');
10
+
11
+ :root {
12
+ --bg-dark: #09090b;
13
+ --bg-card: #18181b;
14
+ --text-primary: #e2e8f0;
15
+ --text-secondary: #a1a1aa;
16
+ --accent: #00c6ff;
17
+ --accent-hover: #0072ff;
18
+ --glass: rgba(255, 255, 255, 0.03);
19
+ --card-border: #27272a;
20
+ }
21
+
22
+ * { margin: 0; padding: 0; box-sizing: border-box; font-family: 'Inter', sans-serif; }
23
+
24
+ body {
25
+ background-color: var(--bg-dark);
26
+ color: var(--text-primary);
27
+ min-height: 100vh;
28
+ display: flex;
29
+ justify-content: center;
30
+ align-items: flex-start;
31
+ padding: 40px 20px;
32
+ background-image: radial-gradient(circle at top right, rgba(0, 198, 255, 0.05) 0%, transparent 40%);
33
+ }
34
+
35
+ .container {
36
+ width: 100%;
37
+ max-width: 850px;
38
+ background: var(--bg-card);
39
+ border-radius: 20px;
40
+ padding: 35px;
41
+ box-shadow: 0 20px 40px rgba(0,0,0,0.6);
42
+ animation: slideUpFade 0.6s ease-out forwards;
43
+ border: 1px solid var(--card-border);
44
+ }
45
+
46
+ @keyframes slideUpFade { from { opacity: 0; transform: translateY(30px); } to { opacity: 1; transform: translateY(0); } }
47
+
48
+ /* Shimmering Main Title */
49
+ .shimmer-text {
50
+ text-align: center;
51
+ font-size: 2.5rem;
52
+ font-weight: 800;
53
+ background: linear-gradient(90deg, #e2e8f0 0%, #ffffff 25%, #00c6ff 50%, #e2e8f0 75%, #e2e8f0 100%);
54
+ background-size: 200% auto;
55
+ color: transparent;
56
+ -webkit-background-clip: text;
57
+ background-clip: text;
58
+ animation: shimmer 4s linear infinite;
59
+ margin-bottom: 0.2rem;
60
+ }
61
+ @keyframes shimmer { 0% { background-position: -200% center; } 100% { background-position: 200% center; } }
62
+
63
+ .subtitle { text-align: center; color: var(--text-secondary); font-size: 1rem; margin-bottom: 25px; }
64
+
65
+ /* Warning Box */
66
+ .warning-box {
67
+ background: rgba(0, 198, 255, 0.1);
68
+ border-left: 4px solid var(--accent);
69
+ color: var(--text-primary);
70
+ padding: 12px 15px;
71
+ border-radius: 8px;
72
+ margin-bottom: 25px;
73
+ font-size: 0.9rem;
74
+ display: flex;
75
+ align-items: center;
76
+ gap: 12px;
77
+ }
78
+
79
+ /* Controls Section */
80
+ .controls-grid {
81
+ display: grid;
82
+ grid-template-columns: 1fr 1fr;
83
+ gap: 20px;
84
+ margin-bottom: 25px;
85
+ }
86
+
87
+ .file-wrapper {
88
+ position: relative; height: 160px; border: 2px dashed #444; border-radius: 16px;
89
+ display: flex; justify-content: center; align-items: center; cursor: pointer;
90
+ transition: all 0.3s ease; background: var(--glass); overflow: hidden;
91
+ }
92
+ .file-wrapper:hover { border-color: var(--accent); background: rgba(0, 198, 255, 0.05); }
93
+ .file-wrapper input { position: absolute; width: 100%; height: 100%; opacity: 0; cursor: pointer; z-index: 2; }
94
+
95
+ .settings-card {
96
+ background: rgba(0,0,0,0.2);
97
+ border-radius: 16px;
98
+ padding: 20px;
99
+ border: 1px solid var(--card-border);
100
+ display: flex;
101
+ flex-direction: column;
102
+ justify-content: center;
103
+ }
104
+
105
+ select {
106
+ width: 100%; background: #27272a; border: 1px solid #3f3f46; padding: 14px;
107
+ border-radius: 12px; color: white; outline: none; margin-top: 10px; font-size: 1rem;
108
+ }
109
+ select:focus { border-color: var(--accent); }
110
+
111
+ /* Preview Area & Animations */
112
+ .image-area {
113
+ width: 100%; height: 350px; background: #09090b; border-radius: 16px;
114
+ margin-bottom: 25px; display: none; justify-content: center; align-items: center;
115
+ overflow: hidden; position: relative; border: 1px solid var(--card-border);
116
+ }
117
+ .image-area img { max-width: 100%; max-height: 100%; object-fit: contain; z-index: 1;}
118
+
119
+ /* Scanner Animation */
120
+ .scan-line {
121
+ position: absolute; top: -10%; left: 0; width: 100%; height: 5px;
122
+ background: var(--accent); box-shadow: 0 0 15px var(--accent), 0 0 30px var(--accent);
123
+ z-index: 5; opacity: 0.8; display: none; animation: scanMove 2s ease-in-out infinite; filter: blur(1px);
124
+ }
125
+ @keyframes scanMove { 0% { top: -10%; opacity: 0.5; } 50% { opacity: 1; } 100% { top: 110%; opacity: 0.5; } }
126
+
127
+ /* Loader Overlay */
128
+ .loader-overlay {
129
+ position: absolute; top: 0; left: 0; width: 100%; height: 100%;
130
+ background: rgba(0,0,0,0.65); backdrop-filter: blur(4px);
131
+ display: none; flex-direction: column; justify-content: center; align-items: center; z-index: 10;
132
+ }
133
+ .spinner {
134
+ width: 50px; height: 50px; border: 4px solid rgba(0, 198, 255, 0.2);
135
+ border-top: 4px solid var(--accent); border-radius: 50%;
136
+ animation: spin 1s cubic-bezier(0.68, -0.55, 0.27, 1.55) infinite; margin-bottom: 15px;
137
+ }
138
+ @keyframes spin { 100% { transform: rotate(360deg); } }
139
+
140
+ /* Buttons */
141
+ .btn {
142
+ width: 100%; padding: 16px; background: linear-gradient(135deg, var(--accent) 0%, var(--accent-hover) 100%);
143
+ color: white; border: none; border-radius: 12px; cursor: pointer; font-weight: 700; font-size: 1rem;
144
+ transition: all 0.3s ease; box-shadow: 0 4px 15px rgba(0, 114, 255, 0.3);
145
+ }
146
+ .btn:hover:not(:disabled) { transform: scale(1.02); box-shadow: 0 8px 25px rgba(0, 198, 255, 0.5); }
147
+ .btn:disabled { background: #444; color: #888; box-shadow: none; transform: none; cursor: not-allowed;}
148
+
149
+ /* Results Tabs */
150
+ .results-section { display: none; margin-top: 30px; animation: slideUpFade 0.5s ease-out; }
151
+ .tabs { display: flex; gap: 10px; margin-bottom: 20px; border-bottom: 1px solid var(--card-border); padding-bottom: 10px; overflow-x: auto; }
152
+ .tab {
153
+ padding: 10px 20px; cursor: pointer; border-radius: 8px; color: var(--text-secondary);
154
+ font-weight: 600; transition: all 0.3s ease; white-space: nowrap;
155
+ }
156
+ .tab.active { background: rgba(0, 198, 255, 0.1); color: var(--accent); }
157
+ .tab-content { display: none; }
158
+ .tab-content.active { display: block; animation: slideUpFade 0.4s ease-out; }
159
+
160
+ /* Progress Bar */
161
+ .progress-wrapper { background: #27272a; border-radius: 20px; overflow: hidden; height: 12px; margin: 10px 0 20px 0; box-shadow: inset 0 2px 4px rgba(0,0,0,0.5); }
162
+ .progress-fill { height: 100%; background: linear-gradient(90deg, var(--accent), var(--accent-hover)); border-radius: 20px; width: 0%; transition: width 1.5s cubic-bezier(0.22, 1, 0.36, 1); }
163
+
164
+ /* Final Prediction Text */
165
+ .big-text { font-size: 2.5rem; font-weight: 800; background: -webkit-linear-gradient(45deg, #00c6ff, #0072ff); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 5px; }
166
+
167
+ /* Images Grid (Attention Maps) */
168
+ .img-grid { display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px; }
169
+ .img-card { background: rgba(0,0,0,0.3); border: 1px solid var(--card-border); border-radius: 12px; padding: 10px; text-align: center; }
170
+ .img-card img { width: 100%; border-radius: 8px; margin-top: 10px; }
171
+
172
+ /* YOLO Grid */
173
+ .yolo-grid { display: grid; grid-template-columns: 1.5fr 1fr; gap: 20px; }
174
+ .log-box { background: rgba(0,0,0,0.3); border: 1px solid var(--card-border); border-radius: 12px; padding: 20px; height: 100%; }
175
+ .detection-item { background: #27272a; padding: 12px; border-radius: 8px; margin-bottom: 10px; border-left: 4px solid var(--accent); box-shadow: 0 2px 4px rgba(0,0,0,0.2); }
176
+
177
+ @media (max-width: 768px) {
178
+ .controls-grid, .img-grid, .yolo-grid { grid-template-columns: 1fr; }
179
+ .shimmer-text { font-size: 2rem; }
180
+ }
181
+ </style>
182
+ </head>
183
+ <body>
184
+
185
+ <div class="container">
186
+ <div class="shimmer-text">🚗 Car Damage AI</div>
187
+ <div class="subtitle">Fusion Intelligence: ResNet + YOLO</div>
188
+
189
+ <div class="warning-box">
190
+ <span style="font-size: 1.2rem;">⏱️</span>
191
+ <span><b>Note:</b> The first analysis may take up to 3-4 mins while models warm up. Subsequent requests are faster!</span>
192
+ </div>
193
+
194
+ <div class="controls-grid">
195
+ <div class="file-wrapper">
196
+ <input type="file" id="fileInput" accept="image/jpeg, image/png, image/jpg">
197
+ <div style="text-align: center;">
198
+ <p style="font-size: 2.5rem; margin-bottom: 5px;">📷</p>
199
+ <p style="color:#a1a1aa; font-weight: 500;">Tap or Drag & Drop Vehicle Image</p>
200
+ </div>
201
+ </div>
202
+
203
+ <div class="settings-card">
204
+ <h3 style="font-size: 1.1rem; margin-bottom: 5px;">⚙️ Analysis Settings</h3>
205
+ <p style="font-size: 0.85rem; color: var(--text-secondary);">Select the neural network pipeline.</p>
206
+ <select id="engineMode">
207
+ <option value="fusion">Fusion</option>
208
+ <option value="resnet">ResNet</option>
209
+ </select>
210
+ </div>
211
+ </div>
212
+
213
+ <div class="image-area" id="previewBox">
214
+ <img id="displayImage" src="" alt="Car Image">
215
+ <div class="scan-line" id="scanLine"></div>
216
+ <div class="loader-overlay" id="loader">
217
+ <div class="spinner"></div>
218
+ <p style="color:white; font-weight:600; letter-spacing: 1px; margin-bottom: 5px;">🧠 ANALYZING...</p>
219
+ <p id="loaderStatusText" style="color:#00c6ff; font-size:0.9rem;">Extracting features...</p>
220
+ </div>
221
+ </div>
222
+
223
+ <button class="btn" id="analyzeBtn" onclick="analyze()">🚀 Run AI Analysis</button>
224
+
225
+ <div class="results-section" id="resultsSection">
226
+ <div class="tabs">
227
+ <div class="tab active" onclick="switchResultTab('tab-pred')">📊 Prediction</div>
228
+ <div class="tab" onclick="switchResultTab('tab-attention')">👀 Attention Maps</div>
229
+ <div class="tab" onclick="switchResultTab('tab-yolo')">🎯 Localization</div>
230
+ </div>
231
+
232
+ <div id="tab-pred" class="tab-content active">
233
+ <div class="settings-card">
234
+ <div id="finalPredText" class="big-text">--</div>
235
+ <div style="font-weight: 600; margin-top: 5px;" id="confText">Confidence Score: 0%</div>
236
+ <div class="progress-wrapper">
237
+ <div class="progress-fill" id="confBar"></div>
238
+ </div>
239
+ <h3 style="margin: 15px 0 5px 0; font-size: 1.1rem;">Probability Distribution</h3>
240
+ <div id="plotlyChart" style="width:100%; height:300px;"></div>
241
+ </div>
242
+ </div>
243
+
244
+ <div id="tab-attention" class="tab-content">
245
+ <div class="img-grid">
246
+ <div class="img-card">
247
+ <div style="font-weight:600; color:#e2e8f0;">Original Image</div>
248
+ <img id="camOriginal" src="" alt="Original Image">
249
+ </div>
250
+ <div class="img-card">
251
+ <div id="camSelectedLabel" style="font-weight:600; color:#e2e8f0;">Selected Grad-CAM</div>
252
+ <img id="camSelected" src="" alt="Selected Grad-CAM">
253
+ </div>
254
+ </div>
255
+ </div>
256
+
257
+ <div id="tab-yolo" class="tab-content">
258
+ <div class="yolo-grid">
259
+ <div class="settings-card">
260
+ <h3 style="margin-bottom: 10px;">Bounding Boxes</h3>
261
+ <img id="yoloImage" src="" alt="YOLO Output" style="width: 100%; border-radius: 8px;">
262
+ </div>
263
+ <div class="log-box">
264
+ <h3 style="margin-bottom: 15px;">Detection Log</h3>
265
+ <div id="yoloLogContainer">
266
+ </div>
267
+ </div>
268
+ </div>
269
+ </div>
270
+ </div>
271
+
272
+ </div>
273
+
274
+ <script>
275
+ const API_URL = "http://127.0.0.1:8000";
276
+ let currentFile = null;
277
+
278
+ // DOM Elements
279
+ const fileInput = document.getElementById('fileInput');
280
+ const displayImage = document.getElementById('displayImage');
281
+ const previewBox = document.getElementById('previewBox');
282
+ const resultsSection = document.getElementById('resultsSection');
283
+ const loader = document.getElementById('loader');
284
+ const loaderStatusText = document.getElementById('loaderStatusText');
285
+ const scanLine = document.getElementById('scanLine');
286
+ const analyzeBtn = document.getElementById('analyzeBtn');
287
+
288
+ fileInput.addEventListener('change', e => {
289
+ if(e.target.files[0]) {
290
+ currentFile = e.target.files[0];
291
+ const reader = new FileReader();
292
+ reader.onload = x => {
293
+ displayImage.src = x.target.result;
294
+ previewBox.style.display = 'flex';
295
+ resultsSection.style.display = 'none'; // Hide old results
296
+ };
297
+ reader.readAsDataURL(currentFile);
298
+ }
299
+ });
300
+
301
+ // --- BUG FIX IS HERE ---
302
+ function switchResultTab(tabId) {
303
+ // 1. Remove active state from all tabs and panels
304
+ document.querySelectorAll('.tab').forEach(t => t.classList.remove('active'));
305
+ document.querySelectorAll('.tab-content').forEach(c => c.classList.remove('active'));
306
+
307
+ // 2. Find the tab button that corresponds to this panel and make it active
308
+ const tabButton = document.querySelector(`.tab[onclick*="${tabId}"]`);
309
+ if(tabButton) {
310
+ tabButton.classList.add('active');
311
+ }
312
+
313
+ // 3. Make the specific panel active
314
+ document.getElementById(tabId).classList.add('active');
315
+
316
+ // 4. Resize Plotly chart if switching back to its tab to prevent layout squash
317
+ if(tabId === 'tab-pred') {
318
+ window.dispatchEvent(new Event('resize'));
319
+ }
320
+ }
321
+
322
+ // Plotly Chart Helper
323
+ function drawChart(dataObj, title) {
324
+ const labels = Object.keys(dataObj);
325
+ const values = Object.values(dataObj);
326
+
327
+ const trace = {
328
+ x: labels,
329
+ y: values,
330
+ type: 'bar',
331
+ marker: { color: '#00c6ff', line: { color: '#0072ff', width: 1.5 } },
332
+ opacity: 0.85
333
+ };
334
+
335
+ const layout = {
336
+ title: title || '',
337
+ paper_bgcolor: 'rgba(0,0,0,0)',
338
+ plot_bgcolor: 'rgba(0,0,0,0)',
339
+ font: { family: 'Inter', color: '#a1a1aa' },
340
+ margin: { l: 40, r: 10, t: 30, b: 40 },
341
+ xaxis: { title: 'Classes' },
342
+ yaxis: { title: 'Probability', range: [0, 1] }
343
+ };
344
+
345
+ Plotly.newPlot('plotlyChart', [trace], layout, {displayModeBar: false, responsive: true});
346
+ }
347
+
348
+ async function analyze() {
349
+ if(!currentFile) return alert("Please upload an image first.");
350
+
351
+ const engineMode = document.getElementById('engineMode').value; // fusion or resnet
352
+
353
+ // UI Prep
354
+ loader.style.display = 'flex';
355
+ scanLine.style.display = 'block';
356
+ analyzeBtn.disabled = true;
357
+ analyzeBtn.innerText = "Processing...";
358
+ resultsSection.style.display = 'none';
359
+
360
+ const formData = new FormData();
361
+ formData.append('image', currentFile);
362
+
363
+ try {
364
+ loaderStatusText.innerText = "Extracting features...";
365
+ const predRes = await fetch(`${API_URL}/predict/${engineMode}`, { method: 'POST', body: formData });
366
+ if (!predRes.ok) throw new Error("Prediction API failed");
367
+ const predData = await predRes.json();
368
+
369
+ loaderStatusText.innerText = "Generating Grad-CAM...";
370
+ const camForm = new FormData();
371
+ camForm.append('file', currentFile);
372
+ const camRes = await fetch(`${API_URL}/predict?mode=${engineMode}`, { method: 'POST', body: camForm });
373
+ if (!camRes.ok) throw new Error("Grad-CAM API failed");
374
+ const camData = await camRes.json();
375
+
376
+ loaderStatusText.innerText = "Running YOLO detection...";
377
+ const yoloRes = await fetch(`${API_URL}/predict/yolo`, { method: 'POST', body: camForm });
378
+ if (!yoloRes.ok) throw new Error("YOLO API failed");
379
+ const yoloData = await yoloRes.json();
380
+
381
+ const highestClass = Object.keys(predData).reduce((a, b) => predData[a] > predData[b] ? a : b);
382
+ const highestScore = predData[highestClass] || 0;
383
+ document.getElementById('finalPredText').innerText = highestClass;
384
+ document.getElementById('confText').innerText = `Confidence Score: ${(highestScore * 100).toFixed(2)}%`;
385
+ drawChart(predData, `${engineMode.toUpperCase()} Output`);
386
+ setTimeout(() => { document.getElementById('confBar').style.width = `${(highestScore * 100).toFixed(2)}%`; }, 100);
387
+
388
+ document.getElementById('camOriginal').src = `${API_URL}${camData.original_image}`;
389
+ document.getElementById('camSelected').src = `${API_URL}${camData.selected_viz}`;
390
+ document.getElementById('camSelectedLabel').innerText = engineMode === 'fusion' ? 'Fusion Grad-CAM' : 'ResNet Grad-CAM';
391
+
392
+ document.getElementById('yoloImage').src = `${API_URL}${yoloData.yolo_image}`;
393
+ const logContainer = document.getElementById('yoloLogContainer');
394
+ if (!yoloData.detections || yoloData.detections.length === 0) {
395
+ logContainer.innerHTML = '<div style="color: #a1a1aa; padding: 10px;">🟢 No damage boxes detected.</div>';
396
+ } else {
397
+ let logHTML = `<div style="color: #ffcc00; margin-bottom: 10px; font-weight:600;">🔴 Found ${yoloData.total_detections} damage region(s).</div>`;
398
+ yoloData.detections.forEach((det, idx) => {
399
+ logHTML += `<div class="detection-item"><b style="color: #e2e8f0;">Region ${idx + 1}</b><br><span style="color: #a1a1aa; font-size: 0.9em;">${det.label} · ${(det.confidence * 100).toFixed(1)}%</span></div>`;
400
+ });
401
+ logContainer.innerHTML = logHTML;
402
+ }
403
+
404
+ resultsSection.style.display = 'block';
405
+ switchResultTab('tab-pred');
406
+ } catch (error) {
407
+ alert(`Error connecting to AI server. Details: ${error.message}`);
408
+ console.error(error);
409
+ } finally {
410
+ loader.style.display = 'none';
411
+ scanLine.style.display = 'none';
412
+ analyzeBtn.disabled = false;
413
+ analyzeBtn.innerText = "🚀 Run AI Analysis";
414
+ }
415
+ }
416
+ </script>
417
+
418
+ </body>
419
+ </html>
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ fastapi
5
+ uvicorn
6
+ dotenv
7
+ matplotlib
8
+ opencv-python
9
+ python-multipart
10
+ ultralytics
11
+ plotly
12
+ pandas
scripts/gradcam.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ def get_resnet_gradcam(image_path, predictor, output_path):
8
+ model = predictor.model
9
+ device = predictor.device
10
+ model.eval()
11
+ features, gradients = [], []
12
+
13
+ def forward_hook(module, input, output):
14
+ features.append(output)
15
+ def backward_hook(module, grad_in, grad_out):
16
+ gradients.append(grad_out[0])
17
+
18
+ target_layer = model.model.layer4[-1]
19
+ handle_fw = target_layer.register_forward_hook(forward_hook)
20
+ handle_bw = target_layer.register_full_backward_hook(backward_hook)
21
+
22
+ original_img = Image.open(image_path).convert("RGB")
23
+ input_tensor = predictor.test_transforms(original_img).unsqueeze(0).to(device)
24
+ model.zero_grad()
25
+ output = model(input_tensor)
26
+ pred_class_idx = output.argmax(dim=1).item()
27
+ score = output[0, pred_class_idx]
28
+ score.backward()
29
+
30
+ handle_fw.remove()
31
+ handle_bw.remove()
32
+
33
+ acts = features[0].cpu().data.numpy()[0]
34
+ grads = gradients[0].cpu().data.numpy()[0]
35
+ weights = np.mean(grads, axis=(1, 2))
36
+ cam = np.zeros(acts.shape[1:], dtype=np.float32)
37
+ for i, w in enumerate(weights):
38
+ cam += w * acts[i]
39
+
40
+ cam = np.maximum(cam, 0)
41
+ cam = cv2.resize(cam, (original_img.width, original_img.height))
42
+ cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
43
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
44
+ original_np = np.array(original_img)
45
+ overlay = cv2.addWeighted(cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR), 0.6, heatmap, 0.4, 0)
46
+ cv2.imwrite(output_path, overlay)
47
+ return True
48
+
49
+ def get_fusion_gradcam(image_path, predictor, output_path):
50
+ model = predictor.model
51
+ device = predictor.device
52
+ model.eval()
53
+ target_layer = model.eff_features[-1]
54
+ activation = None
55
+
56
+ def forward_hook(module, inp, out):
57
+ nonlocal activation
58
+ activation = out
59
+ activation.retain_grad()
60
+
61
+ handle = target_layer.register_forward_hook(forward_hook)
62
+ original_img = Image.open(image_path).convert("RGB")
63
+ pixel_eff = predictor.eff_normalize(original_img).unsqueeze(0).to(device)
64
+ inputs_cnx = predictor.convnext_processor(images=original_img, return_tensors="pt")
65
+ pixel_cnx = inputs_cnx["pixel_values"].to(device)
66
+
67
+ if next(model.parameters()).dtype == torch.float16:
68
+ pixel_eff = pixel_eff.half()
69
+ pixel_cnx = pixel_cnx.half()
70
+
71
+ model.zero_grad()
72
+ output = model(pixel_eff, pixel_cnx)
73
+ pred_class_idx = output.argmax(dim=1).item()
74
+ score = output[0, pred_class_idx]
75
+ score.backward()
76
+ handle.remove()
77
+
78
+ if activation is None or activation.grad is None:
79
+ raise RuntimeError("Gradients could not be extracted. Ensure requires_grad=True is properly set.")
80
+
81
+ acts = activation[0].detach().float()
82
+ grads = activation.grad[0].detach().float()
83
+ weights = grads.mean(dim=(1, 2), keepdim=True)
84
+ cam = torch.sum(weights * acts, dim=0)
85
+ cam = F.relu(cam)
86
+ cam = cam.cpu().numpy()
87
+
88
+ if cam.max() > cam.min():
89
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
90
+ else:
91
+ cam = np.zeros_like(cam)
92
+
93
+ cam = np.uint8(255 * cam)
94
+ cam_resized = cv2.resize(cam, (original_img.width, original_img.height), interpolation=cv2.INTER_LINEAR)
95
+ heatmap = cv2.applyColorMap(cam_resized, cv2.COLORMAP_JET)
96
+ original_np = np.array(original_img)
97
+ original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
98
+ overlay = cv2.addWeighted(original_bgr, 0.5, heatmap, 0.6, 0)
99
+ cv2.imwrite(output_path, overlay)
100
+ return True
scripts/model_loader.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from .prediction_helper import ResnetCarDamagePredictor, FusionCarDamagePredictor
4
+
5
+ CHECKPOINT_DIR = Path(__file__).resolve().parents[1] / "checkpoints"
6
+ MODEL_FILES = {
7
+ "resnet": "best_resnet_model.pt",
8
+ "fusion": "best_fusion_model_fp16.pth",
9
+ "yolo": "damage_detector.pt",
10
+ }
11
+
12
+
13
+ def get_checkpoint_path(model_key: str) -> Path:
14
+ if model_key not in MODEL_FILES:
15
+ raise ValueError(f"Unknown model key: {model_key}")
16
+
17
+ path = CHECKPOINT_DIR / MODEL_FILES[model_key]
18
+ if not path.exists():
19
+ raise FileNotFoundError(f"Checkpoint not found: {path}")
20
+ return path
21
+
22
+
23
+ class ModelLoader:
24
+ def __init__(self):
25
+ self.base_dir = CHECKPOINT_DIR
26
+
27
+ def get_model_path(self, model_key: str) -> Path:
28
+ return get_checkpoint_path(model_key)
29
+
30
+
31
+ def initialize_models(class_map):
32
+ resnet_path = get_checkpoint_path("resnet")
33
+ fusion_path = get_checkpoint_path("fusion")
34
+
35
+ resnet_predictor = ResnetCarDamagePredictor(resnet_path, class_map)
36
+ fusion_predictor = FusionCarDamagePredictor(fusion_path, class_map)
37
+
38
+ return resnet_predictor, fusion_predictor
scripts/prediction_helper.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms, models
5
+ from PIL import Image, UnidentifiedImageError
6
+ from transformers import ConvNextModel, ConvNextImageProcessor
7
+
8
+ class Car_Classifier_Resnet(nn.Module):
9
+ def __init__(self, num_classes):
10
+ super().__init__()
11
+ self.model = models.resnet18(weights="DEFAULT")
12
+ for param in self.model.parameters():
13
+ param.requires_grad = False
14
+ for param in self.model.layer3.parameters():
15
+ param.requires_grad = True
16
+ for param in self.model.layer4.parameters():
17
+ param.requires_grad = True
18
+ self.model.fc = nn.Sequential(
19
+ nn.Dropout(0.5),
20
+ nn.Linear(self.model.fc.in_features, 256),
21
+ nn.ReLU(),
22
+ nn.Dropout(0.3),
23
+ nn.Linear(256, num_classes)
24
+ )
25
+
26
+ def forward(self, x):
27
+ return self.model(x)
28
+
29
+ class ResnetCarDamagePredictor:
30
+ def __init__(self, checkpoint_path, class_map):
31
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ self.class_map = class_map
33
+ self.test_transforms = transforms.Compose([
34
+ transforms.Resize((128, 128)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize([0.485, 0.456, 0.406],
37
+ [0.229, 0.224, 0.225])
38
+ ])
39
+ try:
40
+ self.model = Car_Classifier_Resnet(num_classes=len(class_map))
41
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
42
+ state_dict = checkpoint.get("model_state_dict", checkpoint)
43
+ self.model.load_state_dict(state_dict)
44
+ self.model.to(self.device)
45
+ self.model.eval()
46
+ except Exception as e:
47
+ raise RuntimeError(f"Failed to load ResNet model: {str(e)}")
48
+
49
+ def resnet_predict(self, image_input):
50
+ try:
51
+ if isinstance(image_input, str):
52
+ image = Image.open(image_input).convert("RGB")
53
+ elif isinstance(image_input, Image.Image):
54
+ image = image_input.convert("RGB")
55
+ else:
56
+ raise TypeError("image_input must be a file path or PIL.Image")
57
+ image = self.test_transforms(image)
58
+ image = image.unsqueeze(0).to(self.device)
59
+ with torch.no_grad():
60
+ outputs = self.model(image)
61
+ probs = torch.nn.functional.softmax(outputs, dim=1)[0]
62
+ class_probs = {
63
+ self.class_map[i]: float(probs[i].item())
64
+ for i in range(len(self.class_map))
65
+ }
66
+ return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True))
67
+ except UnidentifiedImageError:
68
+ raise ValueError("Invalid image file provided")
69
+ except Exception as e:
70
+ raise RuntimeError(f"ResNet prediction failed: {str(e)}")
71
+
72
+ class FusionClassifier(nn.Module):
73
+ def __init__(self, num_classes, convnext_model_name="facebook/convnext-small-224"):
74
+ super().__init__()
75
+ eff = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)
76
+ for param in eff.parameters():
77
+ param.requires_grad = False
78
+ for param in eff.features[5].parameters():
79
+ param.requires_grad = True
80
+ for param in eff.features[6].parameters():
81
+ param.requires_grad = True
82
+ for param in eff.features[7].parameters():
83
+ param.requires_grad = True
84
+ self.eff_features = eff.features
85
+ self.eff_avgpool = eff.avgpool
86
+ self.eff_out_dim = eff.classifier[1].in_features
87
+ cnx = ConvNextModel.from_pretrained(convnext_model_name)
88
+ for param in cnx.parameters():
89
+ param.requires_grad = False
90
+ for param in cnx.encoder.stages[2].parameters():
91
+ param.requires_grad = True
92
+ for param in cnx.encoder.stages[3].parameters():
93
+ param.requires_grad = True
94
+ for param in cnx.layernorm.parameters():
95
+ param.requires_grad = True
96
+ self.cnx_backbone = cnx
97
+ self.cnx_out_dim = 768
98
+ fused_dim = self.eff_out_dim + self.cnx_out_dim
99
+ self.fusion_head = nn.Sequential(
100
+ nn.Dropout(p=0.4),
101
+ nn.Linear(fused_dim, 512),
102
+ nn.LayerNorm(512),
103
+ nn.GELU(),
104
+ nn.Dropout(p=0.3),
105
+ nn.Linear(512, 256),
106
+ nn.LayerNorm(256),
107
+ nn.GELU(),
108
+ nn.Dropout(p=0.2),
109
+ nn.Linear(256, num_classes)
110
+ )
111
+
112
+ def forward(self, pixel_values_eff, pixel_values_cnx):
113
+ x_eff = self.eff_features(pixel_values_eff)
114
+ x_eff = self.eff_avgpool(x_eff)
115
+ x_eff = torch.flatten(x_eff, 1)
116
+ cnx_out = self.cnx_backbone(pixel_values=pixel_values_cnx, return_dict=True)
117
+ x_cnx = cnx_out.pooler_output
118
+ fused = torch.cat([x_eff, x_cnx], dim=1)
119
+ return self.fusion_head(fused)
120
+
121
+ class FusionCarDamagePredictor:
122
+ def __init__(self, checkpoint_path, class_map, convnext_model_name="facebook/convnext-small-224"):
123
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+ self.class_map = class_map
125
+ self.eff_normalize = transforms.Compose([
126
+ transforms.Resize((260, 260)),
127
+ transforms.ToTensor(),
128
+ transforms.Normalize([0.485, 0.456, 0.406],
129
+ [0.229, 0.224, 0.225])
130
+ ])
131
+ self.convnext_processor = ConvNextImageProcessor.from_pretrained(convnext_model_name)
132
+ try:
133
+ self.model = FusionClassifier(
134
+ num_classes=len(class_map),
135
+ convnext_model_name=convnext_model_name
136
+ )
137
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
138
+ state_dict = checkpoint.get("model_state_dict", checkpoint)
139
+ first_tensor = next(iter(state_dict.values()))
140
+ if first_tensor.dtype == torch.float16:
141
+ self.model = self.model.half()
142
+ self.model.load_state_dict(state_dict)
143
+ self.model.to(self.device)
144
+ self.model.eval()
145
+ except Exception as e:
146
+ raise RuntimeError(f"Failed to load Fusion model: {str(e)}")
147
+
148
+ def predict(self, image_input):
149
+ try:
150
+ if isinstance(image_input, str):
151
+ image = Image.open(image_input).convert("RGB")
152
+ elif isinstance(image_input, Image.Image):
153
+ image = image_input.convert("RGB")
154
+ else:
155
+ raise TypeError("image_input must be a file path or PIL.Image")
156
+ pixel_eff = self.eff_normalize(image)
157
+ pixel_eff = pixel_eff.unsqueeze(0).to(self.device)
158
+ inputs_cnx = self.convnext_processor(images=image, return_tensors="pt")
159
+ pixel_cnx = inputs_cnx["pixel_values"].to(self.device)
160
+ if next(self.model.parameters()).dtype == torch.float16:
161
+ pixel_eff = pixel_eff.half()
162
+ pixel_cnx = pixel_cnx.half()
163
+ with torch.no_grad():
164
+ logits = self.model(pixel_eff, pixel_cnx)
165
+ probs = torch.nn.functional.softmax(logits, dim=1)[0]
166
+ class_probs = {
167
+ self.class_map[i]: float(probs[i].item())
168
+ for i in range(len(self.class_map))
169
+ }
170
+ return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True))
171
+ except UnidentifiedImageError:
172
+ raise ValueError("Invalid image file provided")
173
+ except Exception as e:
174
+ raise RuntimeError(f"Fusion prediction failed: {str(e)}")
175
+
scripts/yolo.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ from ultralytics import YOLO
5
+ from scripts.model_loader import ModelLoader
6
+
7
+ yolo_model = None
8
+
9
+ def get_yolo_model():
10
+ global yolo_model
11
+ if yolo_model is None:
12
+ loader = ModelLoader()
13
+ yolo_path = loader.get_model_path("yolo")
14
+ yolo_model = YOLO(str(yolo_path))
15
+ return yolo_model
16
+
17
+
18
+ def get_yolo_damage_boxes(image_path, output_path):
19
+ try:
20
+ image = Image.open(image_path).convert("RGB")
21
+ model = get_yolo_model()
22
+ results = model.predict(
23
+ source=image,
24
+ conf=0.05,
25
+ imgsz=640,
26
+ verbose=False
27
+ )
28
+
29
+ result = results[0]
30
+ boxes = result.boxes
31
+ detections = []
32
+
33
+ if boxes is not None and len(boxes) > 0:
34
+ for box in boxes:
35
+ conf = float(box.conf[0])
36
+ cls_id = int(box.cls[0])
37
+ label = yolo_model.names[cls_id]
38
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
39
+
40
+ detections.append({
41
+ "label": label,
42
+ "confidence": round(conf, 4),
43
+ "box": [x1, y1, x2, y2]
44
+ })
45
+
46
+ plotted_bgr = result.plot()
47
+ plotted_rgb = plotted_bgr[..., ::-1]
48
+ cv2.imwrite(output_path, plotted_rgb)
49
+
50
+ return {
51
+ "detections": detections,
52
+ "total_detections": len(detections),
53
+ "message": "No damage detected" if len(detections) == 0 else "Detections found"
54
+ }
55
+
56
+ except Exception as e:
57
+ raise RuntimeError(f"YOLO failed: {str(e)}")