Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from segment_anything import sam_model_registry, SamPredictor | |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
| import supervision as sv | |
| import os | |
| import urllib.request | |
| # Download SAM checkpoint if not exists | |
| SAM_CHECKPOINT = "sam_vit_h_4b8939.pth" | |
| SAM_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
| if not os.path.exists(SAM_CHECKPOINT): | |
| print(f"Downloading SAM checkpoint...") | |
| urllib.request.urlretrieve(SAM_CHECKPOINT_URL, SAM_CHECKPOINT) | |
| print(f"SAM checkpoint downloaded!") | |
| # Initialize models | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load Grounding DINO from Hugging Face | |
| grounding_dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") | |
| grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained( | |
| "IDEA-Research/grounding-dino-tiny" | |
| ).to(device) | |
| # Load SAM | |
| sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT) | |
| sam.to(device=device) | |
| sam_predictor = SamPredictor(sam) | |
| def process_image(image, text_prompt, box_threshold, text_threshold, quality): | |
| """ | |
| Process image with Grounded SAM | |
| """ | |
| try: | |
| # Resize based on quality setting | |
| if quality == "Low": | |
| max_size = 800 | |
| elif quality == "Medium": | |
| max_size = 1024 | |
| else: # High | |
| max_size = 1920 | |
| # Resize image if needed | |
| h, w = image.shape[:2] | |
| if max(h, w) > max_size: | |
| scale = max_size / max(h, w) | |
| new_h, new_w = int(h * scale), int(w * scale) | |
| image = cv2.resize(image, (new_w, new_h)) | |
| # Convert to PIL Image for Grounding DINO | |
| pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
| # Grounding DINO inference | |
| inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = grounding_dino_model(**inputs) | |
| # Post-process results | |
| results = grounding_dino_processor.post_process_grounded_object_detection( | |
| outputs, | |
| inputs.input_ids, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| target_sizes=[pil_image.size[::-1]] | |
| )[0] | |
| # Extract boxes and labels | |
| boxes = results["boxes"].cpu().numpy() | |
| labels = results["labels"] | |
| if len(boxes) == 0: | |
| return image, "No objects detected. Try adjusting the thresholds or text prompt." | |
| # Convert boxes to xyxy format for SAM | |
| boxes_xyxy = boxes | |
| # SAM inference | |
| sam_predictor.set_image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
| masks = [] | |
| for box in boxes_xyxy: | |
| mask, _, _ = sam_predictor.predict( | |
| box=box, | |
| multimask_output=False | |
| ) | |
| masks.append(mask[0]) | |
| # Visualize results | |
| result_image = image.copy() | |
| # Draw masks | |
| for i, mask in enumerate(masks): | |
| color = np.random.randint(0, 255, 3).tolist() | |
| result_image[mask] = result_image[mask] * 0.5 + np.array(color) * 0.5 | |
| # Draw boxes and labels | |
| for i, (box, label) in enumerate(zip(boxes_xyxy, labels)): | |
| x1, y1, x2, y2 = map(int, box) | |
| color = np.random.randint(0, 255, 3).tolist() | |
| cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2) | |
| cv2.putText(result_image, label, (x1, y1-10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| metadata = f"β Detected {len(boxes)} objects: {', '.join(labels)}" | |
| return result_image, metadata | |
| except Exception as e: | |
| return image, f"β Error: {str(e)}" | |
| # Gradio Interface | |
| with gr.Blocks(title="Grounded SAM") as demo: | |
| gr.Markdown("# π― Grounded SAM - Object Detection & Segmentation") | |
| gr.Markdown("Upload an image and describe what you want to detect (e.g., 'fish', 'all fish', 'person').") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="numpy") | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="e.g., 'fish', 'person', 'car'", | |
| value="fish" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| box_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.35, step=0.05, | |
| label="Box Threshold (detection confidence)" | |
| ) | |
| text_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.25, step=0.05, | |
| label="Text Threshold (text matching confidence)" | |
| ) | |
| quality = gr.Radio( | |
| choices=["Low", "Medium", "High"], | |
| value="Medium", | |
| label="Processing Quality" | |
| ) | |
| submit_btn = gr.Button("π Process Image", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Output with Masks & Boxes", type="numpy") | |
| output_metadata = gr.Textbox(label="Detection Metadata", lines=3) | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[input_image, text_prompt, box_threshold, text_threshold, quality], | |
| outputs=[output_image, output_metadata] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/fish1.jpg", "fish", 0.35, 0.25, "Medium"], | |
| ["examples/fish2.jpg", "all fish", 0.35, 0.25, "Medium"], | |
| ], | |
| inputs=[input_image, text_prompt, box_threshold, text_threshold, quality], | |
| ) | |
| demo.launch() | |