Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# Load the segmentation pipeline | |
pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") | |
# Your predefined label dictionary | |
label_dict = { | |
0: "Background", | |
1: "Hat", | |
2: "Hair", | |
3: "Sunglasses", | |
4: "Upper-clothes", | |
5: "Skirt", | |
6: "Pants", | |
7: "Dress", | |
8: "Belt", | |
9: "Left-shoe", | |
10: "Right-shoe", | |
11: "Face", | |
12: "Left-leg", | |
13: "Right-leg", | |
14: "Left-arm", | |
15: "Right-arm", | |
16: "Bag", | |
17: "Scarf", | |
} | |
# Function to process the image and generate the segmentation map | |
# Function to process the image and generate the segmentation map | |
def segment_image(image): | |
# Perform segmentation | |
result = pipe(image) | |
# Initialize an empty array for the segmentation map | |
image_width, image_height = result[0]["mask"].size | |
segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8) | |
# Combine masks into a single segmentation map | |
for entry in result: | |
label = entry["label"] # Get the label (e.g., "Hair", "Upper-clothes") | |
mask = np.array(entry["mask"]) # Convert PIL Image to NumPy array | |
# Find the index of the label in the original label dictionary | |
class_idx = [key for key, value in label_dict.items() if value == label][0] | |
# Assign the correct class index to the segmentation map | |
segmentation_map[mask > 0] = class_idx | |
# Get the unique class indices in the segmentation map | |
unique_classes = np.unique(segmentation_map) | |
# Print the names of the detected classes | |
print("Detected Classes:") | |
for class_idx in unique_classes: | |
print(f"- {label_dict[class_idx]}") | |
# Create a matplotlib figure and visualize the segmentation map | |
plt.figure(figsize=(8, 8)) | |
plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap | |
# Get the unique class indices in the segmentation map | |
unique_classes = np.unique(segmentation_map) | |
# Filter the label dictionary to include only detected classes | |
filtered_labels = {idx: label_dict[idx] for idx in unique_classes} | |
# Create a dynamic colorbar with only the detected classes | |
cbar = plt.colorbar(ticks=unique_classes) | |
cbar.ax.set_yticklabels([filtered_labels[i] for i in unique_classes]) | |
plt.title("Segmented Image with Detected Classes") | |
plt.axis("off") | |
plt.savefig("segmented_output.png", bbox_inches="tight") | |
plt.close() | |
return Image.open("segmented_output.png") | |
# Gradio interface | |
interface = gr.Interface( | |
fn=segment_image, | |
inputs=gr.Image(type="pil"), # Input is an image | |
outputs=gr.Image(type="pil"), # Output is an image with the colormap | |
#examples=["example_image.jpg"], # Use the saved image as an example | |
examples=["1.jpg", "2.jpg", "3.jpg"], | |
title="Clothes Segmentation with Colormap", | |
description="Upload an image, and the segmentation model will produce an output with a colormap applied to the segmented classes." | |
) | |
# Launch the app | |
interface.launch() | |