File size: 21,430 Bytes
6932dcb
 
f79aa1b
6932dcb
 
ea6c5ed
6932dcb
93b47b3
6932dcb
ea6c5ed
 
 
975ce9a
 
f79aa1b
cd06f78
ea6c5ed
f79aa1b
39078d5
f79aa1b
 
975ce9a
 
 
 
6932dcb
975ce9a
 
 
 
6932dcb
 
 
 
 
 
ea6c5ed
6932dcb
 
 
 
 
cd06f78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6c5ed
6932dcb
 
 
ea6c5ed
 
 
 
 
6932dcb
ea6c5ed
f79aa1b
 
 
 
 
 
 
 
 
 
 
 
6932dcb
f79aa1b
 
 
 
 
ea6c5ed
f79aa1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6932dcb
f79aa1b
 
 
 
 
 
 
6932dcb
f79aa1b
ea6c5ed
f9d921a
975ce9a
 
f9d921a
975ce9a
9053779
 
f9d921a
975ce9a
 
 
 
 
 
6932dcb
975ce9a
 
 
 
 
 
 
 
 
 
 
cd06f78
975ce9a
 
585c0be
 
 
f9d921a
975ce9a
585c0be
f9d921a
975ce9a
585c0be
 
 
 
 
cd06f78
 
 
 
 
 
 
 
 
 
 
 
 
 
975ce9a
cd06f78
ea6c5ed
cd06f78
 
585c0be
 
ea6c5ed
 
6932dcb
975ce9a
 
ea6c5ed
 
f79aa1b
ea6c5ed
488a535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f79aa1b
6932dcb
 
cd06f78
 
 
 
 
 
 
 
 
 
 
 
 
4803b30
cd06f78
 
ea6c5ed
488a535
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
import os
import io
import json # Add json import
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw, ImageFont # Add imports for drawing
import requests
from fastapi import FastAPI, Form, UploadFile, HTTPException, File # Import Form, UploadFile, HTTPException, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import traceback # For detailed error logging
import logging # Add logging import
from fastapi import Request # Import Request
import tempfile # Use tempfile for safer temporary file handling
import math # Added for distance calculation

# Import from utility files
from detection_utils import PREDEFINED_CLASSES, run_yoloworld_detection, expand_synonyms
# Import the new speech processing function
from speech_utils import process_audio

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

profile_models = {}
profile_class_maps = {} # Store the name mapping for each profile model
dynamic_model = None

# Create FastAPI app
app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- Proximity Filtering Configuration ---
PROXIMITY_THRESHOLD = 50.0  # Pixels. Adjust as needed.
PREFERRED_LABEL_FOR_FILTERING = "auto rickshaw"

# --- Helper for Proximity Filtering ---
def _calculate_distance(center1: List[float], center2: List[float]) -> float:
    if len(center1) != 2 or len(center2) != 2:
        # Should not happen with valid detection data, but good to be safe
        return float('inf') 
    return math.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)

