ChongMou commited on
Commit
1e3fd43
1 Parent(s): f0ae51e

Update demo/model.py

Browse files
Files changed (1) hide show
  1. demo/model.py +97 -85
demo/model.py CHANGED
@@ -81,6 +81,7 @@ def imshow_keypoints(img,
81
 
82
  return img
83
 
 
84
  def load_model_from_config(config, ckpt, verbose=False):
85
  print(f"Loading model from {ckpt}")
86
  pl_sd = torch.load(ckpt, map_location="cpu")
@@ -97,6 +98,7 @@ def load_model_from_config(config, ckpt, verbose=False):
97
  model.eval()
98
  return model
99
 
 
100
  class Model_all:
101
  def __init__(self, device='cpu'):
102
  # common part
@@ -108,18 +110,20 @@ class Model_all:
108
  self.sampler = PLMSSampler(self.base_model)
109
 
110
  # sketch part
111
- self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
 
112
  self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
113
  self.model_edge = pidinet()
114
  ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
115
- self.model_edge.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
116
  self.model_edge.to(device)
117
 
118
  # keypose part
119
- self.model_pose = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
 
120
  self.model_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth", map_location=device))
121
  ## mmpose
122
- det_config = 'models/faster_rcnn_r50_fpn_coco.py'
123
  det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
124
  pose_config = 'models/hrnet_w48_coco_256x192.py'
125
  pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
@@ -131,50 +135,56 @@ class Model_all:
131
  pose_config_mmcv = mmcv.Config.fromfile(pose_config)
132
  self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
133
  ## color
134
- self.skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10],
135
- [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
136
- self.pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
137
- [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0],
138
- [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
 
 
 
139
  self.pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
140
- [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0],
141
- [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
142
- [51, 153, 255], [51, 153, 255], [51, 153, 255]]
 
 
143
 
144
  @torch.no_grad()
145
- def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
 
146
  if self.current_base != base_model:
147
  ckpt = os.path.join("models", base_model)
148
- pl_sd = torch.load(ckpt, map_location="cpu")
149
  if "state_dict" in pl_sd:
150
  sd = pl_sd["state_dict"]
151
  else:
152
  sd = pl_sd
153
- self.base_model = self.base_model.cpu()
154
  self.base_model.load_state_dict(sd, strict=False)
155
- self.base_model = self.base_model.cuda()
156
  self.current_base = base_model
157
  # del sd
158
  # del pl_sd
159
- con_strength = int((1-con_strength)*50)
160
  if fix_sample == 'True':
161
  seed_everything(42)
162
- im = cv2.resize(input_img,(512,512))
163
 
164
  if type_in == 'Sketch':
165
  if color_back == 'White':
166
- im = 255-im
167
  im_edge = im.copy()
168
- im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255.
169
- im = im>0.5
170
  im = im.float()
171
  elif type_in == 'Image':
172
- im = img2tensor(im).unsqueeze(0)/255.
173
  im = self.model_edge(im.to(self.device))[-1]
174
- im = im>0.5
175
  im = im.float()
176
  im_edge = tensor2img(im)
177
-
178
  # # save gpu memory
179
  # self.base_model.model = self.base_model.model.cpu()
180
  # self.model_sketch = self.model_sketch.cuda()
@@ -182,11 +192,11 @@ class Model_all:
182
  # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
183
 
184
  # extract condition features
185
- c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
186
  nc = self.base_model.get_learned_conditioning([neg_prompt])
187
  features_adapter = self.model_sketch(im.to(self.device))
188
  shape = [4, 64, 64]
189
-
190
  # # save gpu memory
191
  # self.model_sketch = self.model_sketch.cpu()
192
  # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
@@ -194,17 +204,17 @@ class Model_all:
194
 
195
  # sampling
196
  samples_ddim, _ = self.sampler.sample(S=50,
197
- conditioning=c,
198
- batch_size=1,
199
- shape=shape,
200
- verbose=False,
201
- unconditional_guidance_scale=scale,
202
- unconditional_conditioning=nc,
203
- eta=0.0,
204
- x_T=None,
205
- features_adapter1=features_adapter,
206
- mode = 'sketch',
207
- con_strength = con_strength)
208
  # # save gpu memory
209
  # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
210
 
@@ -212,7 +222,7 @@ class Model_all:
212
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
213
  x_samples_ddim = x_samples_ddim.to('cpu')
214
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
215
- x_samples_ddim = 255.*x_samples_ddim
216
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
217
 
218
  return [im_edge, x_samples_ddim]
@@ -221,16 +231,16 @@ class Model_all:
221
  def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
222
  if self.current_base != base_model:
223
  ckpt = os.path.join("models", base_model)
224
- pl_sd = torch.load(ckpt, map_location="cpu")
225
  if "state_dict" in pl_sd:
226
  sd = pl_sd["state_dict"]
227
  else:
228
  sd = pl_sd
229
- self.base_model = self.base_model.cpu()
230
  self.base_model.load_state_dict(sd, strict=False)
231
- self.base_model = self.base_model.cuda()
232
  self.current_base = base_model
233
- con_strength = int((1-con_strength)*50)
234
  if fix_sample == 'True':
235
  seed_everything(42)
236
  input_img = input_img['mask']
@@ -238,12 +248,12 @@ class Model_all:
238
  a = input_img[:, :, 3:4].astype(np.float32) / 255.0
239
  im = c * a + 255.0 * (1.0 - a)
240
  im = im.clip(0, 255).astype(np.uint8)
241
- im = cv2.resize(im,(512,512))
242
 
243
  # im = 255-im
244
  im_edge = im.copy()
245
- im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255.
246
- im = im>0.5
247
  im = im.float()
248
 
249
  # # save gpu memory
@@ -251,9 +261,9 @@ class Model_all:
251
  # self.model_sketch = self.model_sketch.cuda()
252
  # self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
253
  # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
254
-
255
  # extract condition features
256
- c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
257
  nc = self.base_model.get_learned_conditioning([neg_prompt])
258
  features_adapter = self.model_sketch(im.to(self.device))
259
  shape = [4, 64, 64]
@@ -265,18 +275,18 @@ class Model_all:
265
 
266
  # sampling
267
  samples_ddim, _ = self.sampler.sample(S=50,
268
- conditioning=c,
269
- batch_size=1,
270
- shape=shape,
271
- verbose=False,
272
- unconditional_guidance_scale=scale,
273
- unconditional_conditioning=nc,
274
- eta=0.0,
275
- x_T=None,
276
- features_adapter1=features_adapter,
277
- mode = 'sketch',
278
- con_strength = con_strength)
279
-
280
  # # save gpu memory
281
  # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
282
 
@@ -284,35 +294,36 @@ class Model_all:
284
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
285
  x_samples_ddim = x_samples_ddim.to('cpu')
286
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
287
- x_samples_ddim = 255.*x_samples_ddim
288
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
289
 
290
  return [im_edge, x_samples_ddim]
291
 
292
  @torch.no_grad()
293
- def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
 
294
  if self.current_base != base_model:
295
  ckpt = os.path.join("models", base_model)
296
- pl_sd = torch.load(ckpt, map_location="cpu")
297
  if "state_dict" in pl_sd:
298
  sd = pl_sd["state_dict"]
299
  else:
300
  sd = pl_sd
301
- self.base_model = self.base_model.cpu()
302
  self.base_model.load_state_dict(sd, strict=False)
303
- self.base_model = self.base_model.cuda()
304
  self.current_base = base_model
305
- con_strength = int((1-con_strength)*50)
306
  if fix_sample == 'True':
307
  seed_everything(42)
308
- im = cv2.resize(input_img,(512,512))
309
 
310
  if type_in == 'Keypose':
311
  im_pose = im.copy()
312
- im = img2tensor(im).unsqueeze(0)/255.
313
  elif type_in == 'Image':
314
  image = im.copy()
315
- im = img2tensor(im).unsqueeze(0)/255.
316
  mmdet_results = inference_detector(self.det_model, image)
317
  # keep the person class bounding boxes.
318
  person_results = process_mmdet_results(mmdet_results, self.det_cat_id)
@@ -343,8 +354,8 @@ class Model_all:
343
  pose_link_color=self.pose_link_color,
344
  radius=2,
345
  thickness=2)
346
- im_pose = cv2.resize(im_pose,(512,512))
347
-
348
  # # save gpu memory
349
  # self.base_model.model = self.base_model.model.cpu()
350
  # self.model_pose = self.model_pose.cuda()
@@ -352,9 +363,9 @@ class Model_all:
352
  # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
353
 
354
  # extract condition features
355
- c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
356
  nc = self.base_model.get_learned_conditioning([neg_prompt])
357
- pose = img2tensor(im_pose, bgr2rgb=True, float32=True)/255.
358
  pose = pose.unsqueeze(0)
359
  features_adapter = self.model_pose(pose.to(self.device))
360
 
@@ -367,17 +378,17 @@ class Model_all:
367
 
368
  # sampling
369
  samples_ddim, _ = self.sampler.sample(S=50,
370
- conditioning=c,
371
- batch_size=1,
372
- shape=shape,
373
- verbose=False,
374
- unconditional_guidance_scale=scale,
375
- unconditional_conditioning=nc,
376
- eta=0.0,
377
- x_T=None,
378
- features_adapter1=features_adapter,
379
- mode = 'sketch',
380
- con_strength = con_strength)
381
 
382
  # # save gpu memory
383
  # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
@@ -386,10 +397,11 @@ class Model_all:
386
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
387
  x_samples_ddim = x_samples_ddim.to('cpu')
388
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
389
- x_samples_ddim = 255.*x_samples_ddim
390
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
391
 
392
- return [im_pose[:,:,::-1].astype(np.uint8), x_samples_ddim]
 
393
 
394
  if __name__ == '__main__':
395
  model = Model_all('cpu')
 
81
 
82
  return img
83
 
84
+
85
  def load_model_from_config(config, ckpt, verbose=False):
86
  print(f"Loading model from {ckpt}")
87
  pl_sd = torch.load(ckpt, map_location="cpu")
 
98
  model.eval()
99
  return model
100
 
101
+
102
  class Model_all:
103
  def __init__(self, device='cpu'):
104
  # common part
 
110
  self.sampler = PLMSSampler(self.base_model)
111
 
112
  # sketch part
113
+ self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
114
+ use_conv=False).to(device)
115
  self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
