|
import gradio as gr |
|
from transformers import pipeline |
|
from PIL import Image |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") |
|
|
|
|
|
|
|
|
|
label_dict = { |
|
0: "Background", |
|
1: "Hat", |
|
2: "Hair", |
|
3: "Sunglasses", |
|
4: "Upper-clothes", |
|
5: "Skirt", |
|
6: "Pants", |
|
7: "Dress", |
|
8: "Belt", |
|
9: "Left-shoe", |
|
10: "Right-shoe", |
|
11: "Face", |
|
12: "Left-leg", |
|
13: "Right-leg", |
|
14: "Left-arm", |
|
15: "Right-arm", |
|
16: "Bag", |
|
17: "Scarf", |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def segment_image(image): |
|
|
|
result = pipe(image) |
|
|
|
|
|
image_width, image_height = result[0]["mask"].size |
|
segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8) |
|
|
|
|
|
for entry in result: |
|
label = entry["label"] |
|
mask = np.array(entry["mask"]) |
|
|
|
|
|
class_idx = [key for key, value in label_dict.items() if value == label][0] |
|
|
|
|
|
segmentation_map[mask > 0] = class_idx |
|
|
|
|
|
|
|
unique_classes = np.unique(segmentation_map) |
|
|
|
print("Detected Classes:") |
|
for class_idx in unique_classes: |
|
print(f"- {label_dict[class_idx]}") |
|
|
|
|
|
plt.figure(figsize=(8, 8)) |
|
plt.imshow(segmentation_map, cmap="tab20") |
|
|
|
unique_classes = np.unique(segmentation_map) |
|
|
|
|
|
|
|
filtered_labels = {idx: label_dict[idx] for idx in unique_classes} |
|
|
|
|
|
cbar = plt.colorbar(ticks=unique_classes) |
|
cbar.ax.set_yticklabels([filtered_labels[i] for i in unique_classes]) |
|
plt.title("Segmented Image with Detected Classes") |
|
plt.axis("off") |
|
plt.savefig("segmented_output.png", bbox_inches="tight") |
|
plt.close() |
|
return Image.open("segmented_output.png") |
|
|
|
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=segment_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Image(type="pil"), |
|
|
|
examples=["1.jpg", "2.jpg", "3.jpg"], |
|
title="Clothes Segmentation with Colormap", |
|
description="Upload an image, and the segmentation model will produce an output with a colormap applied to the segmented classes." |
|
) |
|
|
|
|
|
interface.launch() |
|
|