Adapter commited on
Commit
a056b0b
·
1 Parent(s): af3233a
Files changed (5) hide show
  1. app.py +4 -3
  2. demo/demos.py +25 -0
  3. demo/model.py +102 -1
  4. requirements.txt +2 -1
  5. seger.py +283 -0
app.py CHANGED
@@ -8,7 +8,7 @@ os.system('mim install mmcv-full==1.7.0')
8
 
9
  from demo.model import Model_all
10
  import gradio as gr
11
- from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw
12
  import torch
13
  import subprocess
14
  import shlex
@@ -22,6 +22,7 @@ urls = {
22
  urls_mmpose = [
23
  'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
24
  'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth',
 
25
  ]
26
  if os.path.exists('models') == False:
27
  os.mkdir('models')
@@ -69,7 +70,7 @@ with gr.Blocks(css='style.css') as demo:
69
  create_demo_sketch(model.process_sketch)
70
  with gr.TabItem('Draw'):
71
  create_demo_draw(model.process_draw)
 
 
72
 
73
- # demo.queue(api_open=False).launch(server_name='0.0.0.0')
74
- # demo.queue(show_api=False, enable_queue=False).launch(server_name='0.0.0.0')
75
  demo.queue().launch(debug=True, server_name='0.0.0.0')
 
8
 
9
  from demo.model import Model_all
10
  import gradio as gr
11
+ from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg
12
  import torch
13
  import subprocess
14
  import shlex
 
22
  urls_mmpose = [
23
  'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
24
  'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth',
25
+ 'https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth'
26
  ]
27
  if os.path.exists('models') == False:
28
  os.mkdir('models')
 
70
  create_demo_sketch(model.process_sketch)
71
  with gr.TabItem('Draw'):
72
  create_demo_draw(model.process_draw)
73
+ with gr.TabItem('Segmentation'):
74
+ create_demo_seg(model.process_seg)
75
 
 
 
76
  demo.queue().launch(debug=True, server_name='0.0.0.0')
demo/demos.py CHANGED
@@ -70,6 +70,31 @@ def create_demo_sketch(process):
70
  run_button.click(fn=process, inputs=ips, outputs=[result])
71
  return demo
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def create_demo_draw(process):
74
  with gr.Blocks() as demo:
75
  with gr.Row():
 
70
  run_button.click(fn=process, inputs=ips, outputs=[result])
71
  return demo
72
 
73
+ def create_demo_seg(process):
74
+ with gr.Blocks() as demo:
75
+ with gr.Row():
76
+ gr.Markdown('## T2I-Adapter (Segmentation)')
77
+ with gr.Row():
78
+ with gr.Column():
79
+ input_img = gr.Image(source='upload', type="numpy")
80
+ prompt = gr.Textbox(label="Prompt")
81
+ neg_prompt = gr.Textbox(label="Negative Prompt",
82
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
83
+ pos_prompt = gr.Textbox(label="Positive Prompt",
84
+ value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
85
+ with gr.Row():
86
+ type_in = gr.inputs.Radio(['Segmentation', 'Image'], type="value", default='Image', label='You can input an image or a segmentation. If you choose to input a segmentation, it must correspond to the coco-stuff')
87
+ run_button = gr.Button(label="Run")
88
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the segmentation to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
89
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
90
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
91
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
92
+ with gr.Column():
93
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
94
+ ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
95
+ run_button.click(fn=process, inputs=ips, outputs=[result])
96
+ return demo
97
+
98
  def create_demo_draw(process):
99
  with gr.Blocks() as demo:
100
  with gr.Row():
demo/model.py CHANGED
@@ -13,7 +13,30 @@ from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process
13
  import os
14
  import cv2
15
  import numpy as np
16
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def imshow_keypoints(img,
19
  pose_result,
@@ -118,6 +141,13 @@ class Model_all:
118
  self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()})
119
  self.model_edge.to(device)
120
 
 
 
 
 
 
 
 
121
  # keypose part
122
  self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
123
  use_conv=False).to(device)
@@ -218,6 +248,77 @@ class Model_all:
218
 
219
  return [im_edge, x_samples_ddim]
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  @torch.no_grad()
222
  def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
223
  if self.current_base != base_model:
 
13
  import os
14
  import cv2
15
  import numpy as np
16
+ from seger import seger, Colorize
17
+ import torch.nn.functional as F
18
+
19
+ def preprocessing(image, device):
20
+ # Resize
21
+ scale = 640 / max(image.shape[:2])
22
+ image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
23
+ raw_image = image.astype(np.uint8)
24
+
25
+ # Subtract mean values
26
+ image = image.astype(np.float32)
27
+ image -= np.array(
28
+ [
29
+ float(104.008),
30
+ float(116.669),
31
+ float(122.675),
32
+ ]
33
+ )
34
+
35
+ # Convert to torch.Tensor and add "batch" axis
36
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
37
+ image = image.to(device)
38
+
39
+ return image, raw_image
40
 
41
  def imshow_keypoints(img,
42
  pose_result,
 
141
  self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()})
