Matthijs Hollemans commited on
Commit
f1cff84
β€’
1 Parent(s): 6a36cd0

make noice

Browse files
Files changed (7) hide show
  1. .gitattributes +4 -0
  2. README.md +1 -1
  3. app.py +76 -10
  4. cat-3.jpg +3 -0
  5. construction-site.jpg +3 -0
  6. dog-cat.jpg +3 -0
  7. football-match.jpg +3 -0
.gitattributes CHANGED
@@ -25,3 +25,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ cat-3.jpg filter=lfs diff=lfs merge=lfs -text
29
+ construction-site.jpg filter=lfs diff=lfs merge=lfs -text
30
+ dog-cat.jpg filter=lfs diff=lfs merge=lfs -text
31
+ football-match.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -2,7 +2,7 @@
2
  title: MobileViT Deeplab Demo
3
  emoji: πŸ•
4
  colorFrom: black
5
- colorTo: black
6
  sdk: gradio
7
  sdk_version: 3.0.24
8
  app_file: app.py
 
2
  title: MobileViT Deeplab Demo
3
  emoji: πŸ•
4
  colorFrom: black
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.0.24
8
  app_file: app.py
app.py CHANGED
@@ -5,11 +5,11 @@ from PIL import Image
5
  import torch
6
  from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
7
 
 
8
  model_checkpoint = "apple/deeplabv3-mobilevit-small"
9
- feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint) #, do_center_crop=False, size=(512, 512))
10
  model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
11
 
12
-
13
  palette = np.array(
14
  [
15
  [ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
@@ -21,6 +21,69 @@ palette = np.array(
21
  ],
22
  dtype=np.uint8)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  def predict(image):
@@ -35,7 +98,7 @@ def predict(image):
35
  # Class predictions for each pixel.
36
  classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)
37
 
38
- # Super slow method but it works
39
  colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
40
  for y in range(classes.shape[0]):
41
  for x in range(classes.shape[1]):
@@ -43,26 +106,29 @@ def predict(image):
43
 
44
  # Resize predictions to input size (not original size).
45
  colored = Image.fromarray(colored)
46
- colored = colored.resize((resized.shape[1], resized.shape[0]), resample=Image.NEAREST)
47
 
48
  # Keep everything that is not background.
49
  mask = (classes != 0) * 255
50
  mask = Image.fromarray(mask.astype(np.uint8)).convert("RGB")
51
- mask = mask.resize((resized.shape[1], resized.shape[0]), resample=Image.NEAREST)
52
 
53
  # Blend with the input image.
54
  resized = Image.fromarray(resized)
55
  highlighted = Image.blend(resized, mask, 0.4)
56
 
 
 
 
57
  return colored, highlighted
58
 
59
 
60
  gr.Interface(
61
  fn=predict,
62
  inputs=gr.inputs.Image(label="Upload image"),
63
- outputs=[gr.outputs.Image(label="Classes"), gr.outputs.Image(label="Highlighted")],
64
- title="Semantic Segmentation with MobileViT and DeepLabV3",
 
 
 
65
  ).launch()
66
-
67
-
68
- # TODO: combo box with some example images
 
5
  import torch
6
  from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
7
 
8
+
9
  model_checkpoint = "apple/deeplabv3-mobilevit-small"
10
+ feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint)
11
  model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
