Spaces:
Runtime error
Runtime error
| import torch | |
| from torchvision import transforms | |
| from PIL import Image, ImageDraw, ImageEnhance | |
| import requests | |
| from torchvision.models.detection import maskrcnn_resnet50_fpn | |
| import random | |
| # Load the Mask R-CNN model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = maskrcnn_resnet50_fpn(pretrained=True).to(device).eval() | |
| # Function to preprocess the image | |
| def preprocess_image(image_path): | |
| # Open and convert to RGB | |
| image = Image.open(image_path).convert("RGB") | |
| transform = transforms.Compose([ | |
| # Convert image to a tensor | |
| transforms.ToTensor(), | |
| ]) | |
| # Add batch dimension and send to device | |
| return transform(image).unsqueeze(0).to(device), image | |
| # Run object detection | |
| def detect_objects(image_path, threshold=0.5): | |
| image_tensor, image_pil = preprocess_image(image_path) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor)[0] # Get model output | |
| # Extract data from model output | |
| masks = outputs["masks"] # Object masks | |
| labels = outputs["labels"] # Object labels | |
| scores = outputs["scores"] # Confidence scores | |
| filtered_masks = [] | |
| for i in range(len(masks)): | |
| # Only keep objects with high confidence | |
| if scores[i] >= threshold: | |
| # Convert to binary mask | |
| mask = masks[i, 0].mul(255).byte().cpu().numpy() | |
| filtered_masks.append((mask, labels[i].item(), scores[i].item())) | |
| return filtered_masks, image_pil | |
| # Apply color masks to detected objects | |
| def apply_instance_masks(image_path): | |
| masks, image = detect_objects(image_path) | |
| # Convert to RGBA to support transparency | |
| img = image.convert("RGBA") | |
| # Create a transparent layer | |
| overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) | |
| draw = ImageDraw.Draw(overlay) | |
| # Store unique colors for each object category | |
| color_map = {} | |
| for mask, label, score in masks: | |
| if label not in color_map: | |
| # Assign a random color for this object category | |
| color_map[label] = (random.randint(50, 50), random.randint(225, 255), random.randint(50, 50), 150) | |
| mask_pil = Image.fromarray(mask, mode="L") # Convert mask to grayscale image | |
| colored_mask = Image.new("RGBA", mask_pil.size, color_map[label]) # Create a color mask | |
| overlay.paste(colored_mask, (0, 0), mask_pil) # Apply mask to the overlay | |
| # Combine the original image with the overlay | |
| result_image = Image.alpha_composite(img, overlay) | |
| return result_image.convert("RGB") # Convert back to RGB mode | |
| import gradio as gr | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Object Detection with Mask R-CNN") | |
| gr.Markdown("This demo applies instance segmentation to an image using Mask R-CNN.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="filepath") | |
| threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Confidence Threshold") | |
| detect_button = gr.Button("Detect Objects") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Output Image with Masks") | |
| detect_button.click( | |
| fn=lambda img_path, thresh: apply_instance_masks(img_path) if img_path else None, | |
| inputs=[input_image, threshold], | |
| outputs=output_image | |
| ) | |
| demo.launch() | |