radames's picture
new base
cb92d2b
raw
history blame
3.05 kB
from diffusers import DiffusionPipeline, AutoencoderTiny
from latent_consistency_controlnet import LatentConsistencyModelPipeline_controlnet
from compel import Compel
import torch
try:
import intel_extension_for_pytorch as ipex # type: ignore
except:
pass
import psutil
from config import Args
from pydantic import BaseModel
from PIL import Image
from typing import Callable
base_model = "SimianLuo/LCM_Dreamshaper_v7"
WIDTH = 512
HEIGHT = 512
class Pipeline:
class InputParams(BaseModel):
seed: int = 2159232
prompt: str
guidance_scale: float = 8.0
strength: float = 0.5
steps: int = 4
lcm_steps: int = 50
width: int = WIDTH
height: int = HEIGHT
@staticmethod
def create_pipeline(
args: Args, device: torch.device, torch_dtype: torch.dtype
) -> Callable[["Pipeline.InputParams"], Image.Image]:
if args.safety_checker:
pipe = DiffusionPipeline.from_pretrained(base_model)
else:
pipe = DiffusionPipeline.from_pretrained(base_model, safety_checker=None)
if args.use_taesd:
pipe.vae = AutoencoderTiny.from_pretrained(
"madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
)
pipe.set_progress_bar_config(disable=True)
pipe.to(device=device, dtype=torch_dtype)
pipe.unet.to(memory_format=torch.channels_last)
# check if computer has less than 64GB of RAM using sys or os
if psutil.virtual_memory().total < 64 * 1024**3:
pipe.enable_attention_slicing()
if args.torch_compile:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
compel_proc = Compel(
tokenizer=pipe.tokenizer,
text_encoder=pipe.text_encoder,
truncate_long_prompts=False,
)
def predict(params: "Pipeline.InputParams") -> Image.Image:
generator = torch.manual_seed(params.seed)
prompt_embeds = compel_proc(params.prompt)
# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
results = pipe(
prompt_embeds=prompt_embeds,
generator=generator,
num_inference_steps=params.steps,
guidance_scale=params.guidance_scale,
width=params.width,
height=params.height,
original_inference_steps=params.lcm_steps,
output_type="pil",
)
nsfw_content_detected = (
results.nsfw_content_detected[0]
if "nsfw_content_detected" in results
else False
)
if nsfw_content_detected:
return None
return results.images[0]
return predict