{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: image_segmentation\n", "### Image segmentation using DETR. Takes in both an inputu image and the desired confidence, and returns a segmented image.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio transformers torch scipy numpy"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/image_segmentation/example_2.png"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import torch\n", "import random\n", "import numpy as np\n", "from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation\n", "\n", "device = torch.device(\"cpu\")\n", "model = MaskFormerForInstanceSegmentation.from_pretrained(\"facebook/maskformer-swin-tiny-ade\").to(device)\n", "model.eval()\n", "preprocessor = MaskFormerFeatureExtractor.from_pretrained(\"facebook/maskformer-swin-tiny-ade\")\n", "\n", "def visualize_instance_seg_mask(mask):\n", " image = np.zeros((mask.shape[0], mask.shape[1], 3))\n", " labels = np.unique(mask)\n", " label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}\n", " for i in range(image.shape[0]):\n", " for j in range(image.shape[1]):\n", " image[i, j, :] = label2color[mask[i, j]]\n", " image = image / 255\n", " return image\n", "\n", "def query_image(img):\n", " target_size = (img.shape[0], img.shape[1])\n", " inputs = preprocessor(images=img, return_tensors=\"pt\")\n", " with torch.no_grad():\n", " outputs = model(**inputs)\n", " outputs.class_queries_logits = outputs.class_queries_logits.cpu()\n", " outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()\n", " results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()\n", " results = torch.argmax(results, dim=0).numpy()\n", " results = visualize_instance_seg_mask(results)\n", " return results\n", "\n", "demo = gr.Interface(\n", " query_image, \n", " inputs=[gr.Image()], \n", " outputs=\"image\",\n", " title=\"MaskFormer Demo\",\n", " examples=[[\"example_2.png\"]]\n", ")\n", "\n", "demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}