File size: 3,867 Bytes
ded27db
e2ff788
ded27db
e2ff788
ded27db
e2ff788
7370338
ded27db
e11ec5f
e3876cb
e11ec5f
ded27db
 
b7f9967
ded27db
 
e2ff788
 
 
3ee48e8
ec7b022
7370338
3ee48e8
 
 
 
7370338
 
 
 
 
e2ff788
7370338
 
e2ff788
3ee48e8
 
 
 
e2ff788
 
 
e11ec5f
e2ff788
 
 
 
 
 
 
ded27db
 
e2ff788
 
 
ded27db
 
 
c7f3439
e2ff788
 
 
ded27db
e2ff788
 
c7f3439
 
e2ff788
3ee48e8
 
 
c7f3439
ded27db
 
 
e2ff788
 
 
c7f3439
 
 
 
e2ff788
c7f3439
e2ff788
 
c7f3439
 
3ee48e8
 
ded27db
c7f3439
 
 
e11ec5f
c7f3439
e11ec5f
e2ff788
 
e11ec5f
c7f3439
e11ec5f
c7f3439
e2ff788
c7f3439
ded27db
e11ec5f
0cec0cd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import logging
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from ultralytics import YOLO
import cv2
import numpy as np
import io
from PIL import Image
import base64
import os
from io import BytesIO

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

app = FastAPI()

# Set the path to the YOLO model
YOLO_MODEL_PATH = "/tmp/models/yolov9s.pt"

# Model Check
if not os.path.exists(YOLO_MODEL_PATH):
    raise FileNotFoundError(f"YOLO model not found at {YOLO_MODEL_PATH}")

# Load the YOLOv9 model
try:
    model = YOLO(YOLO_MODEL_PATH)  # Load the YOLO model from the pre-downloaded path
except Exception as e:
    raise RuntimeError(f"Failed to load YOLO model: {str(e)}")

# Class labels for vehicles (cars, motorbikes, buses, trucks, etc.)
vehicle_classes = [2, 3, 5, 7]  # Adjust as necessary for your use case

@app.get("/")
async def root():
    return {"status": "OK", "model": "YOLOv9s"}

@app.post("/analyze_traffic/")
async def analyze_traffic(file: UploadFile = File(...)):
    """
    Analyze the traffic image using YOLOv9 and return the results along with a processed image.
    """
    try:
        # Load image from the uploaded file
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        image_np = np.array(image)

        # YOLO model detection
        logging.info("Running YOLO detection on the uploaded image...")
        results = model(image_np)  # Perform detection
        detections = results[0]  # Get detections from the first image (batch size = 1)

        # Log raw results for debugging
        logging.debug(f"Raw YOLO results: {results}")

        # Extract vehicle details
        vehicle_count = 0
        vehicle_boxes = []
        for det in detections.boxes:
            cls_id = int(det.cls[0]) if hasattr(det, 'cls') else None
            if cls_id in vehicle_classes:
                vehicle_count += 1
                box = det.xyxy[0]  # Bounding box
                vehicle_boxes.append((int(box[0]), int(box[1]), int(box[2]), int(box[3])))

        #Log for detection structure
        logging.debug(f"Detection structure: {det.__dict__}")

        # Log detected vehicle details
        logging.info(f"Vehicle count: {vehicle_count}")
        logging.info(f"Vehicle bounding boxes: {vehicle_boxes}")

        # Calculate congestion level based on vehicle count
        if vehicle_count > 20:
            congestion_level = "High"
        elif vehicle_count > 10:
            congestion_level = "Medium"
        else:
            congestion_level = "Low"

        # Determine traffic flow rate based on congestion level
        flow_rate = "Smooth" if congestion_level == "Low" else "Heavy"

        # Draw bounding boxes on the processed image
        for (x1, y1, x2, y2) in vehicle_boxes:
            cv2.rectangle(image_np, (x1, y1), (x2, y2), (0, 255, 0), thickness=2)


        # Encode the processed image to base64
        _, buffer = cv2.imencode('.jpg', image_np)
        processed_image_base64 = base64.b64encode(buffer).decode('utf-8')

        # Return the analysis results along with the processed image
        return JSONResponse(content={
            "vehicle_count": vehicle_count,
            "congestion_level": congestion_level,
            "flow_rate": flow_rate,
            "processed_image_base64": processed_image_base64
        })

    except Exception as e:
        # Log any exceptions that occur
        logging.error(f"Error analyzing traffic: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error analyzing traffic: {str(e)}")

    except cv2.error as cv_error:
            logging.error(f"OpenCV error: {cv_error}")
            raise HTTPException(status_code=500, detail="Image processing error.")