johnnv commited on
Commit
964b7b7
1 Parent(s): 4fb48aa

update app

Browse files
Files changed (3) hide show
  1. README.md +4 -4
  2. app.py +126 -14
  3. requirements.txt +2 -0
README.md CHANGED
@@ -2,19 +2,19 @@
2
  app_file: app.py
3
  colorFrom: red
4
  colorTo: green
5
- datasets:
6
  - lapix/CCAgT
7
  emoji: 💻
8
  license: mit
9
- models:
10
  - lapix/segformer-b3-finetuned-ccagt-400-300
11
  pinned: true
12
  sdk: gradio
13
  sdk_version: "3.3.1"
14
- tags:
15
  - vision
16
  - image-segmentation
17
- task_ids:
18
  - semantic-segmentation
19
  title: "SegFormer B3 CCAgT"
20
  ---
 
2
  app_file: app.py
3
  colorFrom: red
4
  colorTo: green
5
+ datasets:
6
  - lapix/CCAgT
7
  emoji: 💻
8
  license: mit
9
+ models:
10
  - lapix/segformer-b3-finetuned-ccagt-400-300
11
  pinned: true
12
  sdk: gradio
13
  sdk_version: "3.3.1"
14
+ tags:
15
  - vision
16
  - image-segmentation
17
+ task_ids:
18
  - semantic-segmentation
19
  title: "SegFormer B3 CCAgT"
20
  ---
app.py CHANGED
@@ -1,13 +1,21 @@
 
1
  import gradio as gr
 
 
2
  import numpy as np
3
  import torch
 
 
4
  from CCAgT_utils.types.mask import Mask
 
5
  from PIL import Image
6
  from torch import nn
7
  from transformers import SegformerFeatureExtractor
8
  from transformers import SegformerForSemanticSegmentation
 
9
 
10
 
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
  model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300'
@@ -15,37 +23,130 @@ model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300'
15
  model = SegformerForSemanticSegmentation.from_pretrained(
16
  model_hub_name,
17
  ).to(device)
 
 
18
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
19
  model_hub_name,
20
  )
21
 
22
 
23
- def query_image(image):
24
- image = np.array(image)
25
- img = Image.fromarray(image)
26
-
27
- pixel_values = feature_extractor(
28
  image,
29
  return_tensors='pt',
30
  ).to(device)
31
 
32
- with torch.no_grad():
33
- outputs = model(pixel_values=pixel_values)
 
34
 
35
- logits = outputs.logits
 
 
 
 
 
36
 
37
  upsampled_logits = nn.functional.interpolate(
38
  logits,
39
- size=img.size[::-1], # (height, width)
40
  mode='bilinear',
41
  align_corners=False,
42
  )
43
 
44
  segmentation_mask = upsampled_logits.argmax(dim=1)[0]
45
 
46
- results = Mask(segmentation_mask).colorized() / 255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  title = 'SegFormer (b3) - CCAgT dataset'
@@ -59,15 +160,26 @@ images with resolution of 400x300. The model was available at HF hub at
59
  examples = [
60
  [f'https://hf.co/{model_hub_name}/resolve/main/sampleA.png'],
61
  [f'https://hf.co/{model_hub_name}/resolve/main/sampleB.png'],
 
 
 
62
  ]
63
 
 
64
  demo = gr.Interface(
65
- query_image,
66
  inputs=[gr.Image()],
67
- outputs='image',
 
 
 
 
68
  title=title,
69
  description=description,
70
  examples=examples,
 
 
71
  )
72
 
73
- demo.launch()
 
 
1
+ import cv2
2
  import gradio as gr
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
  import numpy as np
6
  import torch
7
+ from CCAgT_utils.categories import CategoriesInfos
8
+ from CCAgT_utils.slice import __create_xy_slice
9
  from CCAgT_utils.types.mask import Mask
10
+ from CCAgT_utils.visualization import plot
11
  from PIL import Image
12
  from torch import nn
13
  from transformers import SegformerFeatureExtractor
14
  from transformers import SegformerForSemanticSegmentation
15
+ from transformers.modeling_outputs import SemanticSegmenterOutput
16
 
17
 
18
+ matplotlib.use('Agg')
19
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
 
21
  model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300'
 
23
  model = SegformerForSemanticSegmentation.from_pretrained(
24
  model_hub_name,
25
  ).to(device)
26
+ model.eval()
27
+
28
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
29
  model_hub_name,
30
  )
31
 
32
 
33
+ def segment(
34
+ image: Image.Image,
35
+ ) -> SemanticSegmenterOutput:
36
+ inputs = feature_extractor(
 
37
  image,
38
  return_tensors='pt',
39
  ).to(device)
40
 
41
+ outputs = model(**inputs)
42
+
43
+ return outputs
44
 
