Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import ibbi | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import matplotlib.pyplot as plt | |
| import io | |
| # --- Model Management --- | |
| MODEL_REGISTRY = { | |
| "Single-Class Detection": { | |
| "yolov10": "yolov10x_bb_detect_model", | |
| "yolov11": "yolov11x_bb_detect_model", | |
| "yolov9": "yolov9e_bb_detect_model", | |
| "yolov8": "yolov8x_bb_detect_model", | |
| "rtdetr": "rtdetrx_bb_detect_model", | |
| }, | |
| "Multi-Class Detection": { | |
| "yolov10": "yolov10x_bb_multi_class_detect_model", | |
| "yolov11": "yolov11x_bb_multi_class_detect_model", | |
| "yolov9": "yolov9e_bb_multi_class_detect_model", | |
| "yolov8": "yolov8x_bb_multi_class_detect_model", | |
| "rtdetr": "rtdetrx_bb_multi_class_detect_model", | |
| }, | |
| "Zero-Shot Detection": { | |
| "grounding_dino": "grounding_dino_detect_model" | |
| } | |
| } | |
| # --- CORRECTED MODEL MANAGEMENT --- | |
| # Caching is removed to prevent errors from stateful models. | |
| # This function now loads a fresh model for each analysis request. | |
| def get_model(task, architecture): | |
| """ | |
| Loads a fresh model instance based on user selection. | |
| This prevents stateful changes from one run affecting the next. | |
| """ | |
| try: | |
| # For Zero-Shot, the architecture is always 'grounding_dino' | |
| if task == "Zero-Shot Detection": | |
| architecture = "grounding_dino" | |
| model_name = MODEL_REGISTRY[task][architecture] | |
| print(f"Loading a fresh model instance: {model_name}") | |
| model = ibbi.create_model(model_name, pretrained=True) | |
| print("Model loaded successfully!") | |
| return model | |
| except KeyError as e: | |
| raise gr.Error(f"Model lookup failed. Task: '{task}', Arch: '{architecture}'. Error: {e}") | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load model. Please check the model name and your connection. Error: {e}") | |
| # --- Visualization and Drawing Functions --- | |
| def draw_yolo_predictions(image, results, font, color="red"): | |
| """Draws YOLO predictions on an image with a dynamically sized font.""" | |
| img_copy = image.copy() | |
| draw = ImageDraw.Draw(img_copy) | |
| if not results or not results[0].boxes: | |
| return img_copy | |
| res_for_img = results[0] | |
| class_names = res_for_img.names | |
| for box in res_for_img.boxes: | |
| if box.cls.numel() == 0 or box.conf.numel() == 0: continue | |
| coords = box.xyxy[0].tolist() | |
| score = box.conf[0].item() | |
| class_id = int(box.cls[0].item()) | |
| label_text = f"{class_names.get(class_id, f'Unknown-{class_id}')}: {score:.2f}" | |
| draw.rectangle(coords, outline=color, width=3) | |
| text_bbox = draw.textbbox((coords[0], coords[1]), label_text, font=font) | |
| text_bg_y1 = coords[1] - (text_bbox[3] - text_bbox[1]) if coords[1] > (text_bbox[3] - text_bbox[1]) else 0 | |
| text_bg_coords = (coords[0], text_bg_y1, coords[0] + (text_bbox[2] - text_bbox[0]), coords[1]) | |
| draw.rectangle(text_bg_coords, fill=color) | |
| draw.text((coords[0], text_bg_y1), label_text, fill="white", font=font) | |
| return img_copy | |
| def draw_dino_predictions(image, results, font, color="green"): | |
| """Draws Grounding DINO predictions on an image with a dynamically sized font.""" | |
| img_copy = image.copy() | |
| draw = ImageDraw.Draw(img_copy) | |
| if not results: return img_copy | |
| for box, score, label in zip(results.get("boxes", []), results.get("scores", []), results.get("text_labels", [])): | |
| coords = box.tolist() | |
| label_text = f"{label}: {score:.2f}" | |
| draw.rectangle(coords, outline=color, width=3) | |
| text_bbox = draw.textbbox((coords[0], coords[1]), label_text, font=font) | |
| text_bg_y1 = coords[1] - (text_bbox[3] - text_bbox[1]) if coords[1] > (text_bbox[3] - text_bbox[1]) else 0 | |
| text_bg_coords = (coords[0], text_bg_y1, coords[0] + (text_bbox[2] - text_bbox[0]), coords[1]) | |
| draw.rectangle(text_bg_coords, fill=color) | |
| draw.text((coords[0], text_bg_y1), label_text, fill="white", font=font) | |
| return img_copy | |
| def visualize_embedding(embedding): | |
| """Visualizes a feature embedding as an image.""" | |
| if embedding is None: return None | |
| if not hasattr(embedding, 'cpu'): return None | |
| if len(embedding.shape) == 1: | |
| embedding = embedding.unsqueeze(0) | |
| fig, ax = plt.subplots(figsize=(10, 2)) | |
| ax.imshow(embedding.cpu().detach().numpy(), cmap='viridis', aspect='auto') | |
| ax.set_title("Feature Embedding Visualization") | |
| ax.set_xlabel("Feature Dimension") | |
| ax.set_yticks([]) | |
| fig.tight_layout() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png') | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| # --- CORRECTED Main Processing Function --- | |
| def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold, text_threshold): | |
| """Performs the main analysis with corrected logic.""" | |
| if image is None: | |
| raise gr.Error("Please upload an image first!") | |
| # Calculate a dynamic font size based on image width. | |
| dynamic_font_size = max(15, int(image.width * 0.04)) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", dynamic_font_size) | |
| except IOError: | |
| font = ImageFont.load_default(size=dynamic_font_size) | |
| # Get a fresh model instance to avoid stateful errors | |
| model = get_model(task, architecture) | |
| outputs = {"annotated_image": None, "model_info": "", "classes_info": "", "embedding_plot": None} | |
| if task in ["Single-Class Detection", "Multi-Class Detection"]: | |
| results = model.predict(image) | |
| outputs["annotated_image"] = draw_yolo_predictions(image, results, font=font) | |
| features = model.extract_features(image) | |
| outputs["model_info"] = f"Architecture: {architecture.upper()}\nTask: {task}\nDevice: {model.device}" | |
| outputs["classes_info"] = f"Classes: {model.get_classes()}" | |
| else: # Zero-Shot Detection | |
| if not text_prompt: | |
| raise gr.Error("Please provide a text prompt for Zero-Shot Detection.") | |
| results = model.predict( | |
| image, | |
| text_prompt=text_prompt, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold | |
| ) | |
| outputs["annotated_image"] = draw_dino_predictions(image, results, font=font) | |
| features = model.extract_features(image, text_prompt=text_prompt) | |
| outputs["model_info"] = f"Architecture: GROUNDING_DINO\nTask: {task}\nDevice: {model.device}\nHF Model ID: {model.model.config._name_or_path}" | |
| outputs["classes_info"] = f"Prompt: '{text_prompt}'" | |
| # Process features for visualization | |
| if isinstance(features, dict): | |
| outputs["embedding_plot"] = visualize_embedding(features.get('last_hidden_state')) | |
| else: | |
| outputs["embedding_plot"] = visualize_embedding(features) | |
| # Correctly placed return statement ensures all outputs are always returned | |
| return outputs["annotated_image"], outputs["model_info"], outputs["classes_info"], outputs["embedding_plot"] | |
| # --- Gradio UI --- | |
| def update_ui_for_task(task): | |
| """Updates the UI components based on the selected task.""" | |
| if task in ["Single-Class Detection", "Multi-Class Detection"]: | |
| arch_choices = list(MODEL_REGISTRY[task].keys()) | |
| return { | |
| arch_dropdown: gr.update(choices=arch_choices, value=arch_choices[0], visible=True, interactive=True), | |
| prompt_textbox: gr.update(visible=False, value=""), | |
| box_threshold_slider: gr.update(visible=False), | |
| text_threshold_slider: gr.update(visible=False) | |
| } | |
| else: # Zero-Shot Detection | |
| arch_choices = list(MODEL_REGISTRY[task].keys()) | |
| return { | |
| arch_dropdown: gr.update(choices=arch_choices, value=arch_choices[0], visible=False), | |
| prompt_textbox: gr.update(visible=True), | |
| box_threshold_slider: gr.update(visible=True), | |
| text_threshold_slider: gr.update(visible=True) | |
| } | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# IBBI - Intelligent Bark Beetle Identifier") | |
| gr.Markdown("An all-in-one interface to analyze images using the `ibbi` library. Upload an image, select a task and model, and view the complete analysis.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Inputs") | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| task_selector = gr.Radio( | |
| choices=["Single-Class Detection", "Multi-Class Detection", "Zero-Shot Detection"], | |
| value="Single-Class Detection", | |
| label="Choose Task" | |
| ) | |
| arch_dropdown = gr.Dropdown( | |
| choices=list(MODEL_REGISTRY["Single-Class Detection"].keys()), | |
| value="yolov10", | |
| label="Choose Model Architecture" | |
| ) | |
| prompt_textbox = gr.Textbox( | |
| label="Enter Text Prompt (for Zero-Shot)", | |
| placeholder="e.g., insect . circle . metal ball", | |
| visible=False | |
| ) | |
| box_threshold_slider = gr.Slider( | |
| minimum=0.05, maximum=1.0, value=0.25, step=0.05, | |
| label="Box Threshold (Zero-Shot)", | |
| info="Lower to detect more objects, even with low confidence.", | |
| visible=False | |
| ) | |
| text_threshold_slider = gr.Slider( | |
| minimum=0.05, maximum=1.0, value=0.25, step=0.05, | |
| label="Text Threshold (Zero-Shot)", | |
| info="Lower to allow more labels to match detected objects.", | |
| visible=False | |
| ) | |
| analyze_btn = gr.Button("Analyze Image", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 2. Analysis Results") | |
| output_image = gr.Image(label="Annotated Image") | |
| with gr.Accordion("Details", open=True): | |
| model_details_output = gr.Textbox(label="Model Details", lines=4) | |
| classes_output = gr.Textbox(label="Classes / Prompt") | |
| embedding_output = gr.Image(label="Feature Embedding Visualization") | |
| # --- Event Handlers --- | |
| task_selector.change( | |
| fn=update_ui_for_task, | |
| inputs=task_selector, | |
| outputs=[arch_dropdown, prompt_textbox, box_threshold_slider, text_threshold_slider] | |
| ) | |
| analyze_btn.click( | |
| fn=comprehensive_analysis, | |
| inputs=[image_input, task_selector, arch_dropdown, prompt_textbox, box_threshold_slider, text_threshold_slider], | |
| outputs=[output_image, model_details_output, classes_output, embedding_output] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### 3. Or Start with an Example Image") | |
| example_list = [ | |
| ["example_images/example1.jpg"], | |
| ["example_images/example2.jpg"], | |
| ["example_images/example3.jpg"], | |
| ["example_images/example4.jpg"], | |
| ["example_images/example5.jpg"], | |
| ] | |
| gr.Examples( | |
| examples=example_list, | |
| inputs=image_input, | |
| label="Select an image to load it" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True, inline=True, debug=True, show_error=True) |