File size: 21,430 Bytes
6932dcb f79aa1b 6932dcb ea6c5ed 6932dcb 93b47b3 6932dcb ea6c5ed 975ce9a f79aa1b cd06f78 ea6c5ed f79aa1b 39078d5 f79aa1b 975ce9a 6932dcb 975ce9a 6932dcb ea6c5ed 6932dcb cd06f78 ea6c5ed 6932dcb ea6c5ed 6932dcb ea6c5ed f79aa1b 6932dcb f79aa1b ea6c5ed f79aa1b 6932dcb f79aa1b 6932dcb f79aa1b ea6c5ed f9d921a 975ce9a f9d921a 975ce9a 9053779 f9d921a 975ce9a 6932dcb 975ce9a cd06f78 975ce9a 585c0be f9d921a 975ce9a 585c0be f9d921a 975ce9a 585c0be cd06f78 975ce9a cd06f78 ea6c5ed cd06f78 585c0be ea6c5ed 6932dcb 975ce9a ea6c5ed f79aa1b ea6c5ed 488a535 f79aa1b 6932dcb cd06f78 4803b30 cd06f78 ea6c5ed 488a535 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 |
import os
import io
import json # Add json import
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw, ImageFont # Add imports for drawing
import requests
from fastapi import FastAPI, Form, UploadFile, HTTPException, File # Import Form, UploadFile, HTTPException, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import traceback # For detailed error logging
import logging # Add logging import
from fastapi import Request # Import Request
import tempfile # Use tempfile for safer temporary file handling
import math # Added for distance calculation
# Import from utility files
from detection_utils import PREDEFINED_CLASSES, run_yoloworld_detection, expand_synonyms
# Import the new speech processing function
from speech_utils import process_audio
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
profile_models = {}
profile_class_maps = {} # Store the name mapping for each profile model
dynamic_model = None
# Create FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Proximity Filtering Configuration ---
PROXIMITY_THRESHOLD = 50.0 # Pixels. Adjust as needed.
PREFERRED_LABEL_FOR_FILTERING = "auto rickshaw"
# --- Helper for Proximity Filtering ---
def _calculate_distance(center1: List[float], center2: List[float]) -> float:
if len(center1) != 2 or len(center2) != 2:
# Should not happen with valid detection data, but good to be safe
return float('inf')
return math.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
def filter_close_detections(
detections: List[Dict[str, Any]],
distance_threshold: float,
preferred_label: str = "auto rickshaw"
) -> List[Dict[str, Any]]:
if not detections:
return []
num_detections = len(detections)
processed_mask = [False] * num_detections
final_filtered_list = []
for i in range(num_detections):
if processed_mask[i]:
continue
# Start a new cluster
current_cluster_indices = []
# Use a list as a queue for BFS
queue = []
# Seed the queue with the current unprocessed detection
queue.append(i)
processed_mask[i] = True
head = 0
while head < len(queue):
current_idx = queue[head]
head += 1
current_cluster_indices.append(current_idx) # Add to current cluster
# Check against all other detections
for j in range(num_detections):
if not processed_mask[j]:
# Ensure 'centre' key exists and is valid before calculating distance
center1 = detections[current_idx].get('centre')
center2 = detections[j].get('centre')
if not (isinstance(center1, list) and len(center1) == 2 and
isinstance(center2, list) and len(center2) == 2):
# Log warning or skip if center data is missing/malformed
# logger.warning(f"Skipping proximity check due to missing/malformed center data for detection indices {current_idx}, {j}")
continue
dist = _calculate_distance(center1, center2)
if dist < distance_threshold:
processed_mask[j] = True
queue.append(j) # Add to queue to explore its neighbors
# We have a cluster (all indices in current_cluster_indices).
# Now select the best one from it based on preference and confidence.
if not current_cluster_indices: # Should not happen if loop started
continue
cluster_detections = [detections[k] for k in current_cluster_indices]
preferred_detections_in_cluster = [
d for d in cluster_detections if d.get('label_en', '').lower() == preferred_label.lower()
]
chosen_detection = None
if preferred_detections_in_cluster:
# If preferred label is present, pick the one with highest confidence among them
chosen_detection = max(preferred_detections_in_cluster, key=lambda d: d.get('confidence', 0.0))
else:
# Otherwise, pick the one with highest confidence from the whole cluster
if cluster_detections:
chosen_detection = max(cluster_detections, key=lambda d: d.get('confidence', 0.0))
if chosen_detection:
final_filtered_list.append(chosen_detection)
return final_filtered_list
# --- Pydantic Models ---
class DetectionResponse(BaseModel):
objects: List[Dict[str, Any]]
count: int
profile_used: str
classes_used: List[str] # Return the actual list used for detection
status: str = "success"
message: Optional[str] = None
# --- API Endpoints ---
# Add the new consolidated speech endpoint
@app.post("/api/speech")
async def handle_speech(
audio: UploadFile = File(...),
lang1: str = Form(...),
lang2: str = Form(...)
):
"""
Receives an audio file, transcribes it using Whisper (detecting between lang1 and lang2),
and translates the result to the other language using googletrans.
"""
try:
logger.info(f"Received speech processing request. Lang1: {lang1}, Lang2: {lang2}, File: {audio.filename}")
# Read audio file content
audio_bytes = await audio.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Received empty audio file.")
# Process using the utility function
result = await process_audio(audio_bytes, lang1, lang2)
if result is None:
raise HTTPException(status_code=500, detail="Failed to process audio.")
if "error" in result: # Handle specific errors returned by process_audio
raise HTTPException(status_code=400, detail=result["error"])
logger.info(f"Speech processing successful. Detected: {result.get('detected_language')}")
return result
except HTTPException as http_exc:
# Re-raise HTTPExceptions directly
raise http_exc
except Exception as e:
logger.error(f"Error in /api/speech endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
finally:
# Ensure the uploaded file stream is closed
if audio:
await audio.close()
# Keep /api/detect_objects endpoint
@app.post("/api/detect_objects", response_model=DetectionResponse)
async def detect_objects_yolo_world(
request: Request, # Keep for logging if needed
image: UploadFile = File(...), # Revert back to File(...)
profile: str = Form("casual"),
extra_words: Optional[str] = Form(None),
confidence: float = Form(0.2, ge=0.0, le=1.0), # Lowered default confidence
iou: float = Form(0.50, ge=0.0, le=1.0) # Lowered default IoU (NMS threshold)
):
profile_lower = profile.lower()
if profile_lower not in PREDEFINED_CLASSES:
logger.warning(f"Invalid profile '{profile}' received, defaulting to 'casual'.")
profile_lower = "casual" # Default to casual
# --- Image Loading ---
try:
logger.info("Reading image bytes...")
image_bytes = await image.read()
if not image_bytes: raise ValueError("Received empty image file.")
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
logger.info(f"Image loaded: {img.width}x{img.height}")
except Exception as e:
logger.error(f"Image reading/loading error: {e}", exc_info=True)
raise HTTPException(status_code=400, detail=f"Invalid or unreadable image file: {e}")
finally:
if image: await image.close()
detections_raw = []
final_class_list_for_response = []
try:
# Build class list from profile and extra_words
initial_classes = set(PREDEFINED_CLASSES[profile_lower])
if extra_words and extra_words.strip():
try:
words = json.loads(extra_words) if extra_words.strip().startswith('[') else extra_words.split(',')
initial_classes |= set(str(w).lower().strip() for w in words if w and str(w).strip())
except Exception as e:
logger.warning(f"Could not parse extra_words '{extra_words}', ignoring. Error: {e}")
# Expand with synonyms
expanded_classes = set(expand_synonyms(list(initial_classes)))
final_class_list_for_response = sorted(list(expanded_classes))
# Run Roboflow detection
detections_raw = run_yoloworld_detection(
img,
expanded_classes,
confidence_threshold=confidence,
iou_threshold=iou,
profile=profile_lower
)
# Filter detections by proximity
filtered_detections = filter_close_detections(
detections_raw,
PROXIMITY_THRESHOLD,
preferred_label=PREFERRED_LABEL_FOR_FILTERING
)
logger.info(f"Detection complete. Found {len(detections_raw)} raw objects, {len(filtered_detections)} after proximity filtering.")
return DetectionResponse(
objects=filtered_detections,
count=len(filtered_detections),
profile_used=profile_lower,
classes_used=final_class_list_for_response,
status="success"
)
except Exception as e:
logger.error(f"Error during detection: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error during detection: {e}")
# --- Gradio UI Functions ---
# # Keep detect_objects_ui function
# def detect_objects_ui(image_pil: Image.Image, profile: str, confidence: float, iou: float): # Add iou parameter
# """Gradio function for YOLO-World object detection."""
# # Create a placeholder image for errors or no input
# placeholder_img = Image.new('RGB', (640, 480), color = (150, 150, 150))
# draw_placeholder = ImageDraw.Draw(placeholder_img)
# if image_pil is None:
# draw_placeholder.text((10, 10), "Please upload an image.", fill=(255,255,255))
# return placeholder_img, "Please upload an image."
# # Check if the correct model structure is available
# profile_lower = profile.lower()
# if profile_lower not in profile_models:
# error_msg = f"Error: Model for profile '{profile_lower}' not loaded."
# logger.error(f"UI requested profile '{profile_lower}' but model not loaded.")
# # Return original image with error drawn on it
# try:
# error_img_out = image_pil.copy()
# draw_error = ImageDraw.Draw(error_img_out)
# draw_error.text((10, 10), error_msg, fill="red", font=ImageFont.load_default())
# return error_img_out, error_msg
# except Exception: # Fallback if drawing on input fails
# draw_placeholder.text((10, 10), error_msg, fill="red")
# return placeholder_img, error_msg
# model_to_use = profile_models[profile_lower]
# name_map_to_use = profile_class_maps[profile_lower]
# try:
# # Ensure image is PIL Image and in RGB
# if not isinstance(image_pil, Image.Image):
# if isinstance(image_pil, np.ndarray):
# image_pil = Image.fromarray(image_pil).convert("RGB")
# else:
# error_msg = "Error: Invalid image input type."
# draw_placeholder.text((10, 10), error_msg, fill="red")
# return placeholder_img, error_msg
# else:
# image_pil = image_pil.convert("RGB")
# # Run detection using the pre-configured model
# logger.info(f"Running YOLO-World detection (UI) with profile: {profile_lower}, confidence: {confidence}, iou: {iou}")
# results = model_to_use.predict(image_pil, conf=confidence, iou=iou, verbose=False)
# # Process results using the helper and the stored map
# original_w, original_h = image_pil.width, image_pil.height
# if results and results[0] and results[0].orig_shape:
# original_h, original_w = results[0].orig_shape[:2]
# detections = process_prediction_results(
# results, original_w, original_h, name_map_to_use
# )
# # Draw boxes on a copy of the image for Gradio output
# output_image = image_pil.copy()
# draw = ImageDraw.Draw(output_image)
# try:
# font = ImageFont.truetype("arial.ttf", 15)
# except IOError:
# font = ImageFont.load_default()
# labels = []
# if not detections:
# labels.append("No objects detected.")
# else:
# for det in detections:
# box = det['box']
# label = f"{det['class_name']}: {det['confidence']:.2f}"
# labels.append(label)
# color = "red"
# draw.rectangle(
# [(box['x1'], box['y1']), (box['x2'], box['y2'])],
# outline=color, width=3
# )
# text_position = (box['x1'], box['y1'] - 15 if box['y1'] > 15 else box['y1'])
# # Use textbbox for better background calculation
# try:
# text_bbox = draw.textbbox(text_position, label, font=font)
# # Adjust background size slightly
# bg_coords = (text_bbox[0]-1, text_bbox[1]-1, text_bbox[2]+1, text_bbox[3]+1)
# draw.rectangle(bg_coords, fill=color)
# draw.text(text_position, label, fill="white", font=font)
# except AttributeError: # Fallback for older Pillow versions without textbbox
# draw.text(text_position, label, fill=color, font=font)
# logger.info(f"UI Detection Results: {labels}")
# return output_image, "\n".join(labels)
# except Exception as e:
# error_msg = f"Error: {str(e)}"
# logger.error(f"Error in detect_objects_ui: {e}", exc_info=True)
# # Return original image with error message drawn
# try:
# error_img_out = image_pil.copy()
# draw_error = ImageDraw.Draw(error_img_out)
# draw_error.text((10, 10), error_msg, fill="red", font=ImageFont.load_default())
# return error_img_out, error_msg
# except Exception: # Fallback if drawing on input fails
# draw_placeholder.text((10, 10), error_msg, fill="red")
# return placeholder_img, error_msg
# # --- Create Gradio Interface ---
# # Add theme and descriptions
# theme = gr.themes.Soft() # Example theme
# with gr.Blocks(title="IPD-Lingual API", theme=theme) as demo:
# gr.Markdown("# IPD-Lingual: Speech & Vision API")
# gr.Markdown("An API providing speech transcription/translation and object detection capabilities.")
# with gr.Tab("Home / About"):
# gr.Markdown("## Welcome!")
# gr.Markdown(
# """
# This application provides two main functionalities accessible via API endpoints and a demonstration UI:
# 1. **Speech Processing (`/api/speech`):**
# * Accepts an audio file and two language codes (e.g., 'en', 'es').
# * Uses **OpenAI Whisper (base model)** to transcribe the audio, automatically detecting which of the two provided languages is spoken.
# * Uses the **googletrans library** (unofficial Google Translate API) to translate the transcribed text into the *other* provided language.
# * Returns the detected language, original transcription, and translation.
# 2. **Object Detection (`/api/detect_objects`):**
# * Accepts an image file, a detection profile (e.g., 'casual', 'vehicles'), optional extra object names, confidence threshold, and IoU threshold.
# * Uses **YOLO-World (yolov8l-worldv2.pt)**, a powerful zero-shot object detection model from Ultralytics.
# * It can detect objects based on predefined profiles or dynamically based on user-provided text prompts (extra words).
# * Returns a list of detected objects with their bounding boxes, class names, and confidence scores.
# Use the tabs above to try out the object detection functionality or see the API endpoint details below.
# *(Note: The speech processing functionality is currently only available via the API endpoint).*
# """
# )
# gr.Markdown("---")
# gr.Markdown("### API Endpoint Summary")
# gr.Markdown("- **POST `/api/speech`**: Transcribe and Translate audio.\n - **Type**: `multipart/form-data`\n - **Fields**: `audio` (file), `lang1` (string), `lang2` (string)")
# gr.Markdown("- **POST `/api/detect_objects`**: Detect objects using YOLO-World.\n - **Type**: `multipart/form-data`\n - **Fields**: `image` (file), `profile` (string), `extra_words` (string, optional, comma-separated or JSON list), `confidence` (float, optional), `iou` (float, optional)")
# # Keep the "Object Detection" Tab
# with gr.Tab("Object Detection Demo"):
# gr.Markdown("## Detect Objects in Image (using YOLO-World)")
# gr.Markdown("Upload an image and select a detection profile. The model will identify objects belonging to that profile.")
# with gr.Row():
# with gr.Column(scale=1): # Input column slightly smaller
# image_input = gr.Image(type="pil", label="Upload Image")
# profile_select = gr.Dropdown(
# choices=sorted(list(PREDEFINED_CLASSES.keys())),
# value="casual",
# label="Detection Profile"
# )
# confidence_slider = gr.Slider(
# minimum=0.001, maximum=1.0, value=0.01, step=0.001,
# label="Confidence Threshold"
# )
# iou_slider = gr.Slider(
# minimum=0.01, maximum=1.0, value=0.2, step=0.01,
# label="IoU Threshold (NMS)"
# )
# detect_btn = gr.Button("Detect Objects", variant="primary") # Make button primary
# with gr.Column(scale=2): # Output column larger
# image_output = gr.Image(label="Detection Result", interactive=False) # Output not interactive
# labels_output = gr.Textbox(label="Detected Objects", lines=10, interactive=False)
# # Ensure the click event is correctly wired
# detect_btn.click(
# fn=detect_objects_ui,
# inputs=[image_input, profile_select, confidence_slider, iou_slider],
# outputs=[image_output, labels_output]
# )
# # Mount both FastAPI and Gradio
# # Ensure the Gradio app uses the FastAPI instance `app`
# app = gr.mount_gradio_app(app, demo, path="/")
# # ... (rest of the file remains the same) ...
# if __name__ == "__main__":
# import uvicorn
# # Check if YOLO models initialized before starting server
# # Update check to use the new model variables
# if not profile_models or dynamic_model is None:
# logger.error(f"CRITICAL: One or more YOLO-World models ({MODEL_NAME}) failed to initialize. API endpoint /api/detect_objects might not work correctly.")
# # Decide if you want to exit or run with degraded functionality
# # exit(1) # Optional: exit if model loading fails
# else:
# logger.info("All required YOLO models initialized successfully.")
# print("Starting Uvicorn server on http://0.0.0.0:7860")
# uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
import uvicorn
# Ensure YOLO-World models from detection_utils are loaded (conceptual check)
# A more robust check would involve calling a function in detection_utils
# or checking the YOLOWORLD_MODELS dictionary directly if it were accessible here.
# For now, we rely on detection_utils to log errors if models fail to load.
# if not detection_utils.YOLOWORLD_MODELS: # This line would cause an error if detection_utils is not imported
# logger.error("CRITICAL: YOLO-World models from detection_utils may not have initialized.")
# else:
# logger.info("YOLO-World models in detection_utils assumed to be loading/loaded.")
# A simple check based on previous logic (profile_models and dynamic_model were for a different setup)
# We can infer model readiness by checking if the PREDEFINED_CLASSES (used by detection_utils) has keys.
if not PREDEFINED_CLASSES: # A basic check, actual model loading is in detection_utils
logger.error(f"CRITICAL: PREDEFINED_CLASSES is empty. YOLO-World models might not be configured in detection_utils.")
else:
logger.info("Detection profiles are configured. YOLO-World model loading is handled in detection_utils.")
print("Starting Uvicorn server on http://0.0.0.0:7860")
uvicorn.run(app, host="0.0.0.0", port=7860) |