import random import gradio as gr import torch from diffusers import ( AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting, ) from huggingface_hub import hf_hub_download from diffusers.schedulers import * from huggingface_hub import hf_hub_download from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1 from .common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq, cleanup, get_controlnet_images, resize_images from modules.pipelines.flux_pipelines import device, models, flux_vae, controlnet from modules.pipelines.common_pipelines import refiner def get_control_mode(controlnet_config: ControlNetReq): control_mode = [] layers = ["canny", "tile", "depth", "blur", "pose", "gray", "low_quality"] for c in controlnet_config.controlnets: if c in layers: control_mode.append(layers.index(c)) return control_mode def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq): for m in models: if m['repo_id'] == request.model: pipe_args = { "pipeline": m['pipeline'], } # Set ControlNet config if request.controlnet_config: pipe_args["control_mode"] = get_control_mode(request.controlnet_config) pipe_args["controlnet"] = [controlnet] # Choose Pipeline Mode if isinstance(request, BaseInpaintReq): pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args) elif isinstance(request, BaseImg2ImgReq): pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args) elif isinstance(request, BaseReq): pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args) # Enable or Disable Vae if request.vae: pipe_args["pipeline"].vae = flux_vae elif not request.vae: pipe_args["pipeline"].vae = None # Set Scheduler pipe_args["pipeline"].scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe_args["pipeline"].scheduler.config) # Set Loras if request.loras: for i, lora in enumerate(request.loras): pipe_args["pipeline"].load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}") adapter_names = [f"lora_{i}" for i in range(len(request.loras))] adapter_weights = [lora['weight'] for lora in request.loras] if request.fast_generation: hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors") hyper_weight = 0.125 pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora") adapter_names.append("hyper_lora") adapter_weights.append(hyper_weight) pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights) return pipe_args def get_prompt_attention(pipeline, prompt): return get_weighted_text_embeddings_flux1(pipeline, prompt) # Gen Function def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq): pipe_args = get_pipe(request) pipeline = pipe_args["pipeline"] try: positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt) # Common Args device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args = { 'prompt_embeds': positive_prompt_embeds, 'pooled_prompt_embeds': positive_prompt_pooled, 'height': request.height, 'width': request.width, 'num_images_per_prompt': request.num_images_per_prompt, 'num_inference_steps': request.num_inference_steps, 'guidance_scale': request.guidance_scale, 'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)], } if request.controlnet_config: args['control_mode'] = get_control_mode(request.controlnet_config) args['control_images'] = get_controlnet_images(request.controlnet_config, request.height, request.width, request.resize_mode) args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)): args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0] args['strength'] = request.strength if isinstance(request, BaseInpaintReq): args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0] # Generate images = pipeline(**args).images # Refiner if request.refiner: images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images return images except Exception as e: cleanup(pipeline, request.loras) raise gr.Error(f"Error: {e}") finally: cleanup(pipeline, request.loras)