45
+
46
+ def post_processing(
47
+ outputs: SemanticSegmenterOutput,
48
+ target_size: tuple[int, int],
49
+ ) -> np.ndarray:
50
+ logits = outputs.logits.cpu()
51
 
52
  upsampled_logits = nn.functional.interpolate(
53
  logits,
54
+ size=target_size,
55
  mode='bilinear',
56
  align_corners=False,
57
  )
58
 
59
  segmentation_mask = upsampled_logits.argmax(dim=1)[0]
60
 
61
+ return np.array(segmentation_mask)
62
+
63
+
64
+ def colorize(
65
+ mask: Mask,
66
+ ) -> np.ndarray:
67
+ return mask.colorized(CategoriesInfos()) / 255
68
+
69
+
70
+ def check_and_resize(
71
+ image: np.ndarray,
72
+ ) -> np.ndarray:
73
+
74
+ if image.shape[0] > 1200 or image.shape[1] > 1600:
75
+ r = 1600.0 / image.shape[1]
76
+ dim = (1600, int(image.shape[0] * r))
77
+ return cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
78
+
79
+ return image
80
+
81
+
82
+ def process_big_images(
83
+ image: Image.Image,
84
+ ) -> Mask:
85
+ '''Process and post-processing for images bigger than 400x300'''
86
+ img = check_and_resize(np.asarray(image))
87
+ mask = np.zeros(shape=(img.shape[0], img.shape[1]), dtype=np.uint8)
88
+
89
+ for bbox in __create_xy_slice(image.size[1], image.size[0], 300, 400):
90
+ part = cv2.copyMakeBorder(
91
+ img,
92
+ bbox.y_init,
93
+ bbox.y_end,
94
+ bbox.x_init,
95
+ bbox.x_end,
96
+ cv2.BORDER_REFLECT,
97
+ )
98
+ target_size = (part.shape[0], part.shape[1])
99
+
100
+ outputs = segment(Image.fromarray(part))
101
+ msk = post_processing(outputs, target_size)
102
+
103
+ mask[bbox.slice_y, bbox.slice_x] = msk[bbox.slice_y, bbox.slice_x]
104
 
105
+ return Mask(mask)
106
+
107
+
108
+ def image_with_mask(
109
+ image: Image.Image,
110
+ mask: Mask,
111
+ ) -> plt.Figure:
112
+ fig = plt.figure(dpi=600)
113
+
114
+ plt.imshow(image)
115
+ plt.imshow(
116
+ mask.categorical,
117
+ cmap=mask.cmap(CategoriesInfos()),
118
+ vmax=max(mask.unique_ids),
119
+ vmin=min(mask.unique_ids),
120
+ interpolation='nearest',
121
+ alpha=0.4,
122
+ )
123
+ plt.axis('off')
124
+
125
+ return fig
126
+
127
+
128
+ def categories_map(
129
+ mask: Mask,
130
+ ) -> plt.Figure:
131
+ fig = plt.figure(dpi=600)
132
+
133
+ handles = plot.create_handles(
134
+ CategoriesInfos(), selected_categories=mask.unique_ids,
135
+ )
136
+ plt.legend(handles=handles, fontsize=24, loc='center')
137
+ plt.axis('off')
138
+
139
+ return fig
140
+
141
+
142
+ def main(image):
143
+ img = Image.fromarray(image)
144
+
145
+ mask = process_big_images(img)
146
+ mask_colorized = colorize(mask)
147
+ fig = image_with_mask(img, mask)
148
+
149
+ return categories_map(mask), mask_colorized, fig
150
 
151
 
152
  title = 'SegFormer (b3) - CCAgT dataset'
 
160
  examples = [
161
  [f'https://hf.co/{model_hub_name}/resolve/main/sampleA.png'],
162
  [f'https://hf.co/{model_hub_name}/resolve/main/sampleB.png'],
163
+ ] + [
164
+ [f'https://datasets-server.huggingface.co/assets/lapix/CCAgT/--/semantic_segmentation/test/{x}/image/image.jpg']
165
+ for x in {3, 10, 12, 18, 35, 78, 89}
166
  ]
167
 
168
+
169
  demo = gr.Interface(
170
+ main,
171
  inputs=[gr.Image()],
172
+ outputs=[
173
+ gr.Plot(label='Categories map'),
174
+ gr.Image(label='Mask'),
175
+ gr.Plot(label='Image with mask'),
176
+ ],
177
  title=title,
178
  description=description,
179
  examples=examples,
180
+ allow_flagging='never',
181
+ cache_examples=False,
182
  )
183
 
184
+ if __name__ == '__main__':
185
+ demo.launch()
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  CCAgT-utils
 
2
  numpy
 
3
  torch
4
  transformers
 
1
  CCAgT-utils
2
+ matplotlib
3
  numpy
4
+ opencv-python
5
  torch
6
  transformers