radames's picture
new base
cb92d2b
from diffusers import DiffusionPipeline, AutoencoderTiny
from compel import Compel
import torch
try:
import intel_extension_for_pytorch as ipex # type: ignore
except:
pass
import psutil
from config import Args
from pydantic import BaseModel
from PIL import Image
from typing import Callable
base_model = "SimianLuo/LCM_Dreamshaper_v7"
taesd_model = "madebyollin/taesd"
class Pipeline:
class InputParams(BaseModel):
seed: int = 2159232
prompt: str = ""
guidance_scale: float = 8.0
strength: float = 0.5
steps: int = 4
width: int = 512
height: int = 512
def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
if args.safety_checker:
self.pipe = DiffusionPipeline.from_pretrained(base_model)
else:
self.pipe = DiffusionPipeline.from_pretrained(
base_model, safety_checker=None
)
if args.use_taesd:
self.pipe.vae = AutoencoderTiny.from_pretrained(
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
)
self.pipe.set_progress_bar_config(disable=True)
self.pipe.to(device=device, dtype=torch_dtype)
self.pipe.unet.to(memory_format=torch.channels_last)
# check if computer has less than 64GB of RAM using sys or os
if psutil.virtual_memory().total < 64 * 1024**3:
self.pipe.enable_attention_slicing()
if args.torch_compile:
self.pipe.unet = torch.compile(
self.pipe.unet, mode="reduce-overhead", fullgraph=True
)
self.pipe.vae = torch.compile(
self.pipe.vae, mode="reduce-overhead", fullgraph=True
)
self.pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
self.compel_proc = Compel(
tokenizer=self.pipe.tokenizer,
text_encoder=self.pipe.text_encoder,
truncate_long_prompts=False,
)
def predict(self, params: "Pipeline.InputParams") -> Image.Image:
generator = torch.manual_seed(params.seed)
prompt_embeds = self.compel_proc(params.prompt)
results = self.pipe(
prompt_embeds=prompt_embeds,
generator=generator,
num_inference_steps=params.steps,
guidance_scale=params.guidance_scale,
width=params.width,
height=params.height,
output_type="pil",
)
nsfw_content_detected = (
results.nsfw_content_detected[0]
if "nsfw_content_detected" in results
else False
)
if nsfw_content_detected:
return None
return results.images[0]