def filter_close_detections(
    detections: List[Dict[str, Any]],
    distance_threshold: float,
    preferred_label: str = "auto rickshaw"
) -> List[Dict[str, Any]]:
    if not detections:
        return []

    num_detections = len(detections)
    processed_mask = [False] * num_detections
    final_filtered_list = []

    for i in range(num_detections):
        if processed_mask[i]:
            continue

        # Start a new cluster
        current_cluster_indices = []
        # Use a list as a queue for BFS
        queue = [] 
        
        # Seed the queue with the current unprocessed detection
        queue.append(i)
        processed_mask[i] = True
        
        head = 0
        while head < len(queue):
            current_idx = queue[head]
            head += 1
            current_cluster_indices.append(current_idx) # Add to current cluster

            # Check against all other detections
            for j in range(num_detections):
                if not processed_mask[j]:
                    # Ensure 'centre' key exists and is valid before calculating distance
                    center1 = detections[current_idx].get('centre')
                    center2 = detections[j].get('centre')

                    if not (isinstance(center1, list) and len(center1) == 2 and
                            isinstance(center2, list) and len(center2) == 2):
                        # Log warning or skip if center data is missing/malformed
                        # logger.warning(f"Skipping proximity check due to missing/malformed center data for detection indices {current_idx}, {j}")
                        continue
                    
                    dist = _calculate_distance(center1, center2)
                    if dist < distance_threshold:
                        processed_mask[j] = True
                        queue.append(j) # Add to queue to explore its neighbors
        
        # We have a cluster (all indices in current_cluster_indices).
        # Now select the best one from it based on preference and confidence.
        if not current_cluster_indices: # Should not happen if loop started
            continue

        cluster_detections = [detections[k] for k in current_cluster_indices]
        
        preferred_detections_in_cluster = [
            d for d in cluster_detections if d.get('label_en', '').lower() == preferred_label.lower()
        ]

        chosen_detection = None
        if preferred_detections_in_cluster:
            # If preferred label is present, pick the one with highest confidence among them
            chosen_detection = max(preferred_detections_in_cluster, key=lambda d: d.get('confidence', 0.0))
        else:
            # Otherwise, pick the one with highest confidence from the whole cluster
            if cluster_detections: 
                 chosen_detection = max(cluster_detections, key=lambda d: d.get('confidence', 0.0))
        
        if chosen_detection:
            final_filtered_list.append(chosen_detection)

    return final_filtered_list

# --- Pydantic Models ---
class DetectionResponse(BaseModel):
    objects: List[Dict[str, Any]]
    count: int
    profile_used: str
    classes_used: List[str] # Return the actual list used for detection
    status: str = "success"
    message: Optional[str] = None


# --- API Endpoints ---

