File size: 1,761 Bytes
dd0ab9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0819ae2
dd0ab9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import logging
from typing import List, Tuple, Dict

import streamlit as st
import torch
import gc
import numpy as np
from PIL import Image

from transformers import AutoImageProcessor, UperNetForSemanticSegmentation

from palette import ade_palette

LOGGING = logging.getLogger(__name__)


def flush():
    gc.collect()
    torch.cuda.empty_cache()

@st.cache_resource(max_entries=5)
def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
    """Method to load the segmentation pipeline
    Returns:
        Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline
    """
    image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
    image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
        "openmmlab/upernet-convnext-small")
    return image_processor, image_segmentor


@torch.inference_mode()
@torch.autocast('cuda')
def segment_image(image: Image) -> Image:
    """Method to segment image
    Args:
        image (Image): input image
    Returns:
        Image: segmented image
    """
    image_processor, image_segmentor = get_segmentation_pipeline()
    pixel_values = image_processor(image, return_tensors="pt").pixel_values
    with torch.no_grad():
        outputs = image_segmentor(pixel_values)

    seg = image_processor.post_process_semantic_segmentation(
        outputs, target_sizes=[image.size[::-1]])[0]
    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
    palette = np.array(ade_palette())
    for label, color in enumerate(palette):
        color_seg[seg == label, :] = color
    color_seg = color_seg.astype(np.uint8)
    seg_image = Image.fromarray(color_seg).convert('RGB')
    return seg_image