142
  self.model_edge.to(device)
143
 
144
+ # segmentation part
145
+ self.model_seger = seger().to(device)
146
+ self.model_seger.eval()
147
+ self.coler = Colorize(n=182)
148
+ self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
149
+ self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
150
+
151
  # keypose part
152
  self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
153
  use_conv=False).to(device)
 
248
 
249
  return [im_edge, x_samples_ddim]
250
 
251
+ @torch.no_grad()
252
+ def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
253
+ con_strength, base_model):
254
+ if self.current_base != base_model:
255
+ ckpt = os.path.join("models", base_model)
256
+ pl_sd = torch.load(ckpt, map_location="cuda")
257
+ if "state_dict" in pl_sd:
258
+ sd = pl_sd["state_dict"]
259
+ else:
260
+ sd = pl_sd
261
+ self.base_model.load_state_dict(sd, strict=False)
262
+ self.current_base = base_model
263
+ if 'anything' in base_model.lower():
264
+ self.load_vae()
265
+
266
+ con_strength = int((1 - con_strength) * 50)
267
+ if fix_sample == 'True':
268
+ seed_everything(42)
269
+ im = cv2.resize(input_img, (512, 512))
270
+
271
+ if type_in == 'Segmentation':
272
+ im_seg = im.copy()
273
+ im = img2tensor(im).unsqueeze(0) / 255.
274
+ labelmap = im.float()
275
+ elif type_in == 'Image':
276
+ im, _ = preprocessing(im, self.device)
277
+ _, _, H, W = im.shape
278
+
279
+ # Image -> Probability map
280
+ logits = self.model_seger(im)
281
+ logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
282
+ probs = F.softmax(logits, dim=1)[0]
283
+ probs = probs.cpu().data.numpy()
284
+ labelmap = np.argmax(probs, axis=0)
285
+
286
+ labelmap = self.coler(labelmap)
287
+ labelmap = np.transpose(labelmap, (1,2,0))
288
+ labelmap = cv2.resize(labelmap, (512, 512))
289
+ labelmap = img2tensor(labelmap, bgr2rgb=False, float32=True)/255.
290
+ im_seg = tensor2img(labelmap)
291
+ labelmap = labelmap.unsqueeze(0)
292
+
293
+ # extract condition features
294
+ c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
295
+ nc = self.base_model.get_learned_conditioning([neg_prompt])
296
+ features_adapter = self.model_seg(labelmap.to(self.device))
297
+ shape = [4, 64, 64]
298
+
299
+ # sampling
300
+ samples_ddim, _ = self.sampler.sample(S=50,
301
+ conditioning=c,
302
+ batch_size=1,
303
+ shape=shape,
304
+ verbose=False,
305
+ unconditional_guidance_scale=scale,
306
+ unconditional_conditioning=nc,
307
+ eta=0.0,
308
+ x_T=None,
309
+ features_adapter1=features_adapter,
310
+ mode='sketch',
311
+ con_strength=con_strength)
312
+
313
+ x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
314
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
315
+ x_samples_ddim = x_samples_ddim.to('cpu')
316
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
317
+ x_samples_ddim = 255. * x_samples_ddim
318
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
319
+
320
+ return [im_seg, x_samples_ddim]
321
+
322
  @torch.no_grad()
