Spaces:
Running
on
Zero
Running
on
Zero
chaojiemao
commited on
Commit
•
5b0cd30
1
Parent(s):
62720f9
Update ace_inference.py
Browse files- ace_inference.py +10 -10
ace_inference.py
CHANGED
@@ -396,9 +396,9 @@ class ACEInference(DiffusionInference):
|
|
396 |
if use_ace and (not is_txt_image or refiner_scale <= 0):
|
397 |
ctx, null_ctx = {}, {}
|
398 |
# Get Noise Shape
|
399 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
400 |
x = self.encode_first_stage(image)
|
401 |
-
self.dynamic_unload(self.first_stage_model,
|
402 |
'first_stage_model',
|
403 |
skip_loaded=True)
|
404 |
noise = [
|
@@ -414,7 +414,7 @@ class ACEInference(DiffusionInference):
|
|
414 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
415 |
|
416 |
# Encode Prompt
|
417 |
-
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
418 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
419 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
420 |
function_name)(prompt)
|
@@ -424,14 +424,14 @@ class ACEInference(DiffusionInference):
|
|
424 |
function_name)(n_prompt)
|
425 |
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
426 |
prompt, edit_image, null_cont, null_cont_mask)
|
427 |
-
self.dynamic_unload(self.cond_stage_model,
|
428 |
'cond_stage_model',
|
429 |
skip_loaded=False)
|
430 |
ctx['crossattn'] = cont
|
431 |
null_ctx['crossattn'] = null_cont
|
432 |
|
433 |
# Encode Edit Images
|
434 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
435 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
436 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
437 |
e_img, e_mask = [], []
|
@@ -442,14 +442,14 @@ class ACEInference(DiffusionInference):
|
|
442 |
m = [None] * len(u)
|
443 |
e_img.append(self.encode_first_stage(u, **kwargs))
|
444 |
e_mask.append([self.interpolate_func(i) for i in m])
|
445 |
-
self.dynamic_unload(self.first_stage_model,
|
446 |
'first_stage_model',
|
447 |
skip_loaded=True)
|
448 |
null_ctx['edit'] = ctx['edit'] = e_img
|
449 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
450 |
|
451 |
# Diffusion Process
|
452 |
-
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
453 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
454 |
with torch.autocast('cuda',
|
455 |
enabled=dtype in ('float16', 'bfloat16'),
|
@@ -490,15 +490,15 @@ class ACEInference(DiffusionInference):
|
|
490 |
guide_rescale=guide_rescale,
|
491 |
return_intermediate=None,
|
492 |
**kwargs)
|
493 |
-
self.dynamic_unload(self.diffusion_model,
|
494 |
'diffusion_model',
|
495 |
skip_loaded=False)
|
496 |
|
497 |
# Decode to Pixel Space
|
498 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
499 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
500 |
x_samples = self.decode_first_stage(samples)
|
501 |
-
self.dynamic_unload(self.first_stage_model,
|
502 |
'first_stage_model',
|
503 |
skip_loaded=False)
|
504 |
x_samples = [x.squeeze(0) for x in x_samples]
|
|
|
396 |
if use_ace and (not is_txt_image or refiner_scale <= 0):
|
397 |
ctx, null_ctx = {}, {}
|
398 |
# Get Noise Shape
|
399 |
+
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
400 |
x = self.encode_first_stage(image)
|
401 |
+
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
402 |
'first_stage_model',
|
403 |
skip_loaded=True)
|
404 |
noise = [
|
|
|
414 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
415 |
|
416 |
# Encode Prompt
|
417 |
+
if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
418 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
419 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
420 |
function_name)(prompt)
|
|
|
424 |
function_name)(n_prompt)
|
425 |
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
426 |
prompt, edit_image, null_cont, null_cont_mask)
|
427 |
+
if use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
|
428 |
'cond_stage_model',
|
429 |
skip_loaded=False)
|
430 |
ctx['crossattn'] = cont
|
431 |
null_ctx['crossattn'] = null_cont
|
432 |
|
433 |
# Encode Edit Images
|
434 |
+
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
435 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
436 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
437 |
e_img, e_mask = [], []
|
|
|
442 |
m = [None] * len(u)
|
443 |
e_img.append(self.encode_first_stage(u, **kwargs))
|
444 |
e_mask.append([self.interpolate_func(i) for i in m])
|
445 |
+
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
446 |
'first_stage_model',
|
447 |
skip_loaded=True)
|
448 |
null_ctx['edit'] = ctx['edit'] = e_img
|
449 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
450 |
|
451 |
# Diffusion Process
|
452 |
+
if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
453 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
454 |
with torch.autocast('cuda',
|
455 |
enabled=dtype in ('float16', 'bfloat16'),
|
|
|
490 |
guide_rescale=guide_rescale,
|
491 |
return_intermediate=None,
|
492 |
**kwargs)
|
493 |
+
if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
|
494 |
'diffusion_model',
|
495 |
skip_loaded=False)
|
496 |
|
497 |
# Decode to Pixel Space
|
498 |
+
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
499 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
500 |
x_samples = self.decode_first_stage(samples)
|
501 |
+
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
502 |
'first_stage_model',
|
503 |
skip_loaded=False)
|
504 |
x_samples = [x.squeeze(0) for x in x_samples]
|