import torch import random import gradio as gr import numpy as np from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation # Use GPU if available if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device) model.eval() preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade") def visualize_instance_seg_mask(mask): # Initialize image image = np.zeros((mask.shape[0], mask.shape[1], 3)) labels = np.unique(mask) label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels} for i in range(image.shape[0]): for j in range(image.shape[1]): image[i, j, :] = label2color[mask[i, j]] image = image / 255 return image def query_image(img): target_size = (img.shape[0], img.shape[1]) inputs = preprocessor(images=img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) outputs.class_queries_logits = outputs.class_queries_logits.cpu() outputs.masks_queries_logits = outputs.masks_queries_logits.cpu() results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach() results = torch.argmax(results, dim=0).numpy() results = visualize_instance_seg_mask(results) return results description = """ Gradio demo for MaskFormer, introduced in Per-Pixel Classification is Not All You Need for Semantic Segmentation . \n\n"Mask2Former is a unified framework architecture based on MaskFormer meta-architecture that achieves SOTA on panoptic, instance and semantic segmentation across four popular datasets (ADE20K, Cityscapes, COCO, Mapillary Vistas). You can use MaskFormer for semantic, instance (illustrated in the demo) and panoptic segmentation. """ demo = gr.Interface( query_image, inputs=[gr.Image()], outputs="image", title="MaskFormer Demo", description=description, examples=["assets/test_image_35.png", "assets/test_image_82.png"] ) demo.launch()