Spaces:
Runtime error
Runtime error
Matthijs Hollemans
commited on
Commit
β’
f1cff84
1
Parent(s):
6a36cd0
make noice
Browse files- .gitattributes +4 -0
- README.md +1 -1
- app.py +76 -10
- cat-3.jpg +3 -0
- construction-site.jpg +3 -0
- dog-cat.jpg +3 -0
- football-match.jpg +3 -0
.gitattributes
CHANGED
@@ -25,3 +25,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
cat-3.jpg filter=lfs diff=lfs merge=lfs -text
|
29 |
+
construction-site.jpg filter=lfs diff=lfs merge=lfs -text
|
30 |
+
dog-cat.jpg filter=lfs diff=lfs merge=lfs -text
|
31 |
+
football-match.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
title: MobileViT Deeplab Demo
|
3 |
emoji: π
|
4 |
colorFrom: black
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.0.24
|
8 |
app_file: app.py
|
|
|
2 |
title: MobileViT Deeplab Demo
|
3 |
emoji: π
|
4 |
colorFrom: black
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.0.24
|
8 |
app_file: app.py
|
app.py
CHANGED
@@ -5,11 +5,11 @@ from PIL import Image
|
|
5 |
import torch
|
6 |
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
|
7 |
|
|
|
8 |
model_checkpoint = "apple/deeplabv3-mobilevit-small"
|
9 |
-
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint)
|
10 |
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
|
11 |
|
12 |
-
|
13 |
palette = np.array(
|
14 |
[
|
15 |
[ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
|
@@ -21,6 +21,69 @@ palette = np.array(
|
|
21 |
],
|
22 |
dtype=np.uint8)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
def predict(image):
|
@@ -35,7 +98,7 @@ def predict(image):
|
|
35 |
# Class predictions for each pixel.
|
36 |
classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)
|
37 |
|
38 |
-
# Super slow method but it works
|
39 |
colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
|
40 |
for y in range(classes.shape[0]):
|
41 |
for x in range(classes.shape[1]):
|
@@ -43,26 +106,29 @@ def predict(image):
|
|
43 |
|
44 |
# Resize predictions to input size (not original size).
|
45 |
colored = Image.fromarray(colored)
|
46 |
-
colored = colored.resize((resized.shape[1], resized.shape[0]), resample=Image.NEAREST)
|
47 |
|
48 |
# Keep everything that is not background.
|
49 |
mask = (classes != 0) * 255
|
50 |
mask = Image.fromarray(mask.astype(np.uint8)).convert("RGB")
|
51 |
-
mask = mask.resize((resized.shape[1], resized.shape[0]), resample=Image.NEAREST)
|
52 |
|
53 |
# Blend with the input image.
|
54 |
resized = Image.fromarray(resized)
|
55 |
highlighted = Image.blend(resized, mask, 0.4)
|
56 |
|
|
|
|
|
|
|
57 |
return colored, highlighted
|
58 |
|
59 |
|
60 |
gr.Interface(
|
61 |
fn=predict,
|
62 |
inputs=gr.inputs.Image(label="Upload image"),
|
63 |
-
outputs=[gr.outputs.Image(label="Classes"), gr.outputs.Image(label="
|
64 |
-
title=
|
|
|
|
|
|
|
65 |
).launch()
|
66 |
-
|
67 |
-
|
68 |
-
# TODO: combo box with some example images
|
|
|
5 |
import torch
|
6 |
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
|
7 |
|
8 |
+
|
9 |
model_checkpoint = "apple/deeplabv3-mobilevit-small"
|
10 |
+
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint)
|
11 |
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
|
12 |
|
|
|
13 |
palette = np.array(
|
14 |
[
|
15 |
[ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
|
|
|
21 |
],
|
22 |
dtype=np.uint8)
|
23 |
|
24 |
+
labels = [
|
25 |
+
"background",
|
26 |
+
"aeroplane",
|
27 |
+
"bicycle",
|
28 |
+
"bird",
|
29 |
+
"boat",
|
30 |
+
"bottle",
|
31 |
+
"bus",
|
32 |
+
"car",
|
33 |
+
"cat",
|
34 |
+
"chair",
|
35 |
+
"cow",
|
36 |
+
"diningtable",
|
37 |
+
"dog",
|
38 |
+
"horse",
|
39 |
+
"motorbike",
|
40 |
+
"person",
|
41 |
+
"pottedplant",
|
42 |
+
"sheep",
|
43 |
+
"sofa",
|
44 |
+
"train",
|
45 |
+
"tvmonitor",
|
46 |
+
]
|
47 |
+
|
48 |
+
# Draw the labels. Light colors use black text, dark colors use white text.
|
49 |
+
inverted = [ 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20 ]
|
50 |
+
labels_colored = []
|
51 |
+
for i in range(len(labels)):
|
52 |
+
r, g, b = palette[i]
|
53 |
+
label = labels[i]
|
54 |
+
color = "white" if i in inverted else "black"
|
55 |
+
text = "<span style='background-color: rgb(%d, %d, %d); color: %s; padding: 2px 4px;'>%s</span>" % (r, g, b, color, label)
|
56 |
+
labels_colored.append(text)
|
57 |
+
labels_text = ", ".join(labels_colored)
|
58 |
+
|
59 |
+
title = "Semantic Segmentation with MobileViT and DeepLabV3"
|
60 |
+
|
61 |
+
description = """
|
62 |
+
The input image is resized and center cropped to 512Γ512 pixels. The segmentation output is 32Γ32 pixels.<br>
|
63 |
+
This model has been trained on <a href="http://host.robots.ox.ac.uk/pascal/VOC/">Pascal VOC</a>.
|
64 |
+
The classes are:
|
65 |
+
""" + labels_text + "</p>"
|
66 |
+
|
67 |
+
article = """
|
68 |
+
<div style='margin:20px auto;'>
|
69 |
+
|
70 |
+
<p>Sources:<p>
|
71 |
+
|
72 |
+
<p>π <a href="https://arxiv.org/abs/2110.02178">MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer</a></p>
|
73 |
+
|
74 |
+
<p>ποΈ Original pretrained weights from <a href="https://github.com/apple/ml-cvnets">this GitHub repo</a></p>
|
75 |
+
|
76 |
+
<p>π Example images from <a href="https://huggingface.co/datasets/mishig/sample_images">this dataset</a><p>
|
77 |
+
|
78 |
+
</div>
|
79 |
+
"""
|
80 |
+
|
81 |
+
examples = [
|
82 |
+
["cat-3.jpg"],
|
83 |
+
["construction-site.jpg"],
|
84 |
+
["dog-cat.jpg"],
|
85 |
+
["football-match.jpg"],
|
86 |
+
]
|
87 |
|
88 |
|
89 |
def predict(image):
|
|
|
98 |
# Class predictions for each pixel.
|
99 |
classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)
|
100 |
|
101 |
+
# Super slow method but it works... should probably improve this.
|
102 |
colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
|
103 |
for y in range(classes.shape[0]):
|
104 |
for x in range(classes.shape[1]):
|
|
|
106 |
|
107 |
# Resize predictions to input size (not original size).
|
108 |
colored = Image.fromarray(colored)
|
109 |
+
colored = colored.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)
|
110 |
|
111 |
# Keep everything that is not background.
|
112 |
mask = (classes != 0) * 255
|
113 |
mask = Image.fromarray(mask.astype(np.uint8)).convert("RGB")
|
114 |
+
mask = mask.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)
|
115 |
|
116 |
# Blend with the input image.
|
117 |
resized = Image.fromarray(resized)
|
118 |
highlighted = Image.blend(resized, mask, 0.4)
|
119 |
|
120 |
+
#colored = colored.resize((256, 256), resample=Image.Resampling.BICUBIC)
|
121 |
+
#highlighted = highlighted.resize((256, 256), resample=Image.Resampling.BICUBIC)
|
122 |
+
|
123 |
return colored, highlighted
|
124 |
|
125 |
|
126 |
gr.Interface(
|
127 |
fn=predict,
|
128 |
inputs=gr.inputs.Image(label="Upload image"),
|
129 |
+
outputs=[gr.outputs.Image(label="Classes"), gr.outputs.Image(label="Overlay")],
|
130 |
+
title=title,
|
131 |
+
description=description,
|
132 |
+
article=article,
|
133 |
+
examples=examples,
|
134 |
).launch()
|
|
|
|
|
|
cat-3.jpg
ADDED
Git LFS Details
|
construction-site.jpg
ADDED
Git LFS Details
|
dog-cat.jpg
ADDED
Git LFS Details
|
football-match.jpg
ADDED
Git LFS Details
|