116
  self.model_edge = pidinet()
117
  ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
118
+ self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()})
119
  self.model_edge.to(device)
120
 
121
  # keypose part
122
+ self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
123
+ use_conv=False).to(device)
124
  self.model_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth", map_location=device))
125
  ## mmpose
126
+ det_config = 'models/faster_rcnn_r50_fpn_coco.py'
127
  det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
128
  pose_config = 'models/hrnet_w48_coco_256x192.py'
129
  pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
 
135
  pose_config_mmcv = mmcv.Config.fromfile(pose_config)
136
  self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
137
  ## color
138
+ self.skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8],
139
+ [7, 9], [8, 10],
140
+ [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
141
+ self.pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
142
+ [0, 255, 0],
143
+ [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0],
144
+ [255, 128, 0],
145
+ [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
146
  self.pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
147
+ [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
148
+ [255, 128, 0],
149
+ [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
150
+ [51, 153, 255],
151
+ [51, 153, 255], [51, 153, 255], [51, 153, 255]]
152
 
153
  @torch.no_grad()
154
+ def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale,
155
+ con_strength, base_model):
156
  if self.current_base != base_model:
157
  ckpt = os.path.join("models", base_model)
158
+ pl_sd = torch.load(ckpt, map_location="cuda")
159
  if "state_dict" in pl_sd:
160
  sd = pl_sd["state_dict"]
161
  else:
162
  sd = pl_sd
163
+ # self.base_model = self.base_model.cpu()
164
  self.base_model.load_state_dict(sd, strict=False)
165
+ # self.base_model = self.base_model.cuda()
166
  self.current_base = base_model
167
  # del sd
168
  # del pl_sd
169
+ con_strength = int((1 - con_strength) * 50)
170
  if fix_sample == 'True':
171
  seed_everything(42)
172
+ im = cv2.resize(input_img, (512, 512))
173
 
174
  if type_in == 'Sketch':
175
  if color_back == 'White':
176
+ im = 255 - im
177
  im_edge = im.copy()
178
+ im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
179
+ im = im > 0.5
180
  im = im.float()
181
  elif type_in == 'Image':
182
+ im = img2tensor(im).unsqueeze(0) / 255.
183
  im = self.model_edge(im.to(self.device))[-1]
184
+ im = im > 0.5
185
  im = im.float()
186
  im_edge = tensor2img(im)
187
+
188
  # # save gpu memory
189
  # self.base_model.model = self.base_model.model.cpu()
190
  # self.model_sketch = self.model_sketch.cuda()
 
192
  # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
193
 
194
  # extract condition features
195
+ c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
196
  nc = self.base_model.get_learned_conditioning([neg_prompt])
197
  features_adapter = self.model_sketch(im.to(self.device))
198
  shape = [4, 64, 64]
199
+
200
  # # save gpu memory
201
  # self.model_sketch = self.model_sketch.cpu()
202
  # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
 
204
 
205
  # sampling
206
  samples_ddim, _ = self.sampler.sample(S=50,
207
+ conditioning=c,
208
+ batch_size=1,
209
+ shape=shape,
210
+ verbose=False,
211
+ unconditional_guidance_scale=scale,
212
+ unconditional_conditioning=nc,
213
+ eta=0.0,
214
+ x_T=None,
215
+ features_adapter1=features_adapter,
216
+ mode='sketch',
217
+ con_strength=con_strength)
218
  # # save gpu memory
219
  # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
220
 
 
222
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
223
  x_samples_ddim = x_samples_ddim.to('cpu')
224
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
225
+ x_samples_ddim = 255. * x_samples_ddim
226
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
227
 
228
  return [im_edge, x_samples_ddim]
 
231
  def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
232
  if self.current_base != base_model:
233
  ckpt = os.path.join("models", base_model)
234
+ pl_sd = torch.load(ckpt, map_location="cuda")
235
  if "state_dict" in pl_sd:
236
  sd = pl_sd["state_dict"]
237
  else:
238
  sd = pl_sd
239
+ # self.base_model = self.base_model.cpu()
240
  self.base_model.load_state_dict(sd, strict=False)
241
+ # self.base_model = self.base_model.cuda()
242
  self.current_base = base_model
243
+ con_strength = int((1 - con_strength) * 50)
244
  if fix_sample == 'True':
245
  seed_everything(42)
246
  input_img = input_img['mask']
 
248
  a = input_img[:, :, 3:4].astype(np.float32) / 255.0
249
  im = c * a + 255.0 * (1.0 - a)
250
  im = im.clip(0, 255).astype(np.uint8)
251
+ im = cv2.resize(im, (512, 512))
252
 
253
  # im = 255-im
254
  im_edge = im.copy()
255
+ im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
256
+ im = im > 0.5
257
  im = im.float()
258
 
259
  # # save gpu memory
 
261
  # self.model_sketch = self.model_sketch.cuda()
262
  # self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
263
  # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
264
+
265
  # extract condition features
266
+ c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
267
  nc = self.base_model.get_learned_conditioning([neg_prompt])
268
  features_adapter = self.model_sketch(im.to(self.device))
269
  shape = [4, 64, 64]
 
275
 
276
  # sampling
277
  samples_ddim, _ = self.sampler.sample(S=50,
278
+ conditioning=c,
279
+ batch_size=1,
280
+ shape=shape,
281
+ verbose=False,
282
+ unconditional_guidance_scale=scale,
283
+ unconditional_conditioning=nc,
284
+ eta=0.0,
285
+ x_T=None,
286
+ features_adapter1=features_adapter,
287
+ mode='sketch',
288
+ con_strength=con_strength)
289
+
290
  # # save gpu memory
291
  # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
292
 
 
294
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
295
  x_samples_ddim = x_samples_ddim.to('cpu')
296
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
297
+ x_samples_ddim = 255. * x_samples_ddim
298
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
299
 
300
  return [im_edge, x_samples_ddim]
301
 
302
  @torch.no_grad()
303
+ def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength,
304
+ base_model):
305
  if self.current_base != base_model:
306
  ckpt = os.path.join("models", base_model)
307
+ pl_sd = torch.load(ckpt, map_location="cuda")
308
  if "state_dict" in pl_sd:
309
  sd = pl_sd["state_dict"]
310
  else:
311
  sd = pl_sd
312
+ # self.base_model = self.base_model.cpu()
313
  self.base_model.load_state_dict(sd, strict=False)
314
+ # self.base_model = self.base_model.cuda()
315
  self.current_base = base_model
316
+ con_strength = int((1 - con_strength) * 50)
317
  if fix_sample == 'True':
318
  seed_everything(42)
319
+ im = cv2.resize(input_img, (512, 512))
320
 
321
  if type_in == 'Keypose':
322
  im_pose = im.copy()
323
+ im = img2tensor(im).unsqueeze(0) / 255.
324
  elif type_in == 'Image':
325
  image = im.copy()
326
+ im = img2tensor(im).unsqueeze(0) / 255.
327
  mmdet_results = inference_detector(self.det_model, image)
328
  # keep the person class bounding boxes.
329
  person_results = process_mmdet_results(mmdet_results, self.det_cat_id)
 
354
  pose_link_color=self.pose_link_color,
355
  radius=2,
356
  thickness=2)
