edge-12 / src /pipeline.py
agentbot's picture
Initial commit with folder contents
e5b7414 verified
import torch
from pathlib import Path
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
from pipelines.models import TextToImageRequest
from torch import Generator
from cache_diffusion import cachify
from trt_pipeline.deploy import load_unet_trt
from loss import SchedulerWrapper
import numpy as np
generator = Generator(torch.device("cuda")).manual_seed(69)
SDXL_DEFAULT_CONFIG = [
{
"wildcard_or_filter_func": lambda name: "down_blocks.2" not in name and"down_blocks.3" not in name and "up_blocks.2" not in name,
"select_cache_step_func": lambda step: (step % 2 != 0) and (step >= 10),
}]
def load_pipeline() -> StableDiffusionXLPipeline:
pipe = StableDiffusionXLPipeline.from_pretrained(
"models/newdream-sdxl-20", torch_dtype=torch.float16, use_safetensors=True, local_files_only=True
).to("cuda")
load_unet_trt(
pipe.unet,
engine_path=Path("./engine"),
batch_size=1,
)
cachify.prepare(pipe, SDXL_DEFAULT_CONFIG)
cachify.enable(pipe)
pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config))
with cachify.infer(pipe) as cached_pipe:
for _ in range(4):
pipe(prompt="a photo of table", num_inference_steps=14)
cachify.disable(pipe)
pipe.scheduler.prepare_loss()
return pipe
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
if request.seed is None:
generator = None
else:
generator = Generator(pipeline.device).manual_seed(request.seed)
cachify.prepare(pipeline, SDXL_DEFAULT_CONFIG)
cachify.enable(pipeline)
with cachify.infer(pipeline) as cached_pipe:
image = cached_pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
generator=generator,
num_inference_steps=14,
).images[0]
filtered_image = pixel_filter(image)
return filtered_image