Darinnn commited on
Commit
f0d0283
·
verified ·
1 Parent(s): 0218d79

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+
7
+ # Load the segmentation pipeline
8
+ pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
9
+
10
+
11
+
12
+ # Your predefined label dictionary
13
+ label_dict = {
14
+ 0: "Background",
15
+ 1: "Hat",
16
+ 2: "Hair",
17
+ 3: "Sunglasses",
18
+ 4: "Upper-clothes",
19
+ 5: "Skirt",
20
+ 6: "Pants",
21
+ 7: "Dress",
22
+ 8: "Belt",
23
+ 9: "Left-shoe",
24
+ 10: "Right-shoe",
25
+ 11: "Face",
26
+ 12: "Left-leg",
27
+ 13: "Right-leg",
28
+ 14: "Left-arm",
29
+ 15: "Right-arm",
30
+ 16: "Bag",
31
+ 17: "Scarf",
32
+ }
33
+
34
+ # Function to process the image and generate the segmentation map
35
+
36
+
37
+ # Function to process the image and generate the segmentation map
38
+ def segment_image(image):
39
+ # Perform segmentation
40
+ result = pipe(image)
41
+
42
+ # Initialize an empty array for the segmentation map
43
+ image_width, image_height = result[0]["mask"].size
44
+ segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8)
45
+
46
+ # Combine masks into a single segmentation map
47
+ for entry in result:
48
+ label = entry["label"] # Get the label (e.g., "Hair", "Upper-clothes")
49
+ mask = np.array(entry["mask"]) # Convert PIL Image to NumPy array
50
+
51
+ # Find the index of the label in the original label dictionary
52
+ class_idx = [key for key, value in label_dict.items() if value == label][0]
53
+
54
+ # Assign the correct class index to the segmentation map
55
+ segmentation_map[mask > 0] = class_idx
56
+
57
+
58
+ # Get the unique class indices in the segmentation map
59
+ unique_classes = np.unique(segmentation_map)
60
+ # Print the names of the detected classes
61
+ print("Detected Classes:")
62
+ for class_idx in unique_classes:
63
+ print(f"- {label_dict[class_idx]}")
64
+
65
+ # Create a matplotlib figure and visualize the segmentation map
66
+ plt.figure(figsize=(8, 8))
67
+ plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap
68
+ # Get the unique class indices in the segmentation map
69
+ unique_classes = np.unique(segmentation_map)
70
+
71
+
72
+ # Filter the label dictionary to include only detected classes
73
+ filtered_labels = {idx: label_dict[idx] for idx in unique_classes}
74
+
75
+ # Create a dynamic colorbar with only the detected classes
76
+ cbar = plt.colorbar(ticks=unique_classes)
77
+ cbar.ax.set_yticklabels([filtered_labels[i] for i in unique_classes])
78
+ plt.title("Segmented Image with Detected Classes")
79
+ plt.axis("off")
80
+ plt.savefig("segmented_output.png", bbox_inches="tight")
81
+ plt.close()
82
+ return Image.open("segmented_output.png")
83
+
84
+
85
+
86
+
87
+ # Gradio interface
88
+ interface = gr.Interface(
89
+ fn=segment_image,
90
+ inputs=gr.Image(type="pil"), # Input is an image
91
+ outputs=gr.Image(type="pil"), # Output is an image with the colormap
92
+ #examples=["example_image.jpg"], # Use the saved image as an example
93
+ examples=["1.jpg", "2.jpg", "3.jpg"],
94
+ title="Clothes Segmentation with Colormap",
95
+ description="Upload an image, and the segmentation model will produce an output with a colormap applied to the segmented classes."
96
+ )
97
+
98
+ # Launch the app
99
+ interface.launch()