357
+ im_pose = cv2.resize(im_pose, (512, 512))
358
+
359
  # # save gpu memory
360
  # self.base_model.model = self.base_model.model.cpu()
361
  # self.model_pose = self.model_pose.cuda()
 
363
  # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
364
 
365
  # extract condition features
366
+ c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
367
  nc = self.base_model.get_learned_conditioning([neg_prompt])
368
+ pose = img2tensor(im_pose, bgr2rgb=True, float32=True) / 255.
369
  pose = pose.unsqueeze(0)
370
  features_adapter = self.model_pose(pose.to(self.device))
371
 
 
378
 
379
  # sampling
380
  samples_ddim, _ = self.sampler.sample(S=50,
381
+ conditioning=c,
382
+ batch_size=1,
383
+ shape=shape,
384
+ verbose=False,
385
+ unconditional_guidance_scale=scale,
386
+ unconditional_conditioning=nc,
387
+ eta=0.0,
388
+ x_T=None,
389
+ features_adapter1=features_adapter,
390
+ mode='sketch',
391
+ con_strength=con_strength)
392
 
393
  # # save gpu memory
394
  # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
 
397
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
398
  x_samples_ddim = x_samples_ddim.to('cpu')
399
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
400
+ x_samples_ddim = 255. * x_samples_ddim
401
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
402
 
403
+ return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
404
+
405
 
406
  if __name__ == '__main__':
407
  model = Model_all('cpu')