Spaces:
Runtime error
Runtime error
File size: 7,077 Bytes
60ae8ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import modules.core as core
import os
import gc
import torch
import numpy as np
import modules.path
from comfy.model_base import SDXL, SDXLRefiner
from comfy.model_management import soft_empty_cache
from PIL import Image, ImageOps
xl_base: core.StableDiffusionModel = None
xl_base_hash = ''
xl_refiner: core.StableDiffusionModel = None
xl_refiner_hash = ''
xl_base_patched: core.StableDiffusionModel = None
xl_base_patched_hash = ''
def refresh_base_model(name):
global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash
if xl_base_hash == str(name):
return
filename = os.path.join(modules.path.modelfile_path, name)
if xl_base is not None:
xl_base.to_meta()
xl_base = None
xl_base = core.load_model(filename)
if not isinstance(xl_base.unet.model, SDXL):
print('Model not supported. Fooocus only support SDXL model as the base model.')
xl_base = None
xl_base_hash = ''
refresh_base_model(modules.path.default_base_model_name)
xl_base_hash = name
xl_base_patched = xl_base
xl_base_patched_hash = ''
return
xl_base_hash = name
xl_base_patched = xl_base
xl_base_patched_hash = ''
print(f'Base model loaded: {xl_base_hash}')
return
def refresh_refiner_model(name):
global xl_refiner, xl_refiner_hash
if xl_refiner_hash == str(name):
return
if name == 'None':
xl_refiner = None
xl_refiner_hash = ''
print(f'Refiner unloaded.')
return
filename = os.path.join(modules.path.modelfile_path, name)
if xl_refiner is not None:
xl_refiner.to_meta()
xl_refiner = None
xl_refiner = core.load_model(filename)
if not isinstance(xl_refiner.unet.model, SDXLRefiner):
print('Model not supported. Fooocus only support SDXL refiner as the refiner.')
xl_refiner = None
xl_refiner_hash = ''
print(f'Refiner unloaded.')
return
xl_refiner_hash = name
print(f'Refiner model loaded: {xl_refiner_hash}')
xl_refiner.vae.first_stage_model.to('meta')
xl_refiner.vae = None
return
def refresh_loras(loras):
global xl_base, xl_base_patched, xl_base_patched_hash
if xl_base_patched_hash == str(loras):
return
model = xl_base
for name, weight in loras:
if name == 'None':
continue
filename = os.path.join(modules.path.lorafile_path, name)
model = core.load_lora(model, filename, strength_model=weight, strength_clip=weight)
xl_base_patched = model
xl_base_patched_hash = str(loras)
print(f'LoRAs loaded: {xl_base_patched_hash}')
return
refresh_base_model(modules.path.default_base_model_name)
refresh_refiner_model(modules.path.default_refiner_model_name)
refresh_loras([(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)])
positive_conditions_cache = None
negative_conditions_cache = None
positive_conditions_refiner_cache = None
negative_conditions_refiner_cache = None
def clean_prompt_cond_caches():
global positive_conditions_cache, negative_conditions_cache, \
positive_conditions_refiner_cache, negative_conditions_refiner_cache
positive_conditions_cache = None
negative_conditions_cache = None
positive_conditions_refiner_cache = None
negative_conditions_refiner_cache = None
return
@torch.no_grad()
def process(positive_prompt, negative_prompt, steps, switch, width, height, image_seed, sampler_name, scheduler, cfg, base_clip_skip, refiner_clip_skip, input_image_path, start_step, denoise, callback):
global positive_conditions_cache, negative_conditions_cache, \
positive_conditions_refiner_cache, negative_conditions_refiner_cache
xl_base_patched.clip.clip_layer(base_clip_skip)
positive_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=positive_prompt) if positive_conditions_cache is None else positive_conditions_cache
negative_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=negative_prompt) if negative_conditions_cache is None else negative_conditions_cache
positive_conditions_cache = positive_conditions
negative_conditions_cache = negative_conditions
if input_image_path == None:
latent = core.generate_empty_latent(width=width, height=height, batch_size=1)
force_full_denoise = True
denoise = None
else:
with open(input_image_path, 'rb') as image_file:
pil_image = Image.open(image_file)
image = ImageOps.exif_transpose(pil_image)
image_file.close()
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
input_image = core.upscale(image)
latent = core.encode_vae(vae=xl_base_patched.vae, pixels=input_image)
force_full_denoise = False
if xl_refiner is not None:
xl_refiner.clip.clip_layer(refiner_clip_skip)
positive_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=positive_prompt) if positive_conditions_refiner_cache is None else positive_conditions_refiner_cache
negative_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=negative_prompt) if negative_conditions_refiner_cache is None else negative_conditions_refiner_cache
positive_conditions_refiner_cache = positive_conditions_refiner
negative_conditions_refiner_cache = negative_conditions_refiner
sampled_latent = core.ksampler_with_refiner(
model=xl_base_patched.unet,
positive=positive_conditions,
negative=negative_conditions,
refiner=xl_refiner.unet,
refiner_positive=positive_conditions_refiner,
refiner_negative=negative_conditions_refiner,
refiner_switch_step=switch,
latent=latent,
steps=steps, start_step=start_step, last_step=steps,
disable_noise=False, force_full_denoise=force_full_denoise, denoise=denoise,
seed=image_seed,
sampler_name=sampler_name,
scheduler=scheduler,
cfg=cfg,
callback_function=callback
)
else:
sampled_latent = core.ksampler(
model=xl_base_patched.unet,
positive=positive_conditions,
negative=negative_conditions,
latent=latent,
steps=steps, start_step=start_step, last_step=steps,
disable_noise=False, force_full_denoise=force_full_denoise, denoise=denoise,
seed=image_seed,
sampler_name=sampler_name,
scheduler=scheduler,
cfg=cfg,
callback_function=callback
)
decoded_latent = core.decode_vae(vae=xl_base_patched.vae, latent_image=sampled_latent)
images = core.image_to_numpy(decoded_latent)
gc.collect()
soft_empty_cache()
return images
|