Spaces:
Runtime error
Runtime error
| import cv2 | |
| import torch | |
| from transformers import DetrImageProcessor, DetrForObjectDetection, TrOCRProcessor, VisionEncoderDecoderModel | |
| from PIL import Image | |
| from datetime import datetime | |
| # Ensure all required libraries are installed | |
| try: | |
| import timm # Required by DETR | |
| except ImportError: | |
| raise ImportError("The 'timm' library is required but not installed. Install it using 'pip install timm'.") | |
| # Load the DETR model for object detection (license plate detection) | |
| try: | |
| detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
| detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
| except Exception as e: | |
| raise RuntimeError(f"Error initializing DETR model: {e}") | |
| # Load the TrOCR model for OCR (license plate text recognition) | |
| try: | |
| trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") | |
| trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") | |
| except Exception as e: | |
| raise RuntimeError(f"Error initializing TrOCR model: {e}") | |
| def detect_license_plate(frame): | |
| """ | |
| Detect license plates in a video frame using DETR. | |
| """ | |
| # Convert the frame to a PIL image | |
| pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| # Preprocess the image for DETR | |
| inputs = detr_processor(images=pil_image, return_tensors="pt") | |
| outputs = detr_model(**inputs) | |
| # Get detected objects and filter for license plates | |
| logits = outputs.logits | |
| boxes = outputs.pred_boxes | |
| probas = logits.softmax(-1)[0, :, :-1] | |
| keep = probas.max(-1).values > 0.9 # Confidence threshold | |
| detected_boxes = [] | |
| for box, score in zip(boxes[keep], probas[keep]): | |
| # Convert box coordinates to pixel values | |
| box = box.detach().cpu().numpy() | |
| detected_boxes.append(box) | |
| return detected_boxes | |
| def recognize_text(plate_image): | |
| """ | |
| Recognize text from a license plate image using TrOCR. | |
| """ | |
| # Convert the license plate image to a PIL image | |
| pil_image = Image.fromarray(cv2.cvtColor(plate_image, cv2.COLOR_BGR2RGB)) | |
| # Preprocess the image for TrOCR | |
| pixel_values = trocr_processor(images=pil_image, return_tensors="pt").pixel_values | |
| generated_ids = trocr_model.generate(pixel_values) | |
| text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return text.strip() | |
| def process_video(video_path, frame_skip=5): | |
| """ | |
| Process a video to detect license plates and log entry/exit times. | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| vehicle_data = {} | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_count += 1 | |
| if frame_count % frame_skip != 0: | |
| continue # Skip frames to optimize processing time | |
| # Detect license plates | |
| detected_boxes = detect_license_plate(frame) | |
| for box in detected_boxes: | |
| x_min, y_min, x_max, y_max = map(int, box) | |
| license_plate_image = frame[y_min:y_max, x_min:x_max] | |
| # Recognize text from the license plate | |
| license_plate = recognize_text(license_plate_image) | |
| if license_plate: | |
| current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
| if license_plate not in vehicle_data: | |
| # Vehicle entering | |
| vehicle_data[license_plate] = {'entry_time': current_time, 'exit_time': None} | |
| print(f"Vehicle {license_plate} entered at {current_time}") | |
| else: | |
| # Update exit time | |
| vehicle_data[license_plate]['exit_time'] = current_time | |
| # Draw bounding box and license plate text | |
| cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) | |
| cv2.putText(frame, license_plate, (x_min, y_min-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | |
| # Display the frame (optional, can be removed for headless environments) | |
| cv2.imshow('Vehicle Detection', frame) | |
| # Break on 'q' key press | |
| if cv2.waitKey(1) & 0xFF == ord('q'): | |
| break | |
| cap.release() | |
| cv2.destroyAllWindows() | |
| # Print vehicle data | |
| print("\nVehicle Data:") | |
| for plate, times in vehicle_data.items(): | |
| print(f"License Plate: {plate}, Entry Time: {times['entry_time']}, Exit Time: {times['exit_time']}") | |
| if __name__ == "__main__": | |
| # Replace 'road_video.mp4' with the path to your video file or use 0 for webcam | |
| process_video("road_video.mp4", frame_skip=5) |