323
  def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
324
  if self.current_base != base_model:
requirements.txt CHANGED
@@ -15,4 +15,5 @@ kornia==0.6.8
15
  openmim
16
  mmpose
17
  mmdet
18
- psutil
 
 
15
  openmim
16
  mmpose
17
  mmdet
18
+ psutil
19
+ blobfile
seger.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import cv2
6
+ from basicsr.utils import img2tensor, tensor2img
7
+
8
+ _BATCH_NORM = nn.BatchNorm2d
9
+ _BOTTLENECK_EXPANSION = 4
10
+
11
+ import blobfile as bf
12
+
13
+ def _list_image_files_recursively(data_dir):
14
+ results = []
15
+ for entry in sorted(bf.listdir(data_dir)):
16
+ full_path = bf.join(data_dir, entry)
17
+ ext = entry.split(".")[-1]
18
+ if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
19
+ results.append(full_path)
20
+ elif bf.isdir(full_path):
21
+ results.extend(_list_image_files_recursively(full_path))
22
+ return results
23
+
24
+ def uint82bin(n, count=8):
25
+ """returns the binary of integer n, count refers to amount of bits"""
26
+ return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
27
+
28
+
29
+ def labelcolormap(N):
30
+ if N == 35: # cityscape
31
+ cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81),
32
+ (128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153),
33
+ (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
34
+ (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
35
+ (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)],
36
+ dtype=np.uint8)
37
+ else:
38
+ cmap = np.zeros((N, 3), dtype=np.uint8)
39
+ for i in range(N):
40
+ r, g, b = 0, 0, 0
41
+ id = i + 1 # let's give 0 a color
42
+ for j in range(7):
43
+ str_id = uint82bin(id)
44
+ r = r ^ (np.uint8(str_id[-1]) << (7 - j))
45
+ g = g ^ (np.uint8(str_id[-2]) << (7 - j))
46
+ b = b ^ (np.uint8(str_id[-3]) << (7 - j))
47
+ id = id >> 3
48
+ cmap[i, 0] = r
49
+ cmap[i, 1] = g
50
+ cmap[i, 2] = b
51
+
52
+ return cmap
53
+
54
+
55
+ class Colorize(object):
56
+ def __init__(self, n=182):
57
+ self.cmap = labelcolormap(n)
58
+
59
+ def __call__(self, gray_image):
60
+ size = gray_image.shape
61
+ color_image = np.zeros((3, size[0], size[1]))
62
+
63
+ for label in range(0, len(self.cmap)):
64
+ mask = (label == gray_image )
65
+ color_image[0][mask] = self.cmap[label][0]
66
+ color_image[1][mask] = self.cmap[label][1]
67
+ color_image[2][mask] = self.cmap[label][2]
68
+
69
+ return color_image
70
+
71
+ class _ConvBnReLU(nn.Sequential):
72
+ """
73
+ Cascade of 2D convolution, batch norm, and ReLU.
74
+ """
75
+
76
+ BATCH_NORM = _BATCH_NORM
77
+
78
+ def __init__(
79
+ self, in_ch, out_ch, kernel_size, stride, padding, dilation, relu=True
80
+ ):
81
+ super(_ConvBnReLU, self).__init__()
82
+ self.add_module(
83
+ "conv",
84
+ nn.Conv2d(
85
+ in_ch, out_ch, kernel_size, stride, padding, dilation, bias=False
86
+ ),
87
+ )
88
+ self.add_module("bn", _BATCH_NORM(out_ch, eps=1e-5, momentum=1 - 0.999))
89
+
90
+ if relu:
91
+ self.add_module("relu", nn.ReLU())
92
+
93
+ class _Bottleneck(nn.Module):
94
+ """
95
+ Bottleneck block of MSRA ResNet.
96
+ """
97
+
98
+ def __init__(self, in_ch, out_ch, stride, dilation, downsample):
99
+ super(_Bottleneck, self).__init__()
100
+ mid_ch = out_ch // _BOTTLENECK_EXPANSION
101
+ self.reduce = _ConvBnReLU(in_ch, mid_ch, 1, stride, 0, 1, True)
102
+ self.conv3x3 = _ConvBnReLU(mid_ch, mid_ch, 3, 1, dilation, dilation, True)
103
+ self.increase = _ConvBnReLU(mid_ch, out_ch, 1, 1, 0, 1, False)
104
+ self.shortcut = (
105
+ _ConvBnReLU(in_ch, out_ch, 1, stride, 0, 1, False)
106
+ if downsample
107
+ else nn.Identity()
108
+ )
109
+
110
+ def forward(self, x):
111
+ h = self.reduce(x)
112
+ h = self.conv3x3(h)
113
+ h = self.increase(h)
114
+ h += self.shortcut(x)
115
+ return F.relu(h)
116
+
117
+ class _ResLayer(nn.Sequential):
118
+ """
119
+ Residual layer with multi grids
120
+ """
121
+
122
+ def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grids=None):
123
+ super(_ResLayer, self).__init__()
124
+
125
+ if multi_grids is None:
126
+ multi_grids = [1 for _ in range(n_layers)]
127
+ else:
128
+ assert n_layers == len(multi_grids)
129
+
130
+ # Downsampling is only in the first block
131
+ for i in range(n_layers):
132
+ self.add_module(
133
+ "block{}".format(i + 1),
134
+ _Bottleneck(
135
+ in_ch=(in_ch if i == 0 else out_ch),
136
+ out_ch=out_ch,
137
+ stride=(stride if i == 0 else 1),
138
+ dilation=dilation * multi_grids[i],
139
+ downsample=(True if i == 0 else False),
140
+ ),
141
+ )
142
+
143
+ class _Stem(nn.Sequential):
144
+ """
145
+ The 1st conv layer.
146
+ Note that the max pooling is different from both MSRA and FAIR ResNet.
147
+ """
148
+
149
+ def __init__(self, out_ch):
150
+ super(_Stem, self).__init__()
151
+ self.add_module("conv1", _ConvBnReLU(3, out_ch, 7, 2, 3, 1))
152
+ self.add_module("pool", nn.MaxPool2d(3, 2, 1, ceil_mode=True))
153
+
154
+ class _ASPP(nn.Module):
155
+ """
156
+ Atrous spatial pyramid pooling (ASPP)
157
+ """
158
+
159
+ def __init__(self, in_ch, out_ch, rates):
160
+ super(_ASPP, self).__init__()
161
+ for i, rate in enumerate(rates):
162
+ self.add_module(
163
+ "c{}".format(i),
164
+ nn.Conv2d(in_ch, out_ch, 3, 1, padding=rate, dilation=rate, bias=True),
165
+ )
166
+
167
+ for m in self.children():
168
+ nn.init.normal_(m.weight, mean=0, std=0.01)
169
+ nn.init.constant_(m.bias, 0)
170
+
171
+ def forward(self, x):
172
+ return sum([stage(x) for stage in self.children()])
173
+
174
+ class MSC(nn.Module):
175
+ """
176
+ Multi-scale inputs
177
+ """
178
+
179
+ def __init__(self, base, scales=None):
180
+ super(MSC, self).__init__()
181
+ self.base = base
182
+ if scales:
183
+ self.scales = scales
184
+ else:
185
+ self.scales = [0.5, 0.75]
186
+
187
+ def forward(self, x):
188
+ # Original
189
+ logits = self.base(x)
190
+ _, _, H, W = logits.shape
191
+ interp = lambda l: F.interpolate(
192
+ l, size=(H, W), mode="bilinear", align_corners=False
193
+ )
194
+
195
+ # Scaled
196
+ logits_pyramid = []
197
+ for p in self.scales:
198
+ h = F.interpolate(x, scale_factor=p, mode="bilinear", align_corners=False)
199
+ logits_pyramid.append(self.base(h))
200
+
201
+ # Pixel-wise max
202
+ logits_all = [logits] + [interp(l) for l in logits_pyramid]
203
+ logits_max = torch.max(torch.stack(logits_all), dim=0)[0]
204
+
205
+ return logits_max
206
+
207
+ class DeepLabV2(nn.Sequential):
208
+ """
209
+ DeepLab v2: Dilated ResNet + ASPP
210
+ Output stride is fixed at 8
211
+ """
212
+
213
+ def __init__(self, n_classes=182, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]):
214
+ super(DeepLabV2, self).__init__()
215
+ ch = [64 * 2 ** p for p in range(6)]
216
+ self.add_module("layer1", _Stem(ch[0]))
217
+ self.add_module("layer2", _ResLayer(n_blocks[0], ch[0], ch[2], 1, 1))
218
+ self.add_module("layer3", _ResLayer(n_blocks[1], ch[2], ch[3], 2, 1))
219
+ self.add_module("layer4", _ResLayer(n_blocks[2], ch[3], ch[4], 1, 2))
220
+ self.add_module("layer5", _ResLayer(n_blocks[3], ch[4], ch[5], 1, 4))
221
+ self.add_module("aspp", _ASPP(ch[5], n_classes, atrous_rates))
222
+
223
+ def freeze_bn(self):
224
+ for m in self.modules():
225
+ if isinstance(m, _ConvBnReLU.BATCH_NORM):
226
+ m.eval()
227
+
228
+ def preprocessing(image, device):
229
+ # Resize
230
+ scale = 640 / max(image.shape[:2])
231
+ image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
232
+ raw_image = image.astype(np.uint8)
233
+
234
+ # Subtract mean values
235
+ image = image.astype(np.float32)
236
+ image -= np.array(
237
+ [
238
+ float(104.008),
239
+ float(116.669),
240
+ float(122.675),
241
+ ]
242
+ )
243
+
244
+ # Convert to torch.Tensor and add "batch" axis
245
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
246
+ image = image.to(device)
247
+
248
+ return image, raw_image
249
+
250
+ # Model setup
251
+ def seger():
252
+ model = MSC(
253
+ base=DeepLabV2(
254
+ n_classes=182, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]
255
+ ),
256
+ scales=[0.5, 0.75],
257
+ )
258
+ state_dict = torch.load('models/deeplabv2_resnet101_msc-cocostuff164k-100000.pth')
259
+ model.load_state_dict(state_dict) # to skip ASPP
260
+
261
+ return model
262
+
263
+ if __name__ == '__main__':
264
+ device = 'cuda'
265
+ model = seger()
266
+ model.to(device)
267
+ model.eval()
268
+ with torch.no_grad():
269
+ im = cv2.imread('/group/30042/chongmou/ft_local/Diffusion/baselines/SPADE/datasets/coco_stuff/val_img/000000000785.jpg', cv2.IMREAD_COLOR)
270
+ im, raw_im = preprocessing(im, 'cuda')
271
+ _, _, H, W = im.shape
272
+
273
+ # Image -> Probability map
274
+ logits = model(im)
275
+ logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
276
+ probs = F.softmax(logits, dim=1)[0]
277
+ probs = probs.cpu().data.numpy()
278
+ labelmap = np.argmax(probs, axis=0)
279
+ print(labelmap.shape, np.max(labelmap), np.min(labelmap))
280
+ cv2.imwrite('mask.png', labelmap)
281
+
282
+
283
+