File size: 2,725 Bytes
dd2ab5c
1
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: image_segmentation\n", "### Image segmentation using DETR. Takes in both an inputu image and the desired confidence, and returns a segmented image.\n", "        "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio transformers torch scipy numpy"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/image_segmentation/example_2.png"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import torch\n", "import random\n", "import numpy as np\n", "from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation\n", "\n", "device = torch.device(\"cpu\")\n", "model = MaskFormerForInstanceSegmentation.from_pretrained(\"facebook/maskformer-swin-tiny-ade\").to(device)\n", "model.eval()\n", "preprocessor = MaskFormerFeatureExtractor.from_pretrained(\"facebook/maskformer-swin-tiny-ade\")\n", "\n", "def visualize_instance_seg_mask(mask):\n", "    image = np.zeros((mask.shape[0], mask.shape[1], 3))\n", "    labels = np.unique(mask)\n", "    label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}\n", "    for i in range(image.shape[0]):\n", "      for j in range(image.shape[1]):\n", "        image[i, j, :] = label2color[mask[i, j]]\n", "    image = image / 255\n", "    return image\n", "\n", "def query_image(img):\n", "    target_size = (img.shape[0], img.shape[1])\n", "    inputs = preprocessor(images=img, return_tensors=\"pt\")\n", "    with torch.no_grad():\n", "        outputs = model(**inputs)\n", "    outputs.class_queries_logits = outputs.class_queries_logits.cpu()\n", "    outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()\n", "    results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()\n", "    results = torch.argmax(results, dim=0).numpy()\n", "    results = visualize_instance_seg_mask(results)\n", "    return results\n", "\n", "demo = gr.Interface(\n", "    query_image, \n", "    inputs=[gr.Image()], \n", "    outputs=\"image\",\n", "    title=\"MaskFormer Demo\",\n", "    examples=[[\"example_2.png\"]]\n", ")\n", "\n", "demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}