import gradio as gr from PIL import Image import numpy as np from gradio import components import torchvision from torchvision.models.detection import ( maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights, ) import torchvision.transforms.functional as F import torch from torchvision.utils import draw_segmentation_masks weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() model = maskrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() def segment_and_show(image): # abc input_image = Image.fromarray(image) input_tensor = torch.tensor(np.array(input_image)) input_tensor = input_tensor.permute(2, 0, 1) input_image = transforms(input_image) output = model([input_image])[0] proba_threshold = 0.5 masks = output["masks"] > proba_threshold masks = masks.squeeze(1) image_with_segmasks = draw_segmentation_masks(input_tensor, masks, alpha=0.7) return np.array(F.to_pil_image(image_with_segmasks)) default_image = Image.open("demo.jpeg") iface = gr.Interface( fn=segment_and_show, inputs=components.Image(value=default_image, sources=["upload", "clipboard"]), outputs=components.Image(type="pil"), title="Urban Autonomy Instance Segmentation Demo", description="Upload an image or use the default to see the instance segmentation model in action.", ) if __name__ == "__main__": iface.launch()