File size: 5,145 Bytes
42ae52a 37112ef 42ae52a 3a5022f 37112ef daf9c75 42ae52a 37112ef 42ae52a 07dc8e6 42ae52a 37112ef 07dc8e6 42ae52a a685f13 42ae52a a685f13 07dc8e6 42ae52a 37112ef 42ae52a 07dc8e6 42ae52a 37112ef 07dc8e6 42ae52a ecfb7d9 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 37112ef 42ae52a 07dc8e6 42ae52a b6b27f8 42ae52a 37112ef 42ae52a b6b27f8 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a ab735b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import random
import gradio as gr
import torch
from diffusers import (
AutoPipelineForText2Image,
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
)
from huggingface_hub import hf_hub_download
from diffusers.schedulers import *
# from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
from .common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq, cleanup, get_controlnet_images, resize_images
from modules.pipelines.sdxl_pipelines import device, models, sdxl_vae, controlnets
from modules.pipelines.common_pipelines import refiner
def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
def get_scheduler(pipeline, scheduler: str):
...
for m in models:
if m['repo_id'] == request.model:
pipe_args = {
"pipeline": m['pipeline'],
}
# Set ControlNet config
if request.controlnet_config:
pipe_args["controlnet"] = [controlnets]
# Choose Pipeline Mode
if isinstance(request, BaseInpaintReq):
pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
elif isinstance(request, BaseImg2ImgReq):
pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
elif isinstance(request, BaseReq):
pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
# Enable or Disable Refiner
if request.vae:
pipe_args["pipeline"].vae = sdxl_vae
elif not request.vae:
pipe_args["pipeline"].vae = None
# Set Scheduler
pipe_args["pipeline"].scheduler = get_scheduler(pipe_args["pipeline"], request.scheduler)
# Set Loras
if request.loras:
for i, lora in enumerate(request.loras):
pipe_args["pipeline"].load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}")
adapter_names = [f"lora_{i}" for i in range(len(request.loras))]
adapter_weights = [lora['weight'] for lora in request.loras]
if request.fast_generation:
hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
hyper_weight = 0.125
pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora")
adapter_names.append("hyper_lora")
adapter_weights.append(hyper_weight)
pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights)
# Set Embeddings
if request.embeddings:
...
return pipe_args
def get_prompt_attention(pipeline, prompt):
return get_weighted_text_embeddings_flux1(pipeline, prompt)
# Gen Function
def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
pipe_args = get_pipe(request)
pipeline = pipe_args["pipeline"]
try:
positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt)
# Common Args
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = {
'prompt_embeds': positive_prompt_embeds,
'pooled_prompt_embeds': positive_prompt_pooled,
'height': request.height,
'width': request.width,
'num_images_per_prompt': request.num_images_per_prompt,
'num_inference_steps': request.num_inference_steps,
'clip_skip': request.clip_skip,
'guidance_scale': request.guidance_scale,
'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
}
if request.controlnet_config:
args['control_mode'] = get_control_mode(request.controlnet_config)
args['control_images'] = get_controlnet_images(request.controlnet_config, request.height, request.width, request.resize_mode)
args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)):
args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0]
args['strength'] = request.strength
if isinstance(request, BaseInpaintReq):
args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
# Generate
images = pipeline(**args).images
# Refiner
if request.refiner:
images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
return images
except Exception as e:
cleanup(pipeline, request.loras)
raise gr.Error(f"Error: {e}")
finally:
cleanup(pipeline, request.loras)
|