# Add the new consolidated speech endpoint
@app.post("/api/speech")
async def handle_speech(
    audio: UploadFile = File(...),
    lang1: str = Form(...),
    lang2: str = Form(...)
):
    """
    Receives an audio file, transcribes it using Whisper (detecting between lang1 and lang2),
    and translates the result to the other language using googletrans.
    """
    try:
        logger.info(f"Received speech processing request. Lang1: {lang1}, Lang2: {lang2}, File: {audio.filename}")
        # Read audio file content
        audio_bytes = await audio.read()
        if not audio_bytes:
            raise HTTPException(status_code=400, detail="Received empty audio file.")

        # Process using the utility function
        result = await process_audio(audio_bytes, lang1, lang2)

        if result is None:
            raise HTTPException(status_code=500, detail="Failed to process audio.")
        if "error" in result: # Handle specific errors returned by process_audio
             raise HTTPException(status_code=400, detail=result["error"])


        logger.info(f"Speech processing successful. Detected: {result.get('detected_language')}")
        return result

    except HTTPException as http_exc:
        # Re-raise HTTPExceptions directly
        raise http_exc
    except Exception as e:
        logger.error(f"Error in /api/speech endpoint: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
    finally:
        # Ensure the uploaded file stream is closed
        if audio:
            await audio.close()


# Keep /api/detect_objects endpoint
@app.post("/api/detect_objects", response_model=DetectionResponse)
async def detect_objects_yolo_world(
    request: Request, # Keep for logging if needed
    image: UploadFile = File(...), # Revert back to File(...)
    profile: str = Form("casual"),
    extra_words: Optional[str] = Form(None),
    confidence: float = Form(0.2, ge=0.0, le=1.0), # Lowered default confidence
    iou: float = Form(0.50, ge=0.0, le=1.0) # Lowered default IoU (NMS threshold)
):
    profile_lower = profile.lower()
    if profile_lower not in PREDEFINED_CLASSES:
        logger.warning(f"Invalid profile '{profile}' received, defaulting to 'casual'.")
        profile_lower = "casual" # Default to casual

    # --- Image Loading ---
    try:
        logger.info("Reading image bytes...")
        image_bytes = await image.read()
        if not image_bytes: raise ValueError("Received empty image file.")
        img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        logger.info(f"Image loaded: {img.width}x{img.height}")
    except Exception as e:
        logger.error(f"Image reading/loading error: {e}", exc_info=True)
        raise HTTPException(status_code=400, detail=f"Invalid or unreadable image file: {e}")
    finally:
        if image: await image.close()

    detections_raw = []
    final_class_list_for_response = []
    try:
        # Build class list from profile and extra_words
        initial_classes = set(PREDEFINED_CLASSES[profile_lower])
        if extra_words and extra_words.strip():
            try:
                words = json.loads(extra_words) if extra_words.strip().startswith('[') else extra_words.split(',')
                initial_classes |= set(str(w).lower().strip() for w in words if w and str(w).strip())
            except Exception as e:
                logger.warning(f"Could not parse extra_words '{extra_words}', ignoring. Error: {e}")
        # Expand with synonyms
        expanded_classes = set(expand_synonyms(list(initial_classes)))
        final_class_list_for_response = sorted(list(expanded_classes))

        # Run Roboflow detection
        detections_raw = run_yoloworld_detection(
            img, 
            expanded_classes, 
            confidence_threshold=confidence, 
            iou_threshold=iou, 
            profile=profile_lower
        )
        
        # Filter detections by proximity
        filtered_detections = filter_close_detections(
            detections_raw, 
            PROXIMITY_THRESHOLD, 
            preferred_label=PREFERRED_LABEL_FOR_FILTERING
        )

        logger.info(f"Detection complete. Found {len(detections_raw)} raw objects, {len(filtered_detections)} after proximity filtering.")
        return DetectionResponse(
            objects=filtered_detections,
            count=len(filtered_detections),
            profile_used=profile_lower,
            classes_used=final_class_list_for_response,
            status="success"
        )
    except Exception as e:
        logger.error(f"Error during detection: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Internal server error during detection: {e}")


# --- Gradio UI Functions ---

# # Keep detect_objects_ui function
# def detect_objects_ui(image_pil: Image.Image, profile: str, confidence: float, iou: float): # Add iou parameter
#     """Gradio function for YOLO-World object detection."""
#     # Create a placeholder image for errors or no input
#     placeholder_img = Image.new('RGB', (640, 480), color = (150, 150, 150))
#     draw_placeholder = ImageDraw.Draw(placeholder_img)

#     if image_pil is None:
#         draw_placeholder.text((10, 10), "Please upload an image.", fill=(255,255,255))
#         return placeholder_img, "Please upload an image."

#     # Check if the correct model structure is available
#     profile_lower = profile.lower()
#     if profile_lower not in profile_models:
#         error_msg = f"Error: Model for profile '{profile_lower}' not loaded."
#         logger.error(f"UI requested profile '{profile_lower}' but model not loaded.")
#         # Return original image with error drawn on it
#         try:
#             error_img_out = image_pil.copy()
#             draw_error = ImageDraw.Draw(error_img_out)
#             draw_error.text((10, 10), error_msg, fill="red", font=ImageFont.load_default())
#             return error_img_out, error_msg
#         except Exception: # Fallback if drawing on input fails
#              draw_placeholder.text((10, 10), error_msg, fill="red")
#              return placeholder_img, error_msg


#     model_to_use = profile_models[profile_lower]
#     name_map_to_use = profile_class_maps[profile_lower]

#     try:
#         # Ensure image is PIL Image and in RGB
#         if not isinstance(image_pil, Image.Image):
#              if isinstance(image_pil, np.ndarray):
#                  image_pil = Image.fromarray(image_pil).convert("RGB")
#              else:
#                  error_msg = "Error: Invalid image input type."
#                  draw_placeholder.text((10, 10), error_msg, fill="red")
#                  return placeholder_img, error_msg
#         else:
#             image_pil = image_pil.convert("RGB")

#         # Run detection using the pre-configured model
#         logger.info(f"Running YOLO-World detection (UI) with profile: {profile_lower}, confidence: {confidence}, iou: {iou}")
#         results = model_to_use.predict(image_pil, conf=confidence, iou=iou, verbose=False)

#         # Process results using the helper and the stored map
#         original_w, original_h = image_pil.width, image_pil.height
#         if results and results[0] and results[0].orig_shape:
#              original_h, original_w = results[0].orig_shape[:2]
#         detections = process_prediction_results(
#             results, original_w, original_h, name_map_to_use
#         )

#         # Draw boxes on a copy of the image for Gradio output
#         output_image = image_pil.copy()
#         draw = ImageDraw.Draw(output_image)
#         try:
#             font = ImageFont.truetype("arial.ttf", 15)
#         except IOError:
#             font = ImageFont.load_default()

#         labels = []
#         if not detections:
#             labels.append("No objects detected.")
#         else:
#             for det in detections:
#                 box = det['box']
#                 label = f"{det['class_name']}: {det['confidence']:.2f}"
#                 labels.append(label)
#                 color = "red"
#                 draw.rectangle(
#                     [(box['x1'], box['y1']), (box['x2'], box['y2'])],
#                     outline=color, width=3
#                 )
#                 text_position = (box['x1'], box['y1'] - 15 if box['y1'] > 15 else box['y1'])
#                 # Use textbbox for better background calculation
#                 try:
#                     text_bbox = draw.textbbox(text_position, label, font=font)
#                     # Adjust background size slightly
#                     bg_coords = (text_bbox[0]-1, text_bbox[1]-1, text_bbox[2]+1, text_bbox[3]+1)
#                     draw.rectangle(bg_coords, fill=color)
#                     draw.text(text_position, label, fill="white", font=font)
#                 except AttributeError: # Fallback for older Pillow versions without textbbox
#                     draw.text(text_position, label, fill=color, font=font)


#         logger.info(f"UI Detection Results: {labels}")
#         return output_image, "\n".join(labels)

#     except Exception as e:
#         error_msg = f"Error: {str(e)}"
#         logger.error(f"Error in detect_objects_ui: {e}", exc_info=True)
#         # Return original image with error message drawn
#         try:
#             error_img_out = image_pil.copy()
#             draw_error = ImageDraw.Draw(error_img_out)
#             draw_error.text((10, 10), error_msg, fill="red", font=ImageFont.load_default())
#             return error_img_out, error_msg
#         except Exception: # Fallback if drawing on input fails
#              draw_placeholder.text((10, 10), error_msg, fill="red")
#              return placeholder_img, error_msg


# # --- Create Gradio Interface ---
# # Add theme and descriptions
# theme = gr.themes.Soft() # Example theme

# with gr.Blocks(title="IPD-Lingual API", theme=theme) as demo:
#     gr.Markdown("# IPD-Lingual: Speech & Vision API")
#     gr.Markdown("An API providing speech transcription/translation and object detection capabilities.")

#     with gr.Tab("Home / About"):
#         gr.Markdown("## Welcome!")
#         gr.Markdown(
#             """
#             This application provides two main functionalities accessible via API endpoints and a demonstration UI:

#             1.  **Speech Processing (`/api/speech`):**
#                 *   Accepts an audio file and two language codes (e.g., 'en', 'es').
#                 *   Uses **OpenAI Whisper (base model)** to transcribe the audio, automatically detecting which of the two provided languages is spoken.
#                 *   Uses the **googletrans library** (unofficial Google Translate API) to translate the transcribed text into the *other* provided language.
#                 *   Returns the detected language, original transcription, and translation.

#             2.  **Object Detection (`/api/detect_objects`):**
#                 *   Accepts an image file, a detection profile (e.g., 'casual', 'vehicles'), optional extra object names, confidence threshold, and IoU threshold.
#                 *   Uses **YOLO-World (yolov8l-worldv2.pt)**, a powerful zero-shot object detection model from Ultralytics.
#                 *   It can detect objects based on predefined profiles or dynamically based on user-provided text prompts (extra words).
#                 *   Returns a list of detected objects with their bounding boxes, class names, and confidence scores.

#             Use the tabs above to try out the object detection functionality or see the API endpoint details below.
#             *(Note: The speech processing functionality is currently only available via the API endpoint).*
#             """
#         )
#         gr.Markdown("---")
#         gr.Markdown("### API Endpoint Summary")
#         gr.Markdown("- **POST `/api/speech`**: Transcribe and Translate audio.\n  - **Type**: `multipart/form-data`\n  - **Fields**: `audio` (file), `lang1` (string), `lang2` (string)")
#         gr.Markdown("- **POST `/api/detect_objects`**: Detect objects using YOLO-World.\n  - **Type**: `multipart/form-data`\n  - **Fields**: `image` (file), `profile` (string), `extra_words` (string, optional, comma-separated or JSON list), `confidence` (float, optional), `iou` (float, optional)")


#     # Keep the "Object Detection" Tab
#     with gr.Tab("Object Detection Demo"):
#         gr.Markdown("## Detect Objects in Image (using YOLO-World)")
#         gr.Markdown("Upload an image and select a detection profile. The model will identify objects belonging to that profile.")
#         with gr.Row():
#             with gr.Column(scale=1): # Input column slightly smaller
#                 image_input = gr.Image(type="pil", label="Upload Image")
#                 profile_select = gr.Dropdown(
#                     choices=sorted(list(PREDEFINED_CLASSES.keys())),
#                     value="casual",
#                     label="Detection Profile"
#                 )
#                 confidence_slider = gr.Slider(
#                     minimum=0.001, maximum=1.0, value=0.01, step=0.001,
#                     label="Confidence Threshold"
#                 )
#                 iou_slider = gr.Slider(
#                     minimum=0.01, maximum=1.0, value=0.2, step=0.01,
#                     label="IoU Threshold (NMS)"
#                 )
#                 detect_btn = gr.Button("Detect Objects", variant="primary") # Make button primary
#             with gr.Column(scale=2): # Output column larger
#                 image_output = gr.Image(label="Detection Result", interactive=False) # Output not interactive
#                 labels_output = gr.Textbox(label="Detected Objects", lines=10, interactive=False)

#         # Ensure the click event is correctly wired
#         detect_btn.click(
#             fn=detect_objects_ui,
#             inputs=[image_input, profile_select, confidence_slider, iou_slider],
#             outputs=[image_output, labels_output]
#         )

# # Mount both FastAPI and Gradio
# # Ensure the Gradio app uses the FastAPI instance `app`
# app = gr.mount_gradio_app(app, demo, path="/")

# # ... (rest of the file remains the same) ...

# if __name__ == "__main__":
#     import uvicorn
#     # Check if YOLO models initialized before starting server
#     # Update check to use the new model variables
#     if not profile_models or dynamic_model is None:
#         logger.error(f"CRITICAL: One or more YOLO-World models ({MODEL_NAME}) failed to initialize. API endpoint /api/detect_objects might not work correctly.")
#         # Decide if you want to exit or run with degraded functionality
#         # exit(1) # Optional: exit if model loading fails
#     else:
#         logger.info("All required YOLO models initialized successfully.")

#     print("Starting Uvicorn server on http://0.0.0.0:7860")
#     uvicorn.run(app, host="0.0.0.0", port=7860)

if __name__ == "__main__":
    import uvicorn
    # Ensure YOLO-World models from detection_utils are loaded (conceptual check)
    # A more robust check would involve calling a function in detection_utils
    # or checking the YOLOWORLD_MODELS dictionary directly if it were accessible here.
    # For now, we rely on detection_utils to log errors if models fail to load.
    # if not detection_utils.YOLOWORLD_MODELS: # This line would cause an error if detection_utils is not imported
    #    logger.error("CRITICAL: YOLO-World models from detection_utils may not have initialized.")
    # else:
    #    logger.info("YOLO-World models in detection_utils assumed to be loading/loaded.")

    # A simple check based on previous logic (profile_models and dynamic_model were for a different setup)
    # We can infer model readiness by checking if the PREDEFINED_CLASSES (used by detection_utils) has keys.
    if not PREDEFINED_CLASSES: # A basic check, actual model loading is in detection_utils
        logger.error(f"CRITICAL: PREDEFINED_CLASSES is empty. YOLO-World models might not be configured in detection_utils.")
    else:
        logger.info("Detection profiles are configured. YOLO-World model loading is handled in detection_utils.")

    print("Starting Uvicorn server on http://0.0.0.0:7860")
    uvicorn.run(app, host="0.0.0.0", port=7860)