Spaces:
Runtime error
Runtime error
import logging | |
from typing import List, Tuple, Dict | |
import streamlit as st | |
import torch | |
import gc | |
import numpy as np | |
from PIL import Image | |
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation | |
from palette import ade_palette | |
LOGGING = logging.getLogger(__name__) | |
def flush(): | |
gc.collect() | |
torch.cuda.empty_cache() | |
def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: | |
"""Method to load the segmentation pipeline | |
Returns: | |
Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline | |
""" | |
image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") | |
image_segmentor = UperNetForSemanticSegmentation.from_pretrained( | |
"openmmlab/upernet-convnext-small") | |
return image_processor, image_segmentor | |
def segment_image(image: Image) -> Image: | |
"""Method to segment image | |
Args: | |
image (Image): input image | |
Returns: | |
Image: segmented image | |
""" | |
image_processor, image_segmentor = get_segmentation_pipeline() | |
pixel_values = image_processor(image, return_tensors="pt").pixel_values | |
with torch.no_grad(): | |
outputs = image_segmentor(pixel_values) | |
seg = image_processor.post_process_semantic_segmentation( | |
outputs, target_sizes=[image.size[::-1]])[0] | |
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) | |
palette = np.array(ade_palette()) | |
for label, color in enumerate(palette): | |
color_seg[seg == label, :] = color | |
color_seg = color_seg.astype(np.uint8) | |
seg_image = Image.fromarray(color_seg).convert('RGB') | |
return seg_image |