|
from diffusers import DiffusionPipeline, AutoencoderTiny |
|
from latent_consistency_controlnet import LatentConsistencyModelPipeline_controlnet |
|
|
|
from compel import Compel |
|
import torch |
|
|
|
try: |
|
import intel_extension_for_pytorch as ipex |
|
except: |
|
pass |
|
|
|
import psutil |
|
from config import Args |
|
from pydantic import BaseModel |
|
from PIL import Image |
|
from typing import Callable |
|
|
|
base_model = "SimianLuo/LCM_Dreamshaper_v7" |
|
WIDTH = 512 |
|
HEIGHT = 512 |
|
|
|
|
|
class Pipeline: |
|
class InputParams(BaseModel): |
|
seed: int = 2159232 |
|
prompt: str |
|
guidance_scale: float = 8.0 |
|
strength: float = 0.5 |
|
steps: int = 4 |
|
lcm_steps: int = 50 |
|
width: int = WIDTH |
|
height: int = HEIGHT |
|
|
|
@staticmethod |
|
def create_pipeline( |
|
args: Args, device: torch.device, torch_dtype: torch.dtype |
|
) -> Callable[["Pipeline.InputParams"], Image.Image]: |
|
if args.safety_checker: |
|
pipe = DiffusionPipeline.from_pretrained(base_model) |
|
else: |
|
pipe = DiffusionPipeline.from_pretrained(base_model, safety_checker=None) |
|
if args.use_taesd: |
|
pipe.vae = AutoencoderTiny.from_pretrained( |
|
"madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True |
|
) |
|
|
|
pipe.set_progress_bar_config(disable=True) |
|
pipe.to(device=device, dtype=torch_dtype) |
|
pipe.unet.to(memory_format=torch.channels_last) |
|
|
|
|
|
if psutil.virtual_memory().total < 64 * 1024**3: |
|
pipe.enable_attention_slicing() |
|
|
|
if args.torch_compile: |
|
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) |
|
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True) |
|
|
|
pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0) |
|
|
|
compel_proc = Compel( |
|
tokenizer=pipe.tokenizer, |
|
text_encoder=pipe.text_encoder, |
|
truncate_long_prompts=False, |
|
) |
|
|
|
def predict(params: "Pipeline.InputParams") -> Image.Image: |
|
generator = torch.manual_seed(params.seed) |
|
prompt_embeds = compel_proc(params.prompt) |
|
|
|
results = pipe( |
|
prompt_embeds=prompt_embeds, |
|
generator=generator, |
|
num_inference_steps=params.steps, |
|
guidance_scale=params.guidance_scale, |
|
width=params.width, |
|
height=params.height, |
|
original_inference_steps=params.lcm_steps, |
|
output_type="pil", |
|
) |
|
nsfw_content_detected = ( |
|
results.nsfw_content_detected[0] |
|
if "nsfw_content_detected" in results |
|
else False |
|
) |
|
if nsfw_content_detected: |
|
return None |
|
return results.images[0] |
|
|
|
return predict |
|
|