汐知 commited on
Commit
d8c7468
1 Parent(s): 240d951
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. app.py +22 -12
  3. configs/demo.yaml +1 -0
  4. iseg/coarse_mask_refine_util.py +285 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -1,15 +1,11 @@
1
  import os
2
  import sys
3
- #sys.path.append('.')
4
- #os.system("pip install gradio==3.50.2")
5
  import cv2
6
  import einops
7
  import numpy as np
8
  import torch
9
  import random
10
  import gradio as gr
11
- #print(gr.__version__)
12
-
13
  import albumentations as A
14
  from PIL import Image
15
  import torchvision.transforms as T
@@ -20,6 +16,7 @@ from omegaconf import OmegaConf
20
  from cldm.hack import disable_verbosity, enable_sliced_attention
21
  from huggingface_hub import snapshot_download
22
 
 
23
  snapshot_download(repo_id="xichenhku/AnyDoor_models", local_dir="./AnyDoor_models")
24
 
25
 
@@ -35,8 +32,7 @@ if save_memory:
35
  config = OmegaConf.load('./configs/demo.yaml')
36
  model_ckpt = config.pretrained_model
37
  model_config = config.config_file
38
-
39
-
40
 
41
 
42
  model = create_model(model_config ).cpu()
@@ -44,6 +40,13 @@ model.load_state_dict(load_state_dict(model_ckpt, location='cuda'))
44
  model = model.cuda()
45
  ddim_sampler = DDIMSampler(model)
46
 
 
 
 
 
 
 
 
47
 
48
  def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop):
49
  H1, W1, H2, W2 = extra_sizes
@@ -222,6 +225,13 @@ ref_list.sort()
222
  image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
223
  image_list.sort()
224
 
 
 
 
 
 
 
 
225
  def mask_image(image, mask):
226
  blanc = np.ones_like(image) * 255
227
  mask = np.stack([mask,mask,mask],-1) / 255
@@ -242,6 +252,11 @@ def run_local(base,
242
  ref_mask = np.asarray(ref_mask)
243
  ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
244
 
 
 
 
 
 
245
  processed_item = process_pairs(ref_image.copy(), ref_mask.copy(), image.copy(), mask.copy(), max_ratio = 0.8)
246
  masked_ref = (processed_item['ref']*255)
247
 
@@ -254,15 +269,13 @@ def run_local(base,
254
  masked_ref = cv2.resize(masked_ref.astype(np.uint8), (512,512))
255
  return [synthesis]
256
 
257
-
258
-
259
  with gr.Blocks() as demo:
260
  with gr.Column():
261
  gr.Markdown("# Play with AnyDoor to Teleport your Target Objects! ")
262
  with gr.Row():
263
  baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=1, height=768)
264
  with gr.Accordion("Advanced Option", open=True):
265
- num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
266
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
267
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=30, step=1)
268
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=5.0, step=0.1)
@@ -270,9 +283,6 @@ with gr.Blocks() as demo:
270
  gr.Markdown(" Higher guidance-scale makes higher fidelity, while lower guidance-scale leads to more harmonized blending.")
271
 
272
 
273
-
274
-
275
-
276
  gr.Markdown("# Upload / Select Images for the Background (left) and Reference Object (right)")
277
  gr.Markdown("### Your could draw coarse masks on the background to indicate the desired location and shape.")
278
  gr.Markdown("### <u>Do not forget</u> to annotate the target object on the reference image.")
 
1
  import os
2
  import sys
 
 
3
  import cv2
4
  import einops
5
  import numpy as np
6
  import torch
7
  import random
8
  import gradio as gr
 
 
9
  import albumentations as A
10
  from PIL import Image
11
  import torchvision.transforms as T
 
16
  from cldm.hack import disable_verbosity, enable_sliced_attention
17
  from huggingface_hub import snapshot_download
18
 
