ocr / app.py
sairamtelagamsetti's picture
Update app.py
3bc1462 verified
import cv2
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection, TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
from datetime import datetime
# Ensure all required libraries are installed
try:
import timm # Required by DETR
except ImportError:
raise ImportError("The 'timm' library is required but not installed. Install it using 'pip install timm'.")
# Load the DETR model for object detection (license plate detection)
try:
detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
except Exception as e:
raise RuntimeError(f"Error initializing DETR model: {e}")
# Load the TrOCR model for OCR (license plate text recognition)
try:
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
except Exception as e:
raise RuntimeError(f"Error initializing TrOCR model: {e}")
def detect_license_plate(frame):
"""
Detect license plates in a video frame using DETR.
"""
# Convert the frame to a PIL image
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Preprocess the image for DETR
inputs = detr_processor(images=pil_image, return_tensors="pt")
outputs = detr_model(**inputs)
# Get detected objects and filter for license plates
logits = outputs.logits
boxes = outputs.pred_boxes
probas = logits.softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9 # Confidence threshold
detected_boxes = []
for box, score in zip(boxes[keep], probas[keep]):
# Convert box coordinates to pixel values
box = box.detach().cpu().numpy()
detected_boxes.append(box)
return detected_boxes
def recognize_text(plate_image):
"""
Recognize text from a license plate image using TrOCR.
"""
# Convert the license plate image to a PIL image
pil_image = Image.fromarray(cv2.cvtColor(plate_image, cv2.COLOR_BGR2RGB))
# Preprocess the image for TrOCR
pixel_values = trocr_processor(images=pil_image, return_tensors="pt").pixel_values
generated_ids = trocr_model.generate(pixel_values)
text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return text.strip()
def process_video(video_path, frame_skip=5):
"""
Process a video to detect license plates and log entry/exit times.
"""
cap = cv2.VideoCapture(video_path)
vehicle_data = {}
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % frame_skip != 0:
continue # Skip frames to optimize processing time
# Detect license plates
detected_boxes = detect_license_plate(frame)
for box in detected_boxes:
x_min, y_min, x_max, y_max = map(int, box)
license_plate_image = frame[y_min:y_max, x_min:x_max]
# Recognize text from the license plate
license_plate = recognize_text(license_plate_image)
if license_plate:
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
if license_plate not in vehicle_data:
# Vehicle entering
vehicle_data[license_plate] = {'entry_time': current_time, 'exit_time': None}
print(f"Vehicle {license_plate} entered at {current_time}")
else:
# Update exit time
vehicle_data[license_plate]['exit_time'] = current_time
# Draw bounding box and license plate text
cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
cv2.putText(frame, license_plate, (x_min, y_min-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# Display the frame (optional, can be removed for headless environments)
cv2.imshow('Vehicle Detection', frame)
# Break on 'q' key press
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
# Print vehicle data
print("\nVehicle Data:")
for plate, times in vehicle_data.items():
print(f"License Plate: {plate}, Entry Time: {times['entry_time']}, Exit Time: {times['exit_time']}")
if __name__ == "__main__":
# Replace 'road_video.mp4' with the path to your video file or use 0 for webcam
process_video("road_video.mp4", frame_skip=5)