lakxs commited on
Commit
26b8a9d
1 Parent(s): b247a09

Upload 10 files

Browse files
Files changed (10) hide show
  1. .gitattributes +3 -11
  2. README.md +5 -5
  3. app.py +133 -0
  4. labels.txt +18 -0
  5. person-1.jpg +0 -0
  6. person-2.jpg +0 -0
  7. person-3.jpg +0 -0
  8. person-4.jpg +0 -0
  9. person-5.jpg +0 -0
  10. requirements.txt +6 -0
.gitattributes CHANGED
@@ -1,35 +1,27 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
 
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Mymy
3
- emoji: 📊
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 4.3.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: segformer-b5-finetuned-cityscapes-1024-1024
3
+ emoji: 💻
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.44.4
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
2
+ # from PIL import Image
3
+ # import requests
4
+ #
5
+ # feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")
6
+ # model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")
7
+ #
8
+ # url = "http://images.cocodataset.org/val2017/000000039769.jpg"
9
+ # image = Image.open(requests.get(url, stream=True).raw)
10
+ #
11
+ # inputs = feature_extractor(images=image, return_tensors="pt")
12
+ # outputs = model(**inputs)
13
+ # logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
14
+ import gradio as gr
15
+
16
+ from matplotlib import gridspec
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+ from PIL import Image
20
+ import tensorflow as tf
21
+ from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
22
+
23
+ feature_extractor = SegformerFeatureExtractor.from_pretrained(
24
+ "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
25
+ )
26
+ model = TFSegformerForSemanticSegmentation.from_pretrained(
27
+ "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
28
+ )
29
+
30
+ def ade_palette():
31
+ """ADE20K palette that maps each class to RGB values."""
32
+ return [
33
+ [204, 87, 92],
34
+ [112, 185, 212],
35
+ [45, 189, 106],
36
+ [234, 123, 67],
37
+ [78, 56, 123],
38
+ [210, 32, 89],
39
+ [90, 180, 56],
40
+ [155, 102, 200],
41
+ [33, 147, 176],
42
+ [255, 183, 76],
43
+ [67, 123, 89],
44
+ [190, 60, 45],
45
+ [134, 112, 200],
46
+ [56, 45, 189],
47
+ [200, 56, 123],
48
+ [87, 92, 204],
49
+ [120, 56, 123],
50
+ [45, 78, 123]
51
+ ]
52
+
53
+ labels_list = []
54
+
55
+ with open(r'labels.txt', 'r') as fp:
56
+ for line in fp:
57
+ labels_list.append(line[:-1])
58
+
59
+ colormap = np.asarray(ade_palette())
60
+
61
+ def label_to_color_image(label):
62
+ if label.ndim != 2:
63
+ raise ValueError("Expect 2-D input label")
64
+
65
+ if np.max(label) >= len(colormap):
66
+ raise ValueError("label value too large.")
67
+ return colormap[label]
68
+
69
+ def draw_plot(pred_img, seg):
70
+ fig = plt.figure(figsize=(20, 15))
71
+
72
+ grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
73
+
74
+ plt.subplot(grid_spec[0])
75
+ plt.imshow(pred_img)
76
+ plt.axis('off')
77
+ LABEL_NAMES = np.asarray(labels_list)
78
+ FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
79
+ FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
80
+
81
+ unique_labels = np.unique(seg.numpy().astype("uint8"))
82
+ ax = plt.subplot(grid_spec[1])
83
+ plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
84
+ ax.yaxis.tick_right()
85
+ plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
86
+ plt.xticks([], [])
87
+ ax.tick_params(width=0.0, labelsize=25)
88
+ return fig
89
+
90
+ def sepia(input_img):
91
+ input_img = Image.fromarray(input_img)
92
+
93
+ inputs = feature_extractor(images=input_img, return_tensors="tf")
94
+ outputs = model(**inputs)
95
+ logits = outputs.logits
96
+
97
+ logits = tf.transpose(logits, [0, 2, 3, 1])
98
+ logits = tf.image.resize(
99
+ logits, input_img.size[::-1]
100
+ ) # We reverse the shape of `image` because `image.size` returns width and height.
101
+ seg = tf.math.argmax(logits, axis=-1)[0]
102
+
103
+ color_seg = np.zeros(
104
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
105
+ ) # height, width, 3
106
+ for label, color in enumerate(colormap):
107
+ color_seg[seg.numpy() == label, :] = color
108
+
109
+ # Show image + mask
110
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
111
+ pred_img = pred_img.astype(np.uint8)
112
+
113
+ fig = draw_plot(pred_img, seg)
114
+ return fig
115
+
116
+
117
+ # demo = gr.Interface(fn=sepia,
118
+ # inputs=gr.Image(shape=(400, 600)),
119
+ # outputs=['plot'],
120
+ # examples=["person-1", "person-2", "person-3", "person-4", "person-5"],
121
+ # allow_flagging='never')
122
+ demo = gr.Interface(fn=sepia,
123
+ inputs=gr.Image(), # Remove the 'shape' argument here
124
+ outputs=['plot'],
125
+ examples=[
126
+ "person-1.jpg",
127
+ "person-2.jpg",
128
+ "person-3.jpg",
129
+ "person-4.jpg",
130
+ "person-5.jpg"
131
+ ],
132
+ allow_flagging='never')
133
+
labels.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Background
2
+ Hat
3
+ Hair
4
+ Sunglasses
5
+ Upper-clothes
6
+ Skirt
7
+ Pants
8
+ Dress
9
+ Belt
10
+ Left-shoe
11
+ Right-shoe
12
+ Face
13
+ Left-leg
14
+ Right-leg
15
+ Left-arm
16
+ Right-arm
17
+ Bag
18
+ Scarf
person-1.jpg ADDED
person-2.jpg ADDED
person-3.jpg ADDED
person-4.jpg ADDED
person-5.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ tensorflow
4
+ numpy
5
+ Image
6
+ matplotlib