19
+
20
  snapshot_download(repo_id="xichenhku/AnyDoor_models", local_dir="./AnyDoor_models")
21
 
22
 
 
32
  config = OmegaConf.load('./configs/demo.yaml')
33
  model_ckpt = config.pretrained_model
34
  model_config = config.config_file
35
+ use_interactive_seg = config.config_file
 
36
 
37
 
38
  model = create_model(model_config ).cpu()
 
40
  model = model.cuda()
41
  ddim_sampler = DDIMSampler(model)
42
 
43
+ if use_interactive_seg:
44
+ from iseg.coarse_mask_refine_util import BaselineModel
45
+ model_path = './iseg/coarse_mask_refine.pth'
46
+ iseg_model = BaselineModel().eval()
47
+ weights = torch.load(model_path , map_location='cpu')['state_dict']
48
+ iseg_model.load_state_dict(weights, strict= True)
49
+
50
 
51
  def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop):
52
  H1, W1, H2, W2 = extra_sizes
 
225
  image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
226
  image_list.sort()
227
 
228
+ def process_image_mask(image_np, mask_np):
229
+ img = torch.from_numpy(image_np.transpose((2, 0, 1)))
230
+ img_ten = img.float().div(255).unsqueeze(0)
231
+ mask_ten = torch.from_numpy(mask_np).float().unsqueeze(0).unsqueeze(0)
232
+ return img_ten, mask_ten
233
+
234
+
235
  def mask_image(image, mask):
236
  blanc = np.ones_like(image) * 255
237
  mask = np.stack([mask,mask,mask],-1) / 255
 
252
  ref_mask = np.asarray(ref_mask)
253
  ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
254
 
255
+ # refine the user annotated coarse mask
256
+ if use_interactive_seg:
257
+ img_ten, mask_ten = process_image_mask(ref_image, ref_mask)
258
+ ref_mask = iseg_model(img_ten, mask_ten)['instances'][0,0].detach().numpy() > 0.5
259
+
260
  processed_item = process_pairs(ref_image.copy(), ref_mask.copy(), image.copy(), mask.copy(), max_ratio = 0.8)
261
  masked_ref = (processed_item['ref']*255)
262
 
 
269
  masked_ref = cv2.resize(masked_ref.astype(np.uint8), (512,512))
270
  return [synthesis]
271
 
 
 
272
  with gr.Blocks() as demo:
273
  with gr.Column():
274
  gr.Markdown("# Play with AnyDoor to Teleport your Target Objects! ")
275
  with gr.Row():
276
  baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=1, height=768)
277
  with gr.Accordion("Advanced Option", open=True):
278
+ #num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
279
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
280
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=30, step=1)
281
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=5.0, step=0.1)
 
283
  gr.Markdown(" Higher guidance-scale makes higher fidelity, while lower guidance-scale leads to more harmonized blending.")
284
 
285
 
 
 
 
286
  gr.Markdown("# Upload / Select Images for the Background (left) and Reference Object (right)")
287
  gr.Markdown("### Your could draw coarse masks on the background to indicate the desired location and shape.")
288
  gr.Markdown("### <u>Do not forget</u> to annotate the target object on the reference image.")
configs/demo.yaml CHANGED
@@ -1,3 +1,4 @@
1
  pretrained_model: ./AnyDoor_models/general_v0.1/general_v0.1.ckpt
2
  config_file: configs/anydoor.yaml
3
  save_memory: False
 
 
1
  pretrained_model: ./AnyDoor_models/general_v0.1/general_v0.1.ckpt
2
  config_file: configs/anydoor.yaml
3
  save_memory: False
