Matthijs Hollemans commited on
Commit
c304fb7
β€’
1 Parent(s): 8239775

segmentation demo

Browse files
Files changed (3) hide show
  1. README.md +2 -1
  2. app.py +43 -7
  3. requirements.txt +1 -1
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: MobileViT Deeplab Demo
3
- emoji: πŸš€
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.0.24
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: MobileViT Deeplab Demo
3
+ emoji: πŸ•
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.0.24
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,15 +1,51 @@
 
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- pipeline = pipeline(task="image-classification", model="apple/mobilevit-small")
5
 
6
  def predict(image):
7
- predictions = pipeline(image)
8
- return {p["label"]: p["score"] for p in predictions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  gr.Interface(
11
  fn=predict,
12
- inputs=gr.inputs.Image(label="Upload image", type="filepath"),
13
- outputs=gr.outputs.Label(num_top_classes=5),
14
- title="This is a title",
15
  ).launch()
 
 
 
 
 
1
+ import numpy as np
2
  import gradio as gr
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
7
+
8
+ model_checkpoint = "apple/deeplabv3-mobilevit-small"
9
+ feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint, do_center_crop=False, size=(512, 512))
10
+ model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
11
+
12
+
13
+ # From https://gist.github.com/kaixin96/457cc3d3be699f1f5b2fd4cdb638d4b4
14
+ palette = np.array([
15
+ [ 0, 0, 0], [128, 0, 0], [ 0, 128, 0], [128, 128, 0], [ 0, 0, 128],
16
+ [128, 0, 128], [ 0, 128, 128], [128, 128, 128], [ 64, 0, 0], [192, 0, 0],
17
+ [ 64, 128, 0], [192, 128, 0], [ 64, 0, 128], [192, 0, 128], [ 64, 128, 128],
18
+ [192, 128, 128], [ 0, 64, 0], [128, 64, 0], [ 0, 192, 0], [128, 192, 0],
19
+ [ 0, 64, 128]], dtype=np.uint8)
20
 
 
21
 
22
  def predict(image):
23
+ with torch.no_grad():
24
+ inputs = feature_extractor(image, return_tensors="pt")
25
+ outputs = model(**inputs)
26
+
27
+ classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)
28
+
29
+ # Super slow method but it works
30
+ colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
31
+ for y in range(classes.shape[0]):
32
+ for x in range(classes.shape[1]):
33
+ colored[y, x] = palette[classes[y, x]]
34
+
35
+ # TODO: overlay mask on image?
36
+
37
+ out_image = Image.fromarray(colored)
38
+ out_image = out_image.resize((image.shape[1], image.shape[0]), resample=Image.NEAREST)
39
+ return out_image
40
+
41
 
42
  gr.Interface(
43
  fn=predict,
44
+ inputs=gr.inputs.Image(label="Upload image"),
45
+ outputs=gr.outputs.Image(),
46
+ title="Semantic Segmentation with MobileViT and DeepLabV3",
47
  ).launch()
48
+
49
+
50
+ # TODO: combo box with some example images
51
+ # TODO: combo box with classes to show on the output, if none then do argmax
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- transformers
2
  torch
 
1
+ git+https://github.com/huggingface/transformers.git
2
  torch