aai / tabs /images /handlers.py
barreloflube's picture
Refactor image_tab function to update logging of pipeline mode and include request details
84df228
import gc
import random
import gradio as gr
import torch
from controlnet_aux.processor import Processor
from safetensors.torch import load_file
from diffusers import (
AutoPipelineForText2Image,
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
FluxPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxControlNetPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
)
from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1, get_weighted_text_embeddings_sdxl
from huggingface_hub import hf_hub_download
from diffusers.schedulers import *
from .models import *
from .load_models import device, models, flux_vae, sdxl_vae, refiner, controlnets
sd_pipes = (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline,
StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline)
flux_pipes = (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxControlNetPipeline)
def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
for model in models:
if model['repo_id'] == request.model:
pipe_args = {
"pipeline": model['pipeline'],
}
# Set ControlNet config
if request.controlnet_config:
pipe_args["controlnet"] = []
if model['loader'] == 'sdxl' or model['loader'] == 'flux':
for controlnet in controlnets:
if request.controlnet_config.controlnet in controlnet['layers']:
pipe_args["controlnet"].append(controlnet['controlnet'])
elif model['loader'] == 'flux-multi':
controlnet = next((controlnet for controlnet in controlnets if controlnet['loader'] == 'flux-multi'), None)
if controlnet is not None:
# control_mode = list of index of layers
pipe_args['control_mode'] = [controlnet['layers'].index(layer) for layer in request.controlnet_config.controlnet]
pipe_args['controlnet'].append(controlnet['controlnet'])
# Choose Pipeline Mode
if not request.custom_addons:
if isinstance(request, BaseInpaintReq):
pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
elif isinstance(request, BaseImg2ImgReq):
pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
elif isinstance(request, BaseReq):
pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
elif request.custom_addons:
pipe_args['pipeline'] = None
# Enable or Disable Vae
if request.vae:
pipe_args["pipeline"].vae = sdxl_vae if model['loader'] == 'sdxl' else flux_vae
elif not request.vae:
pipe_args["pipeline"].vae = None if model['loader'] == 'sdxl' else flux_vae
# Set Scheduler
pipe_args["pipeline"].scheduler = load_scheduler(pipe_args["pipeline"], request.scheduler)
# Set Loras
if request.loras:
for i, lora in enumerate(request.loras):
pipe_args["pipeline"].load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}")
adapter_names = [f"lora_{i}" for i in range(len(request.loras))]
adapter_weights = [lora['weight'] for lora in request.loras]
if request.fast_generation:
hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors") if model['loader'] == 'flux' \
else hf_hub_download("ByteDance/Hyper-SD", "Hyper-SDXL-8steps-lora.safetensors")
hyper_weight = 0.125 if model['loader'] == 'flux' else 1.0
pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora")
pipe_args["pipeline"].set_adapters(["hyper_lora"], [hyper_weight])
pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights)
# Set Embeddings
if request.embeddings and model['loader'] == 'sdxl':
for embedding in request.embeddings:
state_dict = load_file(hf_hub_download(embedding['repo_id']))
pipe_args["pipeline"].load_textual_inversion(state_dict['clip_g'], token=embedding['token'], text_encoder=pipe_args["pipeline"].text_encoder_2, tokenizer=pipe_args["pipeline"].tokenizer_2)
pipe_args["pipeline"].load_textual_inversion(state_dict["clip_l"], token=embedding['token'], text_encoder=pipe_args["pipeline"].text_encoder, tokenizer=pipe_args["pipeline"].tokenizer)
return pipe_args
def load_scheduler(pipeline, scheduler):
schedulers = {
"dpmpp_2m": (DPMSolverMultistepScheduler, {}),
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++"}),
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True}),
"dpmpp_sde": (DPMSolverSinglestepScheduler, {}),
"dpmpp_sde_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
"dpm2": (KDPM2DiscreteScheduler, {}),
"dpm2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
"dpm2_a": (KDPM2AncestralDiscreteScheduler, {}),
"dpm2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
"euler": (EulerDiscreteScheduler, {}),
"euler_a": (EulerAncestralDiscreteScheduler, {}),
"heun": (HeunDiscreteScheduler, {}),
"lms": (LMSDiscreteScheduler, {}),
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
"deis": (DEISMultistepScheduler, {}),
"unipc": (UniPCMultistepScheduler, {}),
"fm_euler": (FlowMatchEulerDiscreteScheduler, {}),
}
scheduler_class, kwargs = schedulers.get(scheduler, (None, {}))
if scheduler_class is not None:
scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs)
else:
raise ValueError(f"Unknown scheduler: {scheduler}")
return scheduler
def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str):
for image in images:
if resize_mode == "resize_only":
image = image.resize((width, height))
elif resize_mode == "crop_and_resize":
image = image.crop((0, 0, width, height))
elif resize_mode == "resize_and_fill":
image = image.resize((width, height), Image.Resampling.LANCZOS)
return images
def get_controlnet_images(controlnets: List[str], control_images: List[Image.Image], height: int, width: int, resize_mode: str):
response_images = []
control_images = resize_images(control_images, height, width, resize_mode)
for controlnet, image in zip(controlnets, control_images):
if controlnet == "canny":
processor = Processor('canny')
elif controlnet == "depth":
processor = Processor('depth_midas')
elif controlnet == "pose":
processor = Processor('openpose_full')
elif controlnet == "scribble":
processor = Processor('scribble')
else:
raise ValueError(f"Invalid Controlnet: {controlnet}")
response_images.append(processor(image, to_pil=True))
return response_images
def get_control_mode(controlnet_config: ControlNetReq):
control_mode = []
for controlnet in controlnets:
if controlnet['loader'] == 'flux-multi':
layers = controlnet['layers']
for c in controlnet_config.controlnets:
if c in layers:
control_mode.append(layers.index(c))
return control_mode
# def check_image_safety(images: List[Image.Image]):
# safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
# has_nsfw_concepts = safety_checker(
# images=[images],
# clip_input=safety_checker_input.pixel_values.to("cuda"),
# )
# return has_nsfw_concepts[1]
# def get_prompt_attention(pipeline, prompt, negative_prompt):
# if isinstance(pipeline, flux_pipes):
# prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt, device=device)
# return prompt_embeds, None, pooled_prompt_embeds, None
# elif isinstance(pipeline, sd_pipes):
# prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt, device=device)
# return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def cleanup(pipeline, loras = None, embeddings = None):
if loras:
# pipeline.disable_lora()
pipeline.unload_lora_weights()
if embeddings:
pipeline.unload_textual_inversion()
gc.collect()
torch.cuda.empty_cache()
# Gen Function
def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Progress(track_tqdm=True)):
progress(0.1, "Loading Pipeline")
pipeline_args = get_pipe(request)
pipeline = pipeline_args["pipeline"]
try:
progress(0.3, "Getting Prompt Embeddings")
# Get Prompt Embeddings
if isinstance(pipeline, flux_pipes):
positive_prompt_embeds, positive_prompt_pooled = get_weighted_text_embeddings_flux1(pipeline, request.prompt)
elif isinstance(pipeline, sd_pipes):
positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_weighted_text_embeddings_sdxl(pipeline, request.prompt, request.negative_prompt)
progress(0.5, "Configuring Pipeline")
# Common Args
args = {
'prompt_embeds': positive_prompt_embeds,
'pooled_prompt_embeds': positive_prompt_pooled,
'height': request.height,
'width': request.width,
'num_images_per_prompt': request.num_images_per_prompt,
'num_inference_steps': request.num_inference_steps,
'guidance_scale': request.guidance_scale,
'generator': [torch.Generator().manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator().manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
}
if isinstance(pipeline, sd_pipes):
args['clip_skip'] = request.clip_skip
args['negative_prompt_embeds'] = negative_prompt_embeds
args['negative_pooled_prompt_embeds'] = negative_prompt_pooled
if request.controlnet_config:
args['control_image'] = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode)
args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
if request.controlnet_config and isinstance(pipeline, flux_pipes):
args['control_mode'] = get_control_mode(request.controlnet_config)
if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)):
args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0]
args['strength'] = request.strength
if isinstance(request, BaseInpaintReq):
args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
# Generate
progress(0.9, "Generating Images")
gr.Info(f"Request {type(request)}: {str(request.__dict__)}", duration=60)
images = pipeline(**args).images
# Refiner
if request.refiner:
images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
progress(1.0, "Cleaning Up")
cleanup(pipeline, request.loras, request.embeddings)
return images
except Exception as e:
cleanup(pipeline, request.loras, request.embeddings)
raise gr.Error(f"Error: {e}")