Spaces:
Sleeping
Sleeping
from fastapi import APIRouter, Request, UploadFile, File, Form, HTTPException | |
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse | |
from fastapi.templating import Jinja2Templates | |
from starlette.background import BackgroundTask | |
import shutil | |
import os | |
import uuid | |
from pathlib import Path | |
from typing import Optional | |
import json | |
import base64 | |
from ultralytics import YOLO | |
import cv2 | |
import numpy as np | |
from ..utils.llm_client import GroqAnalyzer | |
# Templates directory | |
TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "templates") | |
templates = Jinja2Templates(directory=TEMPLATES_DIR) | |
router = APIRouter() | |
UPLOAD_DIR = os.path.join("/tmp", "uploads") | |
RESULTS_DIR = os.path.join("/tmp", "results") | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
os.makedirs(RESULTS_DIR, exist_ok=True) | |
ALLOWED_EXTENSIONS = {"jpg", "jpeg", "png", "tiff", "tif"} | |
# Model paths | |
# DAMAGE_MODEL_PATH = os.path.join("/tmp", "models", "damage", "weights", "weights", "best.pt") # Commented for now | |
PARTS_MODEL_PATH = os.path.join("/tmp", "models", "parts", "weights", "weights", "best.pt") | |
# Class names for parts | |
PARTS_CLASS_NAMES = ['headlamp', 'front_bumper', 'hood', 'door', 'rear_bumper'] | |
# Initialize GroqAnalyzer | |
groq_analyzer = GroqAnalyzer() | |
# Helper: Run YOLO inference and return results | |
def run_yolo_inference(model_path, image_path, task='segment'): | |
model = YOLO(model_path) | |
results = model.predict(source=image_path, imgsz=640, conf=0.25, save=False, task=task) | |
return results[0] | |
# Helper: Draw masks and confidence on image | |
def draw_masks_and_conf(image_path, yolo_result, class_names=None): | |
img = cv2.imread(image_path) | |
overlay = img.copy() | |
out_img = img.copy() | |
colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255), (0,255,255)] | |
for i, box in enumerate(yolo_result.boxes): | |
conf = float(box.conf[0]) | |
cls = int(box.cls[0]) | |
color = colors[cls % len(colors)] | |
# Draw bbox | |
x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
cv2.rectangle(overlay, (x1, y1), (x2, y2), color, 2) | |
label = f"{class_names[cls] if class_names else 'damage'}: {conf:.2f}" | |
cv2.putText(overlay, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) | |
# Draw mask if available | |
if hasattr(yolo_result, 'masks') and yolo_result.masks is not None: | |
mask = yolo_result.masks.data[i].cpu().numpy() | |
mask = (mask * 255).astype(np.uint8) | |
mask = cv2.resize(mask, (x2-x1, y2-y1)) | |
roi = overlay[y1:y2, x1:x2] | |
colored_mask = np.zeros_like(roi) | |
colored_mask[mask > 127] = color | |
overlay[y1:y2, x1:x2] = cv2.addWeighted(roi, 0.5, colored_mask, 0.5, 0) | |
out_img = cv2.addWeighted(overlay, 0.7, img, 0.3, 0) | |
return out_img | |
# Helper: Generate JSON output | |
def generate_json_output(filename, damage_result, parts_result): | |
# Damage severity: use max confidence | |
if damage_result is not None and hasattr(damage_result, 'boxes'): | |
severity_score = float(max([float(box.conf[0]) for box in damage_result.boxes], default=0)) | |
damage_regions = [] | |
for box in damage_result.boxes: | |
x1, y1, x2, y2 = map(float, box.xyxy[0]) | |
conf = float(box.conf[0]) | |
damage_regions.append({"bbox": [x1, y1, x2, y2], "confidence": conf}) | |
else: | |
severity_score = 0 | |
damage_regions = [] | |
# Parts | |
parts = [] | |
for i, box in enumerate(parts_result.boxes): | |
x1, y1, x2, y2 = map(float, box.xyxy[0]) | |
conf = float(box.conf[0]) | |
cls = int(box.cls[0]) | |
# Damage %: use mask area / bbox area if available | |
damage_percentage = None | |
if hasattr(parts_result, 'masks') and parts_result.masks is not None: | |
mask = parts_result.masks.data[i].cpu().numpy() | |
mask_area = np.sum(mask > 0.5) | |
bbox_area = (x2-x1)*(y2-y1) | |
damage_percentage = float(mask_area / bbox_area) if bbox_area > 0 else None | |
parts.append({ | |
"part": PARTS_CLASS_NAMES[cls] if cls < len(PARTS_CLASS_NAMES) else str(cls), | |
"damaged": True, | |
"confidence": conf, | |
"damage_percentage": damage_percentage, | |
"bbox": [x1, y1, x2, y2] | |
}) | |
# Optionally, add base64 masks | |
# (not implemented here for brevity) | |
return { | |
"filename": filename, | |
"damage": { | |
"severity_score": severity_score, | |
"regions": damage_regions | |
}, | |
"parts": parts, | |
"cost_estimate": None | |
} | |
# Dummy login credentials | |
def check_login(username: str, password: str) -> bool: | |
return username == "demo" and password == "demo123" | |
def home(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request, "result": None}) | |
def login(request: Request, username: str = Form(...), password: str = Form(...)): | |
if check_login(username, password): | |
return templates.TemplateResponse("index.html", {"request": request, "result": None, "user": username}) | |
return templates.TemplateResponse("login.html", {"request": request, "error": "Invalid credentials"}) | |
def login_page(request: Request): | |
return templates.TemplateResponse("login.html", {"request": request}) | |
async def upload_image(request: Request, file: UploadFile = File(...)): | |
try: | |
ext = file.filename.split(".")[-1].lower() | |
print(f"[DEBUG] Uploaded file extension: {ext}") | |
if ext not in ALLOWED_EXTENSIONS: | |
print(f"[DEBUG] Unsupported file type: {ext}") | |
return templates.TemplateResponse("index.html", {"request": request, "error": "Unsupported file type."}) | |
# Save uploaded file | |
session_id = str(uuid.uuid4()) | |
upload_filename = f"{session_id}_{file.filename}" | |
upload_path = os.path.join(UPLOAD_DIR, upload_filename) | |
print(f"[DEBUG] Saving uploaded file to: {upload_path}") | |
with open(upload_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
print(f"[DEBUG] File saved. Running inference...") | |
warning = None | |
try: | |
damage_result = None # Not used | |
parts_result = run_yolo_inference(PARTS_MODEL_PATH, upload_path) | |
print(f"[DEBUG] YOLO inference result: {parts_result}") | |
parts_img = None | |
json_output = None | |
parts_img_url = None | |
json_url = None | |
if hasattr(parts_result, 'boxes') and len(parts_result.boxes) > 0: | |
print(f"[DEBUG] Detected {len(parts_result.boxes)} parts.") | |
parts_img = draw_masks_and_conf(upload_path, parts_result, class_names=PARTS_CLASS_NAMES) | |
parts_img_filename = f"{session_id}_parts.png" | |
parts_img_path = os.path.join(RESULTS_DIR, parts_img_filename) | |
cv2.imwrite(parts_img_path, parts_img) | |
print(f"[DEBUG] Parts image saved to: {parts_img_path}") | |
parts_img_url = f"/download/result/{parts_img_filename}" | |
json_output = generate_json_output(file.filename, damage_result, parts_result) | |
json_filename = f"{session_id}_result.json" | |
json_path = os.path.join(RESULTS_DIR, json_filename) | |
with open(json_path, "w") as jf: | |
json.dump(json_output, jf, indent=2) | |
print(f"[DEBUG] JSON output saved to: {json_path}") | |
json_url = f"/download/result/{json_filename}" | |
else: | |
warning = "No parts detected in the image." | |
print("[DEBUG] No parts detected.") | |
llm_analysis = groq_analyzer.analyze_damage(upload_path) | |
print(f"[DEBUG] LLM analysis output: {llm_analysis}") | |
result = { | |
"filename": file.filename, | |
"parts_image": parts_img_url, | |
"json": json_output, | |
"json_download": json_url, | |
"llm_analysis": llm_analysis, | |
"warning": warning | |
} | |
print("[DEBUG] Result dict:", result) | |
except Exception as e: | |
result = { | |
"filename": file.filename, | |
"error": f"Inference failed: {str(e)}", | |
"parts_image": None, | |
"json": None, | |
"json_download": None, | |
"llm_analysis": None, | |
"warning": None | |
} | |
print("[ERROR] Inference failed:", e) | |
import threading | |
import time | |
def delayed_cleanup(): | |
time.sleep(300) # 5 minutes | |
try: | |
os.remove(upload_path) | |
print(f"[DEBUG] Cleaned up upload: {upload_path}") | |
except Exception as ce: | |
print(f"[DEBUG] Cleanup error (upload): {ce}") | |
for suffix in ["_parts.png", "_result.json"]: | |
try: | |
os.remove(os.path.join(RESULTS_DIR, f"{session_id}{suffix}")) | |
print(f"[DEBUG] Cleaned up result: {os.path.join(RESULTS_DIR, f'{session_id}{suffix}')}" ) | |
except Exception as ce: | |
print(f"[DEBUG] Cleanup error (result): {ce}") | |
threading.Thread(target=delayed_cleanup, daemon=True).start() | |
return templates.TemplateResponse( | |
"index.html", | |
{ | |
"request": request, | |
"result": result, | |
"original_image": f"/download/upload/{upload_filename}" | |
} | |
) | |
except Exception as e: | |
print(f"[ERROR] Inference failed: {str(e)}") | |
return templates.TemplateResponse( | |
"index.html", | |
{"request": request, "error": f"Error processing image: {str(e)}"} | |
) | |
# --- Serve files from /tmp/uploads and /tmp/results --- | |
def download_uploaded_file(filename: str): | |
file_path = os.path.join(UPLOAD_DIR, filename) | |
if not os.path.exists(file_path): | |
return JSONResponse(status_code=404, content={"error": "File not found"}) | |
return FileResponse(file_path, filename=filename) | |
def download_result_file(filename: str): | |
file_path = os.path.join(RESULTS_DIR, filename) | |
if not os.path.exists(file_path): | |
return JSONResponse(status_code=404, content={"error": "File not found"}) | |
return FileResponse(file_path, filename=filename) | |