import logging from typing import List, Tuple, Dict import streamlit as st import torch import gc import numpy as np from PIL import Image from transformers import AutoImageProcessor, UperNetForSemanticSegmentation from palette import ade_palette LOGGING = logging.getLogger(__name__) def flush(): gc.collect() torch.cuda.empty_cache() @st.experimental_singleton(max_entries=5) def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: """Method to load the segmentation pipeline Returns: Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline """ image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") image_segmentor = UperNetForSemanticSegmentation.from_pretrained( "openmmlab/upernet-convnext-small") return image_processor, image_segmentor @torch.inference_mode() @torch.autocast('cuda') def segment_image(image: Image) -> Image: """Method to segment image Args: image (Image): input image Returns: Image: segmented image """ image_processor, image_segmentor = get_segmentation_pipeline() pixel_values = image_processor(image, return_tensors="pt").pixel_values with torch.no_grad(): outputs = image_segmentor(pixel_values) seg = image_processor.post_process_semantic_segmentation( outputs, target_sizes=[image.size[::-1]])[0] color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) palette = np.array(ade_palette()) for label, color in enumerate(palette): color_seg[seg == label, :] = color color_seg = color_seg.astype(np.uint8) seg_image = Image.fromarray(color_seg).convert('RGB') return seg_image