4
+ use_interactive_seg: True
iseg/coarse_mask_refine_util.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MobileNet and MobileNetV2."""
2
+ '''
3
+ Code adopted from https://github.com/LikeLy-Journey/SegmenTron/blob/master/segmentron/models/backbones/mobilenet.py
4
+ '''
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ # ============ Basic Blocks ============
10
+
11
+ class _ConvBNReLU(nn.Module):
12
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
13
+ dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d):
14
+ super(_ConvBNReLU, self).__init__()
15
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
16
+ self.bn = norm_layer(out_channels)
17
+ self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True)
18
+
19
+ def forward(self, x):
20
+ x = self.conv(x)
21
+ x = self.bn(x)
22
+ x = self.relu(x)
23
+ return x
24
+
25
+ class _DepthwiseConv(nn.Module):
26
+ """conv_dw in MobileNet"""
27
+
28
+ def __init__(self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs):
29
+ super(_DepthwiseConv, self).__init__()
30
+ self.conv = nn.Sequential(
31
+ _ConvBNReLU(in_channels, in_channels, 3, stride, 1, groups=in_channels, norm_layer=norm_layer),
32
+ _ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer))
33
+
34
+ def forward(self, x):
35
+ return self.conv(x)
36
+
37
+
38
+ class InvertedResidual(nn.Module):
39
+ def __init__(self, in_channels, out_channels, stride, expand_ratio, dilation=1, norm_layer=nn.BatchNorm2d):
40
+ super(InvertedResidual, self).__init__()
41
+ assert stride in [1, 2]
42
+ self.use_res_connect = stride == 1 and in_channels == out_channels
43
+
44
+ layers = list()
45
+ inter_channels = int(round(in_channels * expand_ratio))
46
+ if expand_ratio != 1:
47
+ # pw
48
+ layers.append(_ConvBNReLU(in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer))
49
+ layers.extend([
50
+ # dw
51
+ _ConvBNReLU(inter_channels, inter_channels, 3, stride, dilation, dilation,
52
+ groups=inter_channels, relu6=True, norm_layer=norm_layer),
53
+ # pw-linear
54
+ nn.Conv2d(inter_channels, out_channels, 1, bias=False),
55
+ norm_layer(out_channels)])
56
+ self.conv = nn.Sequential(*layers)
57
+
58
+ def forward(self, x):
59
+ if self.use_res_connect:
60
+ return x + self.conv(x)
61
+ else:
62
+ return self.conv(x)
63
+
64
+
65
+ # ============ Backbone ============
66
+
67
+ class MobileNetV2(nn.Module):
68
+ def __init__(self, num_classes=1000, norm_layer=nn.BatchNorm2d):
69
+ super(MobileNetV2, self).__init__()
70
+ output_stride = 8
71
+ self.multiplier = 1
72
+ if output_stride == 32:
73
+ dilations = [1, 1]
74
+ elif output_stride == 16:
75
+ dilations = [1, 2]
76
+ elif output_stride == 8:
77
+ dilations = [2, 4]
78
+ else:
79
+ raise NotImplementedError
80
+ inverted_residual_setting = [
81
+ # t, c, n, s
82
+ [1, 16, 1, 1],
83
+ [6, 24, 2, 2],
84
+ [6, 32, 3, 2],
85
+ [6, 64, 4, 2],
86
+ [6, 96, 3, 1],
87
+ [6, 160, 3, 2],
88
+ [6, 320, 1, 1]]
89
+ # building first layer
90
+ input_channels = int(32 * self.multiplier) if self.multiplier > 1.0 else 32
91
+ # last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280
92
+ self.conv1 = _ConvBNReLU(3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer)
93
+
94
+ # building inverted residual blocks
95
+ self.planes = input_channels
96
+ self.block1 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[0:1],
97
+ norm_layer=norm_layer)
98
+ self.block2 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[1:2],
99
+ norm_layer=norm_layer)
100
+ self.block3 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[2:3],
101
+ norm_layer=norm_layer)
102
+ self.block4 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[3:5],
103
+ dilations[0], norm_layer=norm_layer)
104
+ self.block5 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[5:],
105
+ dilations[1], norm_layer=norm_layer)
106
+ self.last_inp_channels = self.planes
107
+
108
+ # building last several layers
109
+ # features = list()
110
+ # features.append(_ConvBNReLU(input_channels, last_channels, 1, relu6=True, norm_layer=norm_layer))
111
+ # features.append(nn.AdaptiveAvgPool2d(1))
112
+ # self.features = nn.Sequential(*features)
113
+ #
114
+ # self.classifier = nn.Sequential(
115
+ # nn.Dropout2d(0.2),
116
+ # nn.Linear(last_channels, num_classes))
117
+
118
+ # weight initialization
119
+ for m in self.modules():
120
+ if isinstance(m, nn.Conv2d):
121
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
122
+ if m.bias is not None:
123
+ nn.init.zeros_(m.bias)
124
+ elif isinstance(m, nn.BatchNorm2d):
125
+ nn.init.ones_(m.weight)
126
+ nn.init.zeros_(m.bias)
127
+ elif isinstance(m, nn.Linear):
128
+ nn.init.normal_(m.weight, 0, 0.01)
129
+ if m.bias is not None:
130
+ nn.init.zeros_(m.bias)
131
+
132
+ def _make_layer(self, block, planes, inverted_residual_setting, dilation=1, norm_layer=nn.BatchNorm2d):
133
+ features = list()
134
+ for t, c, n, s in inverted_residual_setting:
135
+ out_channels = int(c * self.multiplier)
136
+ stride = s if dilation == 1 else 1
137
+ features.append(block(planes, out_channels, stride, t, dilation, norm_layer))
138
+ planes = out_channels
139
+ for i in range(n - 1):
140
+ features.append(block(planes, out_channels, 1, t, norm_layer=norm_layer))
141
+ planes = out_channels
142
+ self.planes = planes
143
+ return nn.Sequential(*features)
144
+
145
+ def forward(self, x, side_feature):
146
+ x = self.conv1(x)
147
+ x = x + side_feature
148
+ x = self.block1(x)
149
+ c1 = self.block2(x)
150
+ c2 = self.block3(c1)
151
+ c3 = self.block4(c2)
152
+ c4 = self.block5(c3)
153
+ # x = self.features(x)
154
+ # x = self.classifier(x.view(x.size(0), x.size(1)))
155
+ return c1, c2, c3, c4
156
+
157
+ def mobilenet_v2(norm_layer=nn.BatchNorm2d):
158
+ return MobileNetV2(norm_layer=norm_layer)
159
+
160
+
161
+
162
+ # ============ Segmentor ============
163
+
164
+ class LRASPP(nn.Module):
165
+ """Lite R-ASPP"""
166
+
167
+ def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
168
+ super(LRASPP, self).__init__()
169
+ self.b0 = nn.Sequential(
170
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
171
+ norm_layer(out_channels),
172
+ nn.ReLU(True)
173
+ )
174
+ self.b1 = nn.Sequential(
175
+ nn.AdaptiveAvgPool2d((2,2)),
176
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
177
+ nn.Sigmoid(),
178
+ )
179
+
180
+ def forward(self, x):
181
+ size = x.size()[2:]
182
+ feat1 = self.b0(x)
183
+ feat2 = self.b1(x)
184
+ feat2 = F.interpolate(feat2, size, mode='bilinear', align_corners=True)
185
+ x = feat1 * feat2
186
+ return x
187
+
188
+
189
+
190
+ class MobileSeg(nn.Module):
191
+ def __init__(self, nclass=1, **kwargs):
192
+ super(MobileSeg, self).__init__()
193
+ self.backbone = mobilenet_v2()
194
+ self.lraspp = LRASPP(320,128)
195
+ self.fusion_conv1 = nn.Conv2d(128,16,1,1,0)
196
+ self.fusion_conv2 = nn.Conv2d(24,16,1,1,0)
197
+ self.head = nn.Conv2d(16,nclass,1,1,0)
198
+ self.aux_head = nn.Conv2d(16,nclass,1,1,0)
199
+
200
+ def forward(self, x, side_feature):
201
+ x4, _, _, x8 = self.backbone(x, side_feature)
202
+ x8 = self.lraspp(x8)
203
+ x8 = F.interpolate(x8, x4.size()[2:], mode='bilinear', align_corners=True)
204
+ x8 = self.fusion_conv1(x8)
205
+ pred_aux = self.aux_head(x8)
206
+
207
+ x4 = self.fusion_conv2(x4)
208
+ x = x4 + x8
209
+ pred = self.head(x)
210
+ return pred, pred_aux, x
211
+
212
+ def load_pretrained_weights(self, path_to_weights= ' '):
213
+ backbone_state_dict = self.backbone.state_dict()
214
+ pretrained_state_dict = torch.load(path_to_weights, map_location='cpu')
215
+ ckpt_keys = set(pretrained_state_dict.keys())
216
+ own_keys = set(backbone_state_dict.keys())
217
+ missing_keys = own_keys - ckpt_keys
218
+ unexpected_keys = ckpt_keys - own_keys
219
+ print('Loading Mobilnet V2')
220
+ print('Missing Keys: ', missing_keys)
221
+ print('Unexpected Keys: ', unexpected_keys)
222
+ backbone_state_dict.update(pretrained_state_dict)
223
+ self.backbone.load_state_dict(backbone_state_dict, strict= False)
224
+
225
+
226
+
227
+
228
+ class ScaleLayer(nn.Module):
229
+ def __init__(self, init_value=1.0, lr_mult=1):
230
+ super().__init__()
231
+ self.lr_mult = lr_mult
232
+ self.scale = nn.Parameter(
233
+ torch.full((1,), init_value / lr_mult, dtype=torch.float32)
234
+ )
235
+
236
+ def forward(self, x):
237
+ scale = torch.abs(self.scale * self.lr_mult)
238
+ return x * scale
239
+
240
+
241
+ # ============ Interactive Segmentor ============
242
+
243
+ class BaselineModel(nn.Module):
244
+ def __init__(self, backbone_lr_mult=0.1,
245
+ norm_layer=nn.BatchNorm2d, **kwargs):
246
+ super().__init__()
247
+ self.feature_extractor = MobileSeg()
248
+ side_feature_ch = 32
249
+ mt_layers = [
250
+ nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
251
+ nn.LeakyReLU(negative_slope=0.2),
252
+ nn.Conv2d(in_channels=16, out_channels=side_feature_ch, kernel_size=3, stride=1, padding=1),
253
+ ScaleLayer(init_value=0.05, lr_mult=1)
254
+ ]
255
+ self.maps_transform = nn.Sequential(*mt_layers)
256
+
257
+
258
+ def backbone_forward(self, image, coord_features=None):
259
+ mask, mask_aux, feature = self.feature_extractor(image, coord_features)
260
+ return {'instances': mask, 'instances_aux':mask_aux, 'feature': feature}
261
+
262
+
263
+ def prepare_input(self, image):
264
+ prev_mask = torch.zeros_like(image)[:,:1,:,:]
265
+ return image, prev_mask
266
+
267
+ def forward(self, image, coarse_mask):
268
+ image, prev_mask = self.prepare_input(image)
269
+ coord_features = torch.cat((prev_mask, coarse_mask, coarse_mask * 0.0), dim=1)
270
+ click_map = coord_features[:,1:,:,:]
271
+
272
+ coord_features = self.maps_transform(coord_features)
273
+ outputs = self.backbone_forward(image, coord_features)
274
+
275
+ pred = nn.functional.interpolate(
276
+ outputs['instances'],
277
+ size=image.size()[2:],
278
+ mode='bilinear', align_corners=True
279
+ )
280
+
281
+ outputs['instances'] = torch.sigmoid(pred)
282
+ return outputs
283
+
284
+
285
+