chaojiemao commited on
Commit
5b0cd30
1 Parent(s): 62720f9

Update ace_inference.py

Browse files
Files changed (1) hide show
  1. ace_inference.py +10 -10
ace_inference.py CHANGED
@@ -396,9 +396,9 @@ class ACEInference(DiffusionInference):
396
  if use_ace and (not is_txt_image or refiner_scale <= 0):
397
  ctx, null_ctx = {}, {}
398
  # Get Noise Shape
399
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
400
  x = self.encode_first_stage(image)
401
- self.dynamic_unload(self.first_stage_model,
402
  'first_stage_model',
403
  skip_loaded=True)
404
  noise = [
@@ -414,7 +414,7 @@ class ACEInference(DiffusionInference):
414
  ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
415
 
416
  # Encode Prompt
417
- self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
418
  function_name, dtype = self.get_function_info(self.cond_stage_model)
419
  cont, cont_mask = getattr(get_model(self.cond_stage_model),
420
  function_name)(prompt)
@@ -424,14 +424,14 @@ class ACEInference(DiffusionInference):
424
  function_name)(n_prompt)
425
  null_cont, null_cont_mask = self.cond_stage_embeddings(
426
  prompt, edit_image, null_cont, null_cont_mask)
427
- self.dynamic_unload(self.cond_stage_model,
428
  'cond_stage_model',
429
  skip_loaded=False)
430
  ctx['crossattn'] = cont
431
  null_ctx['crossattn'] = null_cont
432
 
433
  # Encode Edit Images
434
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
435
  edit_image = [to_device(i, strict=False) for i in edit_image]
436
  edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
437
  e_img, e_mask = [], []
@@ -442,14 +442,14 @@ class ACEInference(DiffusionInference):
442
  m = [None] * len(u)
443
  e_img.append(self.encode_first_stage(u, **kwargs))
444
  e_mask.append([self.interpolate_func(i) for i in m])
445
- self.dynamic_unload(self.first_stage_model,
446
  'first_stage_model',
447
  skip_loaded=True)
448
  null_ctx['edit'] = ctx['edit'] = e_img
449
  null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
450
 
451
  # Diffusion Process
452
- self.dynamic_load(self.diffusion_model, 'diffusion_model')
453
  function_name, dtype = self.get_function_info(self.diffusion_model)
454
  with torch.autocast('cuda',
455
  enabled=dtype in ('float16', 'bfloat16'),
@@ -490,15 +490,15 @@ class ACEInference(DiffusionInference):
490
  guide_rescale=guide_rescale,
491
  return_intermediate=None,
492
  **kwargs)
493
- self.dynamic_unload(self.diffusion_model,
494
  'diffusion_model',
495
  skip_loaded=False)
496
 
497
  # Decode to Pixel Space
498
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
499
  samples = unpack_tensor_into_imagelist(latent, x_shapes)
500
  x_samples = self.decode_first_stage(samples)
501
- self.dynamic_unload(self.first_stage_model,
502
  'first_stage_model',
503
  skip_loaded=False)
504
  x_samples = [x.squeeze(0) for x in x_samples]
 
396
  if use_ace and (not is_txt_image or refiner_scale <= 0):
397
  ctx, null_ctx = {}, {}
398
  # Get Noise Shape
399
+ if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
400
  x = self.encode_first_stage(image)
401
+ if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
402
  'first_stage_model',
403
  skip_loaded=True)
404
  noise = [
 
414
  ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
415
 
416
  # Encode Prompt
417
+ if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
418
  function_name, dtype = self.get_function_info(self.cond_stage_model)
419
  cont, cont_mask = getattr(get_model(self.cond_stage_model),
420
  function_name)(prompt)
 
424
  function_name)(n_prompt)
425
  null_cont, null_cont_mask = self.cond_stage_embeddings(
426
  prompt, edit_image, null_cont, null_cont_mask)
427
+ if use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
428
  'cond_stage_model',
429
  skip_loaded=False)
430
  ctx['crossattn'] = cont
431
  null_ctx['crossattn'] = null_cont
432
 
433
  # Encode Edit Images
434
+ if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
435
  edit_image = [to_device(i, strict=False) for i in edit_image]
436
  edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
437
  e_img, e_mask = [], []
 
442
  m = [None] * len(u)
443
  e_img.append(self.encode_first_stage(u, **kwargs))
444
  e_mask.append([self.interpolate_func(i) for i in m])
445
+ if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
446
  'first_stage_model',
447
  skip_loaded=True)
448
  null_ctx['edit'] = ctx['edit'] = e_img
449
  null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
450
 
451
  # Diffusion Process
452
+ if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
453
  function_name, dtype = self.get_function_info(self.diffusion_model)
454
  with torch.autocast('cuda',
455
  enabled=dtype in ('float16', 'bfloat16'),
 
490
  guide_rescale=guide_rescale,
491
  return_intermediate=None,
492
  **kwargs)
493
+ if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
494
  'diffusion_model',
495
  skip_loaded=False)
496
 
497
  # Decode to Pixel Space
498
+ if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
499
  samples = unpack_tensor_into_imagelist(latent, x_shapes)
500
  x_samples = self.decode_first_stage(samples)
501
+ if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
502
  'first_stage_model',
503
  skip_loaded=False)
504
  x_samples = [x.squeeze(0) for x in x_samples]