|
import cv2 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from torchvision import transforms |
|
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
def get_masks(model_type, image): |
|
if model_type == 'vit_h': |
|
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") |
|
|
|
masks_h = mask_generator_h.generate(image) |
|
if model_type == 'vit_b': |
|
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth") |
|
|
|
if model_type == 'vit_l': |
|
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth") |
|
|
|
mask_generator = SamAutomaticMaskGenerator(sam) |
|
masks = mask_generator.generate(image) |
|
for i, mask_data in enumerate(masks): |
|
mask = mask_data['segmentation'] |
|
color = colors[i] |
|
composite_image[mask] = (color[:3] * 255).astype(np.uint8) |
|
|
|
|
|
overlayed_image = (composite_image * 0.5 + image_cv.squeeze().permute(1, 2, 0).cpu().numpy() * 0.5).astype(np.uint8) |
|
return overlayed_image |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=get_masks, |
|
inputs=[gr.inputs.Image(type="pil"), gr.inputs.Dropdown(['vit_h', 'vit_b', 'vit_l'], label="Model Type")], |
|
outputs=gr.outputs.Image(type="pil"), |
|
title="SAM Model Segmentation and Classification", |
|
description="Upload an image, select a model type, and receive the segmented and classified parts." |
|
) |
|
|
|
|
|
iface.launch() |