Spaces:
Runtime error
Runtime error
import gradio as gr | |
import glob | |
import torch | |
import pickle | |
from PIL import Image, ImageDraw | |
import numpy as np | |
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation | |
from scipy.ndimage import center_of_mass | |
def combine_ims(im1, im2, val=128): | |
p = Image.new("L", im1.size, val) | |
im = Image.composite(im1, im2, p) | |
return im | |
def get_class_centers(segmentation_mask, class_dict): | |
segmentation_mask = segmentation_mask.numpy() + 1 | |
class_centers = {} | |
for class_index, _ in class_dict.items(): | |
class_mask = (segmentation_mask == class_index).astype(int) | |
center_of_mass_list = center_of_mass(class_mask) | |
class_centers[class_index] = center_of_mass_list | |
class_centers = {k: list(map(int, v)) for k, v in class_centers.items() if not np.isnan(sum(v))} | |
return class_centers | |
def visualize_mask(predicted_semantic_map, class_ids, class_colors): | |
h, w = predicted_semantic_map.shape | |
color_indexes = np.zeros((h, w), dtype=np.uint8) | |
color_indexes[:] = predicted_semantic_map.numpy() | |
color_indexes = color_indexes.flatten() | |
colors = class_colors[class_ids[color_indexes]] | |
output = colors.reshape(h, w, 3).astype(np.uint8) | |
image_mask = Image.fromarray(output) | |
return image_mask | |
def get_out_image(image, predicted_semantic_map): | |
class_centers = get_class_centers(predicted_semantic_map, class_dict) | |
mask = visualize_mask(predicted_semantic_map, class_ids, class_colors) | |
image_mask = combine_ims(image, mask, val=128) | |
draw = ImageDraw.Draw(image_mask) | |
extracted_tags = [] | |
for id, (y, x) in class_centers.items(): | |
class_name = str(class_names[id - 1]) | |
color = class_colors[id - 1] | |
color_hex = "#{:02x}{:02x}{:02x}".format(*color) # Convert RGB to hex | |
symbol = "●" # You can choose any symbol you like | |
tag_info = f"{symbol} [{color_hex}] {class_name}" | |
extracted_tags.append(tag_info) | |
draw.text((x, y), class_name, fill='black') | |
# Joining all tags into a single string, each tag on a new line | |
tags_string = "\n".join(extracted_tags) | |
return image_mask, tags_string | |
def gradio_process(image): | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
out_image, extracted_tags = get_out_image(image, predicted_semantic_map) | |
return out_image, extracted_tags | |
with open('ade20k_classes.pickle', 'rb') as f: | |
class_names, class_ids, class_colors = pickle.load(f) | |
class_names, class_ids, class_colors = np.array(class_names), np.array(class_ids), np.array(class_colors) | |
class_dict = dict(zip(class_ids, class_names)) | |
device = torch.device("cpu") | |
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic") | |
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic").to(device) | |
model.eval() | |
demo = gr.Interface( | |
gradio_process, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs=[gr.outputs.Image(type="pil"), gr.outputs.Textbox()], | |
title="Semantic Segmentation", | |
examples=glob.glob('./examples/*.jpg'), | |
allow_flagging="never", | |
) | |
demo.launch() |