import streamlit as st import cv2 import numpy as np import onnxruntime as ort from PIL import Image import tempfile import torch from ultralytics import YOLO # Load models @st.cache_resource def load_models(): license_plate_detector = YOLO('license_plate_detector.pt') vehicle_detector = YOLO('yolov8n.pt') ort_session = ort.InferenceSession("model.onnx") return license_plate_detector, vehicle_detector, ort_session def draw_border(img, top_left, bottom_right, color=(0, 255, 0), thickness=10, line_length_x=200, line_length_y=200): x1, y1 = top_left x2, y2 = bottom_right # Draw corner lines cv2.line(img, (x1, y1), (x1, y1 + line_length_y), color, thickness) # top-left cv2.line(img, (x1, y1), (x1 + line_length_x, y1), color, thickness) cv2.line(img, (x1, y2), (x1, y2 - line_length_y), color, thickness) # bottom-left cv2.line(img, (x1, y2), (x1 + line_length_x, y2), color, thickness) cv2.line(img, (x2, y1), (x2 - line_length_x, y1), color, thickness) # top-right cv2.line(img, (x2, y1), (x2, y1 + line_length_y), color, thickness) cv2.line(img, (x2, y2), (x2, y2 - line_length_y), color, thickness) # bottom-right cv2.line(img, (x2, y2), (x2 - line_length_x, y2), color, thickness) return img def process_frame(frame, license_plate_detector, vehicle_detector, ort_session): # Detect vehicles vehicle_results = vehicle_detector(frame, classes=[2, 3, 5, 7]) # cars, motorcycles, bus, trucks # Process each vehicle for vehicle in vehicle_results[0].boxes.data: x1, y1, x2, y2, score, class_id = vehicle if score > 0.5: # Confidence threshold # Draw vehicle border draw_border(frame, (int(x1), int(y1)), (int(x2), int(y2)), color=(0, 255, 0), thickness=25, line_length_x=200, line_length_y=200) # Detect license plate in vehicle region vehicle_crop = frame[int(y1):int(y2), int(x1):int(x2)] license_results = license_plate_detector(vehicle_crop) for license_plate in license_results[0].boxes.data: lp_x1, lp_y1, lp_x2, lp_y2, lp_score, _ = license_plate if lp_score > 0.5: # Adjust coordinates to full frame abs_lp_x1 = int(x1 + lp_x1) abs_lp_y1 = int(y1 + lp_y1) abs_lp_x2 = int(x1 + lp_x2) abs_lp_y2 = int(y1 + lp_y2) # Draw license plate box cv2.rectangle(frame, (abs_lp_x1, abs_lp_y1), (abs_lp_x2, abs_lp_y2), (0, 0, 255), 12) # Extract and process license plate for OCR license_crop = frame[abs_lp_y1:abs_lp_y2, abs_lp_x1:abs_lp_x2] if license_crop.size > 0: # Prepare license crop for ONNX model license_crop_resized = cv2.resize(license_crop, (640, 640)) license_crop_processed = np.transpose(license_crop_resized, (2, 0, 1)).astype(np.float32) / 255.0 license_crop_processed = np.expand_dims(license_crop_processed, axis=0) # Run OCR inference try: inputs = {ort_session.get_inputs()[0].name: license_crop_processed} outputs = ort_session.run(None, inputs) # Process OCR output (adjust based on your model's output format) # This is a placeholder - adjust based on your ONNX model's output license_number = "ABC123" # Replace with actual OCR processing # Display license plate number H, W, _ = license_crop.shape license_crop_display = cv2.resize(license_crop, (int(W * 400 / H), 400)) try: # Display license crop and number above vehicle h_crop, w_crop, _ = license_crop_display.shape center_x = int((x1 + x2) / 2) # Display license plate crop frame[int(y1) - h_crop - 100:int(y1) - 100, int(center_x - w_crop/2):int(center_x + w_crop/2)] = license_crop_display # White background for text cv2.rectangle(frame, (int(center_x - w_crop/2), int(y1) - h_crop - 400), (int(center_x + w_crop/2), int(y1) - h_crop - 100), (255, 255, 255), -1) # Draw license number (text_width, text_height), _ = cv2.getTextSize( license_number, cv2.FONT_HERSHEY_SIMPLEX, 4.3, 17) cv2.putText(frame, license_number, (int(center_x - text_width/2), int(y1 - h_crop - 250 + text_height/2)), cv2.FONT_HERSHEY_SIMPLEX, 4.3, (0, 0, 0), 17) except Exception as e: st.error(f"Error displaying results: {str(e)}") except Exception as e: st.error(f"Error in OCR processing: {str(e)}") return frame def process_video(video_path, license_plate_detector, vehicle_detector, ort_session): cap = cv2.VideoCapture(video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) progress_bar = st.progress(0) frame_count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break processed_frame = process_frame(frame, license_plate_detector, vehicle_detector, ort_session) out.write(processed_frame) frame_count += 1 progress_bar.progress(frame_count / total_frames) cap.release() out.release() progress_bar.empty() return temp_file.name # Streamlit UI st.title("Advanced Vehicle and License Plate Detection") try: license_plate_detector, vehicle_detector, ort_session = load_models() uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"]) if uploaded_file is not None: file_type = uploaded_file.type.split('/')[0] if file_type == "image": image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_column_width=True) if st.button("Detect"): with st.spinner("Processing image..."): # Convert PIL Image to CV2 format image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) processed_image = process_frame(image_cv, license_plate_detector, vehicle_detector, ort_session) processed_image = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB) st.image(processed_image, caption="Processed Image", use_column_width=True) elif file_type == "video": tfile = tempfile.NamedTemporaryFile(delete=False) tfile.write(uploaded_file.read()) st.video(tfile.name) if st.button("Detect"): with st.spinner("Processing video..."): processed_video = process_video(tfile.name, license_plate_detector, vehicle_detector, ort_session) st.video(processed_video) except Exception as e: st.error(f"Error loading models: {str(e)}")