import torch import random import gradio as gr import numpy as np from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation # Use GPU if available if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device) model.eval() preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade") def visualize_instance_seg_mask(mask): # Initialize image image = np.zeros((mask.shape[0], mask.shape[1], 3)) labels = np.unique(mask) label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels} for i in range(image.shape[0]): for j in range(image.shape[1]): image[i, j, :] = label2color[mask[i, j]] image = image / 255 return image def query_image(img): target_size = (img.shape[0], img.shape[1]) inputs = preprocessor(images=img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) outputs.class_queries_logits = outputs.class_queries_logits.cpu() outputs.masks_queries_logits = outputs.masks_queries_logits.cpu() results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach() results = torch.argmax(results, dim=0).numpy() results = visualize_instance_seg_mask(results) return results description = """ Gradio demo for MaskFormer, introduced in Per-Pixel Classification is Not All You Need for Semantic Segmentation . \n\n"MaskFormer is a unified framework for panoptic, instance and semantic segmentation, trained across four popular datasets (ADE20K, Cityscapes, COCO, Mapillary Vistas). """ demo = gr.Interface( query_image, inputs=[gr.Image()], outputs="image", title="MaskFormer Demo", description=description, examples=["assets/test_image_35.png", "assets/test_image_82.png"] ) demo.launch()