Spaces:
Sleeping
Sleeping
import cv2 | |
from cv2 import dnn | |
import numpy as np | |
import pytesseract | |
import requests | |
import base64 | |
import onnxruntime | |
import os | |
from io import BytesIO | |
from PIL import Image | |
from langchain_core.tools import tool as langchain_tool | |
from smolagents.tools import Tool, tool | |
def pre_processing(image: str, input_size=(416, 416))->tuple: | |
""" | |
Pre-process an image for YOLO model | |
Args: | |
image: The image in base64 format to process | |
input_size: The size to which the image should be resized | |
Returns: | |
tuple: (processed_image, original_shape) | |
""" | |
try: | |
# Decode base64 image | |
image_data = base64.b64decode(image) | |
np_image = np.frombuffer(image_data, np.uint8) | |
img = cv2.imdecode(np_image, cv2.IMREAD_COLOR) | |
if img is None: | |
raise ValueError("Failed to decode image") | |
# Store original shape for post-processing | |
original_shape = img.shape[:2] # (height, width) | |
# Ensure input_size is valid | |
if not isinstance(input_size, tuple) or len(input_size) != 2: | |
input_size = (416, 416) | |
# Resize and normalize the image | |
img = cv2.resize(img, input_size, interpolation=cv2.INTER_LINEAR) | |
if img is None: | |
raise ValueError("Failed to resize image") | |
# Ensure image is in BGR format (3 channels) | |
if len(img.shape) == 2: # If grayscale | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
elif img.shape[2] == 4: # If RGBA | |
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) | |
# Convert BGR to RGB and normalize | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # More reliable than array slicing | |
img = img.astype(np.float32) / 255.0 # Normalize to [0, 1] | |
# Convert to NCHW format (batch, channels, height, width) | |
img = np.transpose(img, (2, 0, 1)) # HWC to CHW | |
img = np.expand_dims(img, axis=0) # Add batch dimension | |
# Verify final shape | |
if img.shape != (1, 3, 416, 416): | |
print(f"Warning: Final shape is {img.shape}, expected (1, 3, 416, 416)") | |
img = np.reshape(img, (1, 3, 416, 416)) | |
return img, original_shape | |
except Exception as e: | |
raise ValueError(f"Error in pre_processing: {str(e)}") | |
def post_processing(onnx_output, classes, original_shape, conf_threshold=0.5, nms_threshold=0.4)->list: | |
""" | |
Post-process the output of the YOLO model | |
Args: | |
onnx_output: The raw output from the ONNX model | |
classes: List of class names | |
original_shape: Original shape of the image | |
conf_threshold: Confidence threshold for filtering detections | |
nms_threshold: Non-max suppression threshold | |
Returns: | |
List of detected objects with labels, confidence, and bounding boxes | |
""" | |
class_ids = [] | |
confidences = [] | |
boxes = [] | |
for detection in onnx_output[0]: | |
scores = detection[5:] | |
class_id = np.argmax(scores) | |
confidence = scores[class_id] | |
if confidence > conf_threshold: | |
center_x = int(detection[0] * original_shape[1]) | |
center_y = int(detection[1] * original_shape[0]) | |
w = int(detection[2] * original_shape[1]) | |
h = int(detection[3] * original_shape[0]) | |
x = int(center_x - w / 2) | |
y = int(center_y - h / 2) | |
boxes.append([x, y, w, h]) | |
confidences.append(float(confidence)) | |
class_ids.append(class_id) | |
# Apply non-max suppression | |
indices = dnn.NMSBoxes(boxes, confidences, conf_threshold, nms_threshold) | |
detected_objects = [] | |
for i in indices: | |
i = i[0] | |
box = boxes[i] | |
label = str(classes[class_ids[i]]) | |
detected_objects.append((label, confidences[i], box)) | |
return detected_objects | |
def extract_images_from_video(video_path: str) -> list: | |
""" | |
Extract images (frames) from a video | |
Args: | |
video_path: The path to the video file | |
Returns: | |
A list of images (frames) as numpy arrays | |
""" | |
cap = cv2.VideoCapture(video_path) | |
images = [] | |
while cap.isOpened(): | |
ret, image = cap.read() | |
if not ret: | |
break | |
images.append(image) | |
cap.release() | |
return images | |
def get_image_from_file_path(file_path: str)->str: | |
""" | |
Load an image from a file path and convert it to a base64 string | |
Args: | |
file_path: The path to the file | |
Returns: | |
The image as a base64 string | |
""" | |
try: | |
# Debug prints for original path | |
# print(f"Original file_path: {file_path}") | |
# print(f"Original path exists: {os.path.exists(file_path)}") | |
# if os.path.exists(file_path): | |
# print(f"Original path is file: {os.path.isfile(file_path)}") | |
# print(f"Original path permissions: {oct(os.stat(file_path).st_mode)[-3:]}") | |
# print(f"Original path absolute: {os.path.abspath(file_path)}") | |
# Try reading with cv2 | |
img = cv2.imread(file_path) | |
if img is None: | |
raise FileNotFoundError(f"Could not read image at {file_path}") | |
# Use BytesIO to encode the image | |
with BytesIO() as buffer: | |
_, buffer_data = cv2.imencode('.jpg', img) | |
buffer.write(buffer_data.tobytes()) | |
image = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
except Exception as e: | |
print(f"First attempt failed: {str(e)}") | |
# Try with adjusted path | |
try: | |
current_file_path = os.path.abspath(__file__) | |
current_file_dir = os.path.dirname(current_file_path) | |
adjusted_path = os.path.join(current_file_dir, file_path) | |
# Debug prints for adjusted path | |
# print(f"Adjusted file_path: {adjusted_path}") | |
# print(f"Adjusted path exists: {os.path.exists(adjusted_path)}") | |
# if os.path.exists(adjusted_path): | |
# print(f"Adjusted path is file: {os.path.isfile(adjusted_path)}") | |
# print(f"Adjusted path permissions: {oct(os.stat(adjusted_path).st_mode)[-3:]}") | |
# print(f"Adjusted path absolute: {os.path.abspath(adjusted_path)}") | |
# Try reading with cv2 | |
img = cv2.imread(adjusted_path) | |
if img is None: | |
raise FileNotFoundError(f"Could not read image at {adjusted_path}") | |
# Use BytesIO to encode the image | |
with BytesIO() as buffer: | |
_, buffer_data = cv2.imencode('.jpg', img) | |
buffer.write(buffer_data.tobytes()) | |
image = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
except Exception as e2: | |
print(f"Second attempt failed: {str(e2)}") | |
# List directory contents to help debug | |
try: | |
validation_dir = os.path.join(current_file_dir, "validation") | |
if os.path.exists(validation_dir): | |
print(f"Contents of validation directory: {os.listdir(validation_dir)}") | |
except Exception as e3: | |
print(f"Failed to list directory contents: {str(e3)}") | |
raise FileNotFoundError(f"Could not read image at {file_path} or {adjusted_path}") | |
return image | |
def get_video_from_file_path(file_path: str)->str: | |
""" | |
Load a video from a file path and convert it to a base64 string | |
Args: | |
file_path: The path to the file | |
Returns: | |
The video as a base64 string | |
""" | |
try: | |
# Use cv2 to read the video | |
cap = cv2.VideoCapture(file_path) | |
if not cap.isOpened(): | |
raise FileNotFoundError(f"Could not read video at {file_path}") | |
# Get video properties | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# Create a BytesIO buffer to store the images (frames) | |
images = [] | |
while cap.isOpened(): | |
ret, image = cap.read() | |
if not ret: | |
break | |
# Convert frame to jpg and store in memory | |
_, buffer = cv2.imencode('.jpg', image) | |
images.append(buffer.tobytes()) | |
# Release the video capture | |
cap.release() | |
# Combine all images into a single buffer | |
with BytesIO() as buffer: | |
# Write each image to the buffer | |
for image_data in images: | |
buffer.write(image_data) | |
# Encode to base64 | |
video_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
except Exception as e: | |
current_file_path = os.path.abspath(__file__) | |
current_file_dir = os.path.dirname(current_file_path) | |
file_path = os.path.join(current_file_dir, file_path) | |
# Try again with the new path | |
cap = cv2.VideoCapture(file_path) | |
if not cap.isOpened(): | |
raise FileNotFoundError(f"Could not read video at {file_path}") | |
# Get video properties | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# Create a BytesIO buffer to store the images (frames) | |
images = [] | |
while cap.isOpened(): | |
ret, image = cap.read() | |
if not ret: | |
break | |
# Convert image to jpg and store in memory | |
_, buffer = cv2.imencode('.jpg', image) | |
images.append(buffer.tobytes()) | |
# Release the video capture | |
cap.release() | |
# Combine all images into a single buffer | |
with BytesIO() as buffer: | |
# Write each image to the buffer | |
for image_data in images: | |
buffer.write(image_data) | |
# Encode to base64 | |
video_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
return video_base64 | |
def image_processing(image: str, brightness: float = 1.0, contrast: float = 1.0)->str: | |
""" | |
Process an image | |
Args: | |
image: The image in base64 format to process | |
brightness: The brightness of the image on scale of 0-10 | |
contrast: The contrast of the image on scale of 0-10 | |
Returns: | |
The processed image | |
""" | |
image_data = base64.b64decode(image) | |
np_image = np.frombuffer(image_data, np.uint8) | |
img = cv2.imdecode(np_image, cv2.IMREAD_COLOR) | |
# Adjust brightness and contrast | |
img = cv2.convertScaleAbs(img, alpha=contrast, beta=brightness) | |
_, buffer = cv2.imencode('.jpg', img) | |
processed_image = base64.b64encode(buffer).decode('utf-8') | |
return processed_image | |
onnx_path = "vlm_assets/yolov3-8.onnx" | |
class ObjectDetectionTool(Tool): | |
name = "object_detection" | |
description = """ | |
Detect objects in a list of images. | |
Input Requirements: | |
- Input must be a list of images, where each image is a base64-encoded string | |
- Each base64 string must be properly padded (length must be a multiple of 4) | |
- Images will be resized to 416x416 pixels during processing | |
- Images should be in RGB or BGR format (3 channels) | |
- Supported image formats: JPG, PNG | |
Processing: | |
- Images are automatically resized to 416x416 | |
- Images are normalized to [0,1] range | |
- Model expects input shape: [1, 3, 416, 416] (batch, channels, height, width) | |
Output: | |
- Returns a list of detected objects for each image | |
- Each detection includes: (label, confidence, bounding_box) | |
- Bounding boxes are in format: [x, y, width, height] | |
- Confidence threshold: 0.5 | |
- NMS threshold: 0.4 | |
Example input format: | |
["base64_encoded_image1", "base64_encoded_image2"] | |
Example output format: | |
[ | |
[("person", 0.95, [100, 200, 50, 100]), ("car", 0.88, [300, 400, 80, 60])], # detections for image1 | |
[("dog", 0.92, [150, 250, 40, 80])] # detections for image2 | |
] | |
""" | |
inputs = { | |
"images": { | |
"type": "any", | |
"description": "List of base64-encoded images. Each image must be a valid base64 string with proper padding (length multiple of 4). Images will be resized to 416x416." | |
} | |
} | |
output_type = "any" | |
def setup(self): | |
try: | |
# Load ONNX model | |
self.onnx_path = onnx_path | |
self.onnx_model = onnxruntime.InferenceSession(self.onnx_path) | |
# Get model input details | |
self.input_name = self.onnx_model.get_inputs()[0].name | |
self.input_shape = self.onnx_model.get_inputs()[0].shape | |
print(f"Model input shape: {self.input_shape}") | |
# Load class labels | |
self.classes = [ | |
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', | |
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', | |
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', | |
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', | |
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', | |
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' | |
] | |
except Exception as e: | |
raise RuntimeError(f"Error in setup: {str(e)}") | |
def forward(self, images: any)->any: | |
try: | |
if not isinstance(images, list): | |
images = [images] # Convert single image to list | |
detected_objects = [] | |
for image in images: | |
try: | |
# Preprocess the image | |
img, original_shape = pre_processing(image) | |
# Verify input shape and convert to NCHW if needed | |
if len(img.shape) != 4: # Should be NCHW | |
raise ValueError(f"Invalid input shape: {img.shape}, expected NCHW format") | |
if img.shape[1] != 3: # Should have 3 channels | |
# If channels are last, transpose to NCHW | |
if img.shape[3] == 3: | |
img = np.transpose(img, (0, 3, 1, 2)) | |
else: | |
raise ValueError(f"Invalid number of channels: {img.shape[1]}, expected 3") | |
# Verify final shape | |
if img.shape != (1, 3, 416, 416): | |
print(f"Warning: Reshaping input from {img.shape} to (1, 3, 416, 416)") | |
img = np.reshape(img, (1, 3, 416, 416)) | |
# Run inference | |
onnx_input = {self.input_name: img} | |
onnx_output = self.onnx_model.run(None, onnx_input) | |
# Handle shape mismatch by transposing if needed | |
if len(onnx_output[0].shape) == 4: # If in NCHW format | |
if onnx_output[0].shape[1] == 255: # If channels first | |
onnx_output = [onnx_output[0].transpose(0, 2, 3, 1)] # Convert to NHWC | |
# Post-process the output | |
objects = post_processing(onnx_output, self.classes, original_shape) | |
detected_objects.append(objects) | |
except Exception as e: | |
print(f"Error processing image: {str(e)}") | |
detected_objects.append([]) # Add empty list for failed image | |
return detected_objects | |
except Exception as e: | |
raise RuntimeError(f"Error in forward pass: {str(e)}") | |
class OCRTool(Tool): | |
description = """ | |
Scan an image for text using OCR (Optical Character Recognition). | |
Input Requirements: | |
- Input must be a list of images, where each image is a base64-encoded string | |
- Each base64 string must be properly padded (length must be a multiple of 4) | |
- Images should be in RGB or BGR format (3 channels) | |
- Supported image formats: JPG, PNG | |
- For best results: | |
* Text should be clear and well-lit | |
* Image should have good contrast | |
* Text should be properly oriented | |
* Avoid blurry or distorted images | |
Processing: | |
- Uses Tesseract OCR engine | |
- Automatically handles text orientation | |
- Supports multiple languages (default: English) | |
- Processes each image independently | |
Output: | |
- Returns a list of text strings, one for each input image | |
- Empty string is returned if no text is detected | |
- Text is returned in the order it appears in the image | |
- Line breaks are preserved in the output | |
Example input format: | |
["base64_encoded_image1", "base64_encoded_image2"] | |
Example output format: | |
[ | |
"This is text from image 1\nSecond line of text", # text from image1 | |
"Text from image 2" # text from image2 | |
] | |
""" | |
name = "ocr_scan" | |
inputs = { | |
"images": { | |
"type": "any", | |
"description": "List of base64-encoded images. Each image must be a valid base64 string with proper padding (length multiple of 4). Images should be clear and well-lit for best OCR results." | |
} | |
} | |
output_type = "any" | |
def forward(self, images: any)->any: | |
scanned_text = [] | |
for image in images: | |
image_data = base64.b64decode(image) | |
img = Image.open(BytesIO(image_data)) | |
scanned_text.append(pytesseract.image_to_string(img)) | |
return scanned_text | |
ocr_scan_tool = OCRTool() | |
object_detection_tool = ObjectDetectionTool() | |
#Test 3 | |