12
 
 
13
  palette = np.array(
14
  [
15
  [ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
 
21
  ],
22
  dtype=np.uint8)
23
 
24
+ labels = [
25
+ "background",
26
+ "aeroplane",
27
+ "bicycle",
28
+ "bird",
29
+ "boat",
30
+ "bottle",
31
+ "bus",
32
+ "car",
33
+ "cat",
34
+ "chair",
35
+ "cow",
36
+ "diningtable",
37
+ "dog",
38
+ "horse",
39
+ "motorbike",
40
+ "person",
41
+ "pottedplant",
42
+ "sheep",
43
+ "sofa",
44
+ "train",
45
+ "tvmonitor",
46
+ ]
47
+
48
+ # Draw the labels. Light colors use black text, dark colors use white text.
49
+ inverted = [ 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20 ]
50
+ labels_colored = []
51
+ for i in range(len(labels)):
52
+ r, g, b = palette[i]
53
+ label = labels[i]
54
+ color = "white" if i in inverted else "black"
55
+ text = "<span style='background-color: rgb(%d, %d, %d); color: %s; padding: 2px 4px;'>%s</span>" % (r, g, b, color, label)
56
+ labels_colored.append(text)
57
+ labels_text = ", ".join(labels_colored)
58
+
59
+ title = "Semantic Segmentation with MobileViT and DeepLabV3"
60
+
61
+ description = """
62
+ The input image is resized and center cropped to 512Γ—512 pixels. The segmentation output is 32Γ—32 pixels.<br>
63
+ This model has been trained on <a href="http://host.robots.ox.ac.uk/pascal/VOC/">Pascal VOC</a>.
64
+ The classes are:
65
+ """ + labels_text + "</p>"
66
+
67
+ article = """
68
+ <div style='margin:20px auto;'>
69
+
70
+ <p>Sources:<p>
71
+
72
+ <p>πŸ“œ <a href="https://arxiv.org/abs/2110.02178">MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer</a></p>
73
+
74
+ <p>πŸ‹οΈ Original pretrained weights from <a href="https://github.com/apple/ml-cvnets">this GitHub repo</a></p>
75
+
76
+ <p>πŸ™ Example images from <a href="https://huggingface.co/datasets/mishig/sample_images">this dataset</a><p>
77
+
78
+ </div>
79
+ """
80
+
81
+ examples = [
82
+ ["cat-3.jpg"],
83
+ ["construction-site.jpg"],
84
+ ["dog-cat.jpg"],
85
+ ["football-match.jpg"],
86
+ ]
87
 
88
 
89
  def predict(image):
 
98
  # Class predictions for each pixel.
99
  classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)
100
 
101
+ # Super slow method but it works... should probably improve this.
102
  colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
103
  for y in range(classes.shape[0]):
104
  for x in range(classes.shape[1]):
 
106
 
107
  # Resize predictions to input size (not original size).
108
  colored = Image.fromarray(colored)
109
+ colored = colored.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)
110
 
111
  # Keep everything that is not background.
112
  mask = (classes != 0) * 255
113
  mask = Image.fromarray(mask.astype(np.uint8)).convert("RGB")
114
+ mask = mask.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)
115
 
116
  # Blend with the input image.
117
  resized = Image.fromarray(resized)
118
  highlighted = Image.blend(resized, mask, 0.4)
119
 
120
+ #colored = colored.resize((256, 256), resample=Image.Resampling.BICUBIC)
121
+ #highlighted = highlighted.resize((256, 256), resample=Image.Resampling.BICUBIC)
122
+
123
  return colored, highlighted
124
 
125
 
126
  gr.Interface(
127
  fn=predict,
128
  inputs=gr.inputs.Image(label="Upload image"),
129
+ outputs=[gr.outputs.Image(label="Classes"), gr.outputs.Image(label="Overlay")],
130
+ title=title,
131
+ description=description,
132
+ article=article,
133
+ examples=examples,
134
  ).launch()
 
 
 
cat-3.jpg ADDED

Git LFS Details

  • SHA256: ca00d7f8f53f03232185c70418d875bc98adfeb7d42238d7f01e2926ecafb3b2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
construction-site.jpg ADDED

Git LFS Details

  • SHA256: 2dec9e542ab0b1ac51894535014bd06c8392eb4da41ffcfa326b1188b6ce8762
  • Pointer size: 130 Bytes
  • Size of remote file: 92.7 kB
dog-cat.jpg ADDED

Git LFS Details

  • SHA256: e952088e64c1cf3f270137ba38648d2218e138e6c872dd0c8f80497d247d0536
  • Pointer size: 130 Bytes
  • Size of remote file: 99.5 kB
football-match.jpg ADDED

Git LFS Details

  • SHA256: d65b6f72943d5e2d4f7e5e4dedfb93aea0fbbda140ae7c3ee772124b579e07c4
  • Pointer size: 130 Bytes
  • Size of remote file: 55.6 kB