|
import gc |
|
import random |
|
from typing import List, Optional |
|
|
|
import torch |
|
import numpy as np |
|
from pydantic import BaseModel |
|
from PIL import Image |
|
from diffusers import ( |
|
FluxPipeline, |
|
FluxImg2ImgPipeline, |
|
FluxInpaintPipeline, |
|
FluxControlNetPipeline, |
|
StableDiffusionXLPipeline, |
|
StableDiffusionXLImg2ImgPipeline, |
|
StableDiffusionXLInpaintPipeline, |
|
StableDiffusionXLControlNetPipeline, |
|
StableDiffusionXLControlNetImg2ImgPipeline, |
|
StableDiffusionXLControlNetInpaintPipeline, |
|
AutoPipelineForText2Image, |
|
AutoPipelineForImage2Image, |
|
AutoPipelineForInpainting, |
|
) |
|
from diffusers.schedulers import * |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
from controlnet_aux.processor import Processor |
|
from photomaker import ( |
|
PhotoMakerStableDiffusionXLPipeline, |
|
PhotoMakerStableDiffusionXLControlNetPipeline, |
|
analyze_faces |
|
) |
|
from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl, get_weighted_text_embeddings_flux1 |
|
|
|
from .init_sys import device, models, refiner, safety_checker, feature_extractor, controlnet_models, face_detector |
|
|
|
|
|
|
|
class ControlNetReq(BaseModel): |
|
controlnets: List[str] |
|
control_images: List[Image.Image] |
|
controlnet_conditioning_scale: List[float] |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|
|
|
|
class SDReq(BaseModel): |
|
model: str = "" |
|
prompt: str = "" |
|
negative_prompt: Optional[str] = "black-forest-labs/FLUX.1-dev" |
|
fast_generation: Optional[bool] = True |
|
loras: Optional[list] = [] |
|
embeddings: Optional[list] = [] |
|
resize_mode: Optional[str] = "resize_and_fill" |
|
scheduler: Optional[str] = "euler_fl" |
|
height: int = 1024 |
|
width: int = 1024 |
|
num_images_per_prompt: int = 1 |
|
num_inference_steps: int = 8 |
|
guidance_scale: float = 3.5 |
|
seed: Optional[int] = 0 |
|
refiner: bool = False |
|
vae: bool = True |
|
controlnet_config: Optional[ControlNetReq] = None |
|
photomaker_images: Optional[List[Image.Image]] = None |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|
|
|
|
class SDImg2ImgReq(SDReq): |
|
image: Image.Image |
|
strength: float = 1.0 |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|
|
|
|
class SDInpaintReq(SDImg2ImgReq): |
|
mask_image: Image.Image |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|
|
|
|
|
|
def get_controlnet(controlnet_config: ControlNetReq): |
|
control_mode = [] |
|
controlnet = [] |
|
|
|
for m in controlnet_models: |
|
for c in controlnet_config.controlnets: |
|
if c in m["layers"]: |
|
control_mode.append(m["layers"].index(c)) |
|
controlnet.append(m["controlnet"]) |
|
|
|
return controlnet, control_mode |
|
|
|
|
|
def get_pipe(request: SDReq | SDImg2ImgReq | SDInpaintReq): |
|
for m in models: |
|
if m["repo_id"] == request.model: |
|
pipeline = m['pipeline'] |
|
controlnet, control_mode = get_controlnet(request.controlnet_config) if request.controlnet_config else (None, None) |
|
|
|
pipe_args = { |
|
"pipeline": pipeline, |
|
"control_mode": control_mode, |
|
} |
|
if request.controlnet_config: |
|
pipe_args["controlnet"] = controlnet |
|
|
|
if not request.photomaker_images: |
|
if isinstance(request, SDReq): |
|
pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args) |
|
elif isinstance(request, SDImg2ImgReq): |
|
pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args) |
|
elif isinstance(request, SDInpaintReq): |
|
pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args) |
|
else: |
|
raise ValueError(f"Unknown request type: {type(request)}") |
|
elif isinstance(request, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])): |
|
if request.controlnet_config: |
|
pipe_args['pipeline'] = PhotoMakerStableDiffusionXLControlNetPipeline.from_pipe(**pipe_args) |
|
else: |
|
pipe_args['pipeline'] = PhotoMakerStableDiffusionXLPipeline.from_pipe(**pipe_args) |
|
else: |
|
raise ValueError(f"Invalid request type: {type(request)}") |
|
|
|
return pipe_args |
|
|
|
|
|
def load_scheduler(pipeline, scheduler): |
|
schedulers = { |
|
"dpmpp_2m": (DPMSolverMultistepScheduler, {}), |
|
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}), |
|
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++"}), |
|
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True}), |
|
"dpmpp_sde": (DPMSolverSinglestepScheduler, {}), |
|
"dpmpp_sde_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}), |
|
"dpm2": (KDPM2DiscreteScheduler, {}), |
|
"dpm2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}), |
|
"dpm2_a": (KDPM2AncestralDiscreteScheduler, {}), |
|
"dpm2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}), |
|
"euler": (EulerDiscreteScheduler, {}), |
|
"euler_a": (EulerAncestralDiscreteScheduler, {}), |
|
"heun": (HeunDiscreteScheduler, {}), |
|
"lms": (LMSDiscreteScheduler, {}), |
|
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}), |
|
"deis": (DEISMultistepScheduler, {}), |
|
"unipc": (UniPCMultistepScheduler, {}), |
|
"fm_euler": (FlowMatchEulerDiscreteScheduler, {}), |
|
} |
|
scheduler_class, kwargs = schedulers.get(scheduler, (None, {})) |
|
|
|
if scheduler_class is not None: |
|
scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs) |
|
else: |
|
raise ValueError(f"Unknown scheduler: {scheduler}") |
|
|
|
return scheduler |
|
|
|
|
|
def load_loras(pipeline, loras, fast_generation): |
|
for i, lora in enumerate(loras): |
|
pipeline.load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}") |
|
adapter_names = [f"lora_{i}" for i in range(len(loras))] |
|
adapter_weights = [lora['weight'] for lora in loras] |
|
|
|
if fast_generation: |
|
hyper_lora = hf_hub_download( |
|
"ByteDance/Hyper-SD", |
|
"Hyper-FLUX.1-dev-8steps-lora.safetensors" if isinstance(pipeline, FluxPipeline) else "Hyper-SDXL-2steps-lora.safetensors" |
|
) |
|
hyper_weight = 0.125 if isinstance(pipeline, FluxPipeline) else 1.0 |
|
pipeline.load_lora_weights(hyper_lora, adapter_name="hyper_lora") |
|
adapter_names.append("hyper_lora") |
|
adapter_weights.append(hyper_weight) |
|
|
|
pipeline.set_adapters(adapter_names, adapter_weights) |
|
|
|
|
|
def load_xl_embeddings(pipeline, embeddings): |
|
for embedding in embeddings: |
|
state_dict = load_file(hf_hub_download(embedding['repo_id'])) |
|
pipeline.load_textual_inversion(state_dict['clip_g'], token=embedding['token'], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) |
|
pipeline.load_textual_inversion(state_dict["clip_l"], token=embedding['token'], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) |
|
|
|
|
|
def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str): |
|
for image in images: |
|
if resize_mode == "resize_only": |
|
image = image.resize((width, height)) |
|
elif resize_mode == "crop_and_resize": |
|
image = image.crop((0, 0, width, height)) |
|
elif resize_mode == "resize_and_fill": |
|
image = image.resize((width, height), Image.Resampling.LANCZOS) |
|
|
|
return images |
|
|
|
|
|
def get_controlnet_images(controlnets: List[str], control_images: List[Image.Image], height: int, width: int, resize_mode: str): |
|
response_images = [] |
|
control_images = resize_images(control_images, height, width, resize_mode) |
|
for controlnet, image in zip(controlnets, control_images): |
|
if controlnet == "canny" or controlnet == "canny_xs" or controlnet == "canny_fl": |
|
processor = Processor('canny') |
|
elif controlnet == "depth" or controlnet == "depth_xs" or controlnet == "depth_fl": |
|
processor = Processor('depth_midas') |
|
elif controlnet == "pose" or controlnet == "pose_fl": |
|
processor = Processor('openpose_full') |
|
elif controlnet == "scribble": |
|
processor = Processor('scribble') |
|
else: |
|
raise ValueError(f"Invalid Controlnet: {controlnet}") |
|
|
|
response_images.append(processor(image, to_pil=True)) |
|
|
|
return response_images |
|
|
|
|
|
def check_image_safety(images: List[Image.Image]): |
|
safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda") |
|
has_nsfw_concepts = safety_checker( |
|
images=[images], |
|
clip_input=safety_checker_input.pixel_values.to("cuda"), |
|
) |
|
|
|
return has_nsfw_concepts[1] |
|
|
|
|
|
def get_prompt_attention(pipeline, prompt, negative_prompt): |
|
if isinstance(pipeline, (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxControlNetPipeline)): |
|
prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt) |
|
return prompt_embeds, None, pooled_prompt_embeds, None |
|
elif isinstance(pipeline, StableDiffusionXLPipeline): |
|
prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt) |
|
return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds |
|
else: |
|
raise ValueError(f"Invalid pipeline type: {type(pipeline)}") |
|
|
|
|
|
def get_photomaker_images(photomaker_images: List[Image.Image], height: int, width: int, resize_mode: str): |
|
image_input_ids = [] |
|
image_id_embeds = [] |
|
photomaker_images = resize_images(photomaker_images, height, width, resize_mode) |
|
|
|
for image in photomaker_images: |
|
image_input_ids.append(img) |
|
img = np.array(image)[:, :, ::-1] |
|
faces = analyze_faces(face_detector, image) |
|
if len(faces) > 0: |
|
image_id_embeds.append(torch.from_numpy(faces[0]['embeddings'])) |
|
else: |
|
raise ValueError("No face detected in the image") |
|
|
|
return image_input_ids, image_id_embeds |
|
|
|
|
|
def cleanup(pipeline, loras = None, embeddings = None): |
|
if loras: |
|
pipeline.disable_lora() |
|
pipeline.unload_lora_weights() |
|
if embeddings: |
|
pipeline.unload_textual_inversion() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
def gen_img( |
|
request: SDReq | SDImg2ImgReq | SDInpaintReq |
|
): |
|
pipeline_args = get_pipe(request) |
|
pipeline = pipeline_args['pipeline'] |
|
try: |
|
pipeline.scheduler = load_scheduler(pipeline, request.scheduler) |
|
|
|
load_loras(pipeline, request.loras, request.fast_generation) |
|
load_xl_embeddings(pipeline, request.embeddings) |
|
|
|
control_images = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode) if request.controlnet_config else None |
|
photomaker_images, photomaker_id_embeds = get_photomaker_images(request.photomaker_images, request.height, request.width) if request.photomaker_images else (None, None) |
|
|
|
positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt) |
|
|
|
|
|
args = { |
|
'prompt_embeds': positive_prompt_embeds, |
|
'pooled_prompt_embeds': positive_prompt_pooled, |
|
'height': request.height, |
|
'width': request.width, |
|
'num_images_per_prompt': request.num_images_per_prompt, |
|
'num_inference_steps': request.num_inference_steps, |
|
'guidance_scale': request.guidance_scale, |
|
'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)], |
|
} |
|
|
|
if isinstance(pipeline, any([StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, |
|
StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline])): |
|
args['clip_skip'] = request.clip_skip |
|
args['negative_prompt_embeds'] = negative_prompt_embeds |
|
args['negative_pooled_prompt_embeds'] = negative_prompt_pooled |
|
|
|
if isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config: |
|
args['control_mode'] = pipeline_args['control_mode'] |
|
args['control_image'] = control_images |
|
args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale |
|
|
|
if not isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config: |
|
args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale |
|
|
|
if isinstance(request, SDReq): |
|
args['image'] = control_images |
|
elif isinstance(request, (SDImg2ImgReq, SDInpaintReq)): |
|
args['control_image'] = control_images |
|
|
|
if request.photomaker_images and isinstance(pipeline, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])): |
|
args['input_id_images'] = photomaker_images |
|
args['input_id_embeds'] = photomaker_id_embeds |
|
args['start_merge_step'] = 10 |
|
|
|
if isinstance(request, SDImg2ImgReq): |
|
args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode) |
|
args['strength'] = request.strength |
|
elif isinstance(request, SDInpaintReq): |
|
args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode) |
|
args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode) |
|
args['strength'] = request.strength |
|
|
|
images = pipeline(**args).images |
|
|
|
if request.refiner: |
|
images = refiner( |
|
prompt=request.prompt, |
|
num_inference_steps=40, |
|
denoising_start=0.7, |
|
image=images.images |
|
).images |
|
|
|
cleanup(pipeline, request.loras, request.embeddings) |
|
|
|
return images |
|
except Exception as e: |
|
cleanup(pipeline, request.loras, request.embeddings) |
|
raise ValueError(f"Error generating image: {e}") from e |
|
|