| import gradio as gr |
| from transformers import pipeline |
| from PIL import Image |
| import numpy as np |
| import matplotlib.pyplot as plt |
|
|
| |
| pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") |
|
|
|
|
|
|
| |
| label_dict = { |
| 0: "Background", |
| 1: "Hat", |
| 2: "Hair", |
| 3: "Sunglasses", |
| 4: "Upper-clothes", |
| 5: "Skirt", |
| 6: "Pants", |
| 7: "Dress", |
| 8: "Belt", |
| 9: "Left-shoe", |
| 10: "Right-shoe", |
| 11: "Face", |
| 12: "Left-leg", |
| 13: "Right-leg", |
| 14: "Left-arm", |
| 15: "Right-arm", |
| 16: "Bag", |
| 17: "Scarf", |
| } |
|
|
| |
|
|
|
|
| |
| def segment_image(image): |
| |
| result = pipe(image) |
|
|
| |
| image_width, image_height = result[0]["mask"].size |
| segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8) |
|
|
| |
| for entry in result: |
| label = entry["label"] |
| mask = np.array(entry["mask"]) |
| |
| |
| class_idx = [key for key, value in label_dict.items() if value == label][0] |
| |
| |
| segmentation_map[mask > 0] = class_idx |
| |
| |
| |
| unique_classes = np.unique(segmentation_map) |
| |
| print("Detected Classes:") |
| for class_idx in unique_classes: |
| print(f"- {label_dict[class_idx]}") |
|
|
| |
| plt.figure(figsize=(8, 8)) |
| plt.imshow(segmentation_map, cmap="tab20") |
| |
| unique_classes = np.unique(segmentation_map) |
|
|
|
|
| |
| filtered_labels = {idx: label_dict[idx] for idx in unique_classes} |
|
|
| |
| cbar = plt.colorbar(ticks=unique_classes) |
| cbar.ax.set_yticklabels([filtered_labels[i] for i in unique_classes]) |
| plt.title("Segmented Image with Detected Classes") |
| plt.axis("off") |
| plt.savefig("segmented_output.png", bbox_inches="tight") |
| plt.close() |
| return Image.open("segmented_output.png") |
|
|
|
|
|
|
|
|
| |
| interface = gr.Interface( |
| fn=segment_image, |
| inputs=gr.Image(type="pil"), |
| outputs=gr.Image(type="pil"), |
| |
| examples=["1.jpg", "2.jpg", "3.jpg"], |
| title="Clothes Segmentation with Colormap", |
| description="Upload an image, and the segmentation model will produce an output with a colormap applied to the segmented classes." |
| ) |
|
|
| |
| interface.launch() |
|
|