Spaces:
Running
on
Zero
Running
on
Zero
chaojiemao
commited on
Commit
•
1828f85
1
Parent(s):
5b0cd30
Update ace_inference.py
Browse files- ace_inference.py +11 -10
ace_inference.py
CHANGED
@@ -330,6 +330,7 @@ class ACEInference(DiffusionInference):
|
|
330 |
history_io=None,
|
331 |
tar_index=0,
|
332 |
**kwargs):
|
|
|
333 |
input_image, input_mask = image, mask
|
334 |
g = torch.Generator(device=we.device_id)
|
335 |
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
@@ -396,9 +397,9 @@ class ACEInference(DiffusionInference):
|
|
396 |
if use_ace and (not is_txt_image or refiner_scale <= 0):
|
397 |
ctx, null_ctx = {}, {}
|
398 |
# Get Noise Shape
|
399 |
-
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
400 |
x = self.encode_first_stage(image)
|
401 |
-
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
402 |
'first_stage_model',
|
403 |
skip_loaded=True)
|
404 |
noise = [
|
@@ -414,7 +415,7 @@ class ACEInference(DiffusionInference):
|
|
414 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
415 |
|
416 |
# Encode Prompt
|
417 |
-
if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
418 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
419 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
420 |
function_name)(prompt)
|
@@ -424,14 +425,14 @@ class ACEInference(DiffusionInference):
|
|
424 |
function_name)(n_prompt)
|
425 |
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
426 |
prompt, edit_image, null_cont, null_cont_mask)
|
427 |
-
if use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
|
428 |
'cond_stage_model',
|
429 |
skip_loaded=False)
|
430 |
ctx['crossattn'] = cont
|
431 |
null_ctx['crossattn'] = null_cont
|
432 |
|
433 |
# Encode Edit Images
|
434 |
-
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
435 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
436 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
437 |
e_img, e_mask = [], []
|
@@ -442,14 +443,14 @@ class ACEInference(DiffusionInference):
|
|
442 |
m = [None] * len(u)
|
443 |
e_img.append(self.encode_first_stage(u, **kwargs))
|
444 |
e_mask.append([self.interpolate_func(i) for i in m])
|
445 |
-
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
446 |
'first_stage_model',
|
447 |
skip_loaded=True)
|
448 |
null_ctx['edit'] = ctx['edit'] = e_img
|
449 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
450 |
|
451 |
# Diffusion Process
|
452 |
-
if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
453 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
454 |
with torch.autocast('cuda',
|
455 |
enabled=dtype in ('float16', 'bfloat16'),
|
@@ -490,15 +491,15 @@ class ACEInference(DiffusionInference):
|
|
490 |
guide_rescale=guide_rescale,
|
491 |
return_intermediate=None,
|
492 |
**kwargs)
|
493 |
-
if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
|
494 |
'diffusion_model',
|
495 |
skip_loaded=False)
|
496 |
|
497 |
# Decode to Pixel Space
|
498 |
-
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
499 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
500 |
x_samples = self.decode_first_stage(samples)
|
501 |
-
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
502 |
'first_stage_model',
|
503 |
skip_loaded=False)
|
504 |
x_samples = [x.squeeze(0) for x in x_samples]
|
|
|
330 |
history_io=None,
|
331 |
tar_index=0,
|
332 |
**kwargs):
|
333 |
+
print(kwargs)
|
334 |
input_image, input_mask = image, mask
|
335 |
g = torch.Generator(device=we.device_id)
|
336 |
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
|
|
397 |
if use_ace and (not is_txt_image or refiner_scale <= 0):
|
398 |
ctx, null_ctx = {}, {}
|
399 |
# Get Noise Shape
|
400 |
+
if self.use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
401 |
x = self.encode_first_stage(image)
|
402 |
+
if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
403 |
'first_stage_model',
|
404 |
skip_loaded=True)
|
405 |
noise = [
|
|
|
415 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
416 |
|
417 |
# Encode Prompt
|
418 |
+
if self.use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
419 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
420 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
421 |
function_name)(prompt)
|
|
|
425 |
function_name)(n_prompt)
|
426 |
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
427 |
prompt, edit_image, null_cont, null_cont_mask)
|
428 |
+
if self.use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
|
429 |
'cond_stage_model',
|
430 |
skip_loaded=False)
|
431 |
ctx['crossattn'] = cont
|
432 |
null_ctx['crossattn'] = null_cont
|
433 |
|
434 |
# Encode Edit Images
|
435 |
+
if self.use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
436 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
437 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
438 |
e_img, e_mask = [], []
|
|
|
443 |
m = [None] * len(u)
|
444 |
e_img.append(self.encode_first_stage(u, **kwargs))
|
445 |
e_mask.append([self.interpolate_func(i) for i in m])
|
446 |
+
if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
447 |
'first_stage_model',
|
448 |
skip_loaded=True)
|
449 |
null_ctx['edit'] = ctx['edit'] = e_img
|
450 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
451 |
|
452 |
# Diffusion Process
|
453 |
+
if self.use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
454 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
455 |
with torch.autocast('cuda',
|
456 |
enabled=dtype in ('float16', 'bfloat16'),
|
|
|
491 |
guide_rescale=guide_rescale,
|
492 |
return_intermediate=None,
|
493 |
**kwargs)
|
494 |
+
if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
|
495 |
'diffusion_model',
|
496 |
skip_loaded=False)
|
497 |
|
498 |
# Decode to Pixel Space
|
499 |
+
if self.use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
500 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
501 |
x_samples = self.decode_first_stage(samples)
|
502 |
+
if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
503 |
'first_stage_model',
|
504 |
skip_loaded=False)
|
505 |
x_samples = [x.squeeze(0) for x in x_samples]
|