Fooocus-image / modules /default_pipeline.py
IvanAbramov's picture
Upload folder using huggingface_hub
60ae8ae
import modules.core as core
import os
import gc
import torch
import numpy as np
import modules.path
from comfy.model_base import SDXL, SDXLRefiner
from comfy.model_management import soft_empty_cache
from PIL import Image, ImageOps
xl_base: core.StableDiffusionModel = None
xl_base_hash = ''
xl_refiner: core.StableDiffusionModel = None
xl_refiner_hash = ''
xl_base_patched: core.StableDiffusionModel = None
xl_base_patched_hash = ''
def refresh_base_model(name):
global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash
if xl_base_hash == str(name):
return
filename = os.path.join(modules.path.modelfile_path, name)
if xl_base is not None:
xl_base.to_meta()
xl_base = None
xl_base = core.load_model(filename)
if not isinstance(xl_base.unet.model, SDXL):
print('Model not supported. Fooocus only support SDXL model as the base model.')
xl_base = None
xl_base_hash = ''
refresh_base_model(modules.path.default_base_model_name)
xl_base_hash = name
xl_base_patched = xl_base
xl_base_patched_hash = ''
return
xl_base_hash = name
xl_base_patched = xl_base
xl_base_patched_hash = ''
print(f'Base model loaded: {xl_base_hash}')
return
def refresh_refiner_model(name):
global xl_refiner, xl_refiner_hash
if xl_refiner_hash == str(name):
return
if name == 'None':
xl_refiner = None
xl_refiner_hash = ''
print(f'Refiner unloaded.')
return
filename = os.path.join(modules.path.modelfile_path, name)
if xl_refiner is not None:
xl_refiner.to_meta()
xl_refiner = None
xl_refiner = core.load_model(filename)
if not isinstance(xl_refiner.unet.model, SDXLRefiner):
print('Model not supported. Fooocus only support SDXL refiner as the refiner.')
xl_refiner = None
xl_refiner_hash = ''
print(f'Refiner unloaded.')
return
xl_refiner_hash = name
print(f'Refiner model loaded: {xl_refiner_hash}')
xl_refiner.vae.first_stage_model.to('meta')
xl_refiner.vae = None
return
def refresh_loras(loras):
global xl_base, xl_base_patched, xl_base_patched_hash
if xl_base_patched_hash == str(loras):
return
model = xl_base
for name, weight in loras:
if name == 'None':
continue
filename = os.path.join(modules.path.lorafile_path, name)
model = core.load_lora(model, filename, strength_model=weight, strength_clip=weight)
xl_base_patched = model
xl_base_patched_hash = str(loras)
print(f'LoRAs loaded: {xl_base_patched_hash}')
return
refresh_base_model(modules.path.default_base_model_name)
refresh_refiner_model(modules.path.default_refiner_model_name)
refresh_loras([(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)])
positive_conditions_cache = None
negative_conditions_cache = None
positive_conditions_refiner_cache = None
negative_conditions_refiner_cache = None
def clean_prompt_cond_caches():
global positive_conditions_cache, negative_conditions_cache, \
positive_conditions_refiner_cache, negative_conditions_refiner_cache
positive_conditions_cache = None
negative_conditions_cache = None
positive_conditions_refiner_cache = None
negative_conditions_refiner_cache = None
return
@torch.no_grad()
def process(positive_prompt, negative_prompt, steps, switch, width, height, image_seed, sampler_name, scheduler, cfg, base_clip_skip, refiner_clip_skip, input_image_path, start_step, denoise, callback):
global positive_conditions_cache, negative_conditions_cache, \
positive_conditions_refiner_cache, negative_conditions_refiner_cache
xl_base_patched.clip.clip_layer(base_clip_skip)
positive_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=positive_prompt) if positive_conditions_cache is None else positive_conditions_cache
negative_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=negative_prompt) if negative_conditions_cache is None else negative_conditions_cache
positive_conditions_cache = positive_conditions
negative_conditions_cache = negative_conditions
if input_image_path == None:
latent = core.generate_empty_latent(width=width, height=height, batch_size=1)
force_full_denoise = True
denoise = None
else:
with open(input_image_path, 'rb') as image_file:
pil_image = Image.open(image_file)
image = ImageOps.exif_transpose(pil_image)
image_file.close()
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
input_image = core.upscale(image)
latent = core.encode_vae(vae=xl_base_patched.vae, pixels=input_image)
force_full_denoise = False
if xl_refiner is not None:
xl_refiner.clip.clip_layer(refiner_clip_skip)
positive_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=positive_prompt) if positive_conditions_refiner_cache is None else positive_conditions_refiner_cache
negative_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=negative_prompt) if negative_conditions_refiner_cache is None else negative_conditions_refiner_cache
positive_conditions_refiner_cache = positive_conditions_refiner
negative_conditions_refiner_cache = negative_conditions_refiner
sampled_latent = core.ksampler_with_refiner(
model=xl_base_patched.unet,
positive=positive_conditions,
negative=negative_conditions,
refiner=xl_refiner.unet,
refiner_positive=positive_conditions_refiner,
refiner_negative=negative_conditions_refiner,
refiner_switch_step=switch,
latent=latent,
steps=steps, start_step=start_step, last_step=steps,
disable_noise=False, force_full_denoise=force_full_denoise, denoise=denoise,
seed=image_seed,
sampler_name=sampler_name,
scheduler=scheduler,
cfg=cfg,
callback_function=callback
)
else:
sampled_latent = core.ksampler(
model=xl_base_patched.unet,
positive=positive_conditions,
negative=negative_conditions,
latent=latent,
steps=steps, start_step=start_step, last_step=steps,
disable_noise=False, force_full_denoise=force_full_denoise, denoise=denoise,
seed=image_seed,
sampler_name=sampler_name,
scheduler=scheduler,
cfg=cfg,
callback_function=callback
)
decoded_latent = core.decode_vae(vae=xl_base_patched.vae, latent_image=sampled_latent)
images = core.image_to_numpy(decoded_latent)
gc.collect()
soft_empty_cache()
return images