File size: 3,840 Bytes
e4e2447 0f93b31 e4e2447 2b7dcd2 817434a e54933c 817434a 0f93b31 817434a 0f93b31 a717c03 e4e2447 817434a 0f93b31 817434a 8b61e71 0f93b31 817434a 0f93b31 817434a 0f93b31 e4e2447 0f93b31 e4e2447 0f93b31 e4e2447 0f93b31 264a49a 0f93b31 264a49a 0f93b31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import os
from typing import Any, Dict, Union
from PIL import Image
import torch
from diffusers import FluxPipeline
from huggingface_inference_toolkit.logging import logger
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
from torchao.quantization import autoquant
import time
import gc
# Set high precision for float32 matrix multiplications.
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
torch.set_float32_matmul_precision("high")
import torch._dynamo
torch._dynamo.config.suppress_errors = False # for debugging
class EndpointHandler:
def __init__(self, path=""):
self.pipe = FluxPipeline.from_pretrained(
"NoMoreCopyrightOrg/flux-dev",
torch_dtype=torch.bfloat16,
).to("cuda")
self.pipe.enable_vae_slicing()
self.pipe.enable_vae_tiling()
self.pipe.transformer.fuse_qkv_projections()
self.pipe.vae.fuse_qkv_projections()
self.pipe.transformer.to(memory_format=torch.channels_last)
self.pipe.vae.to(memory_format=torch.channels_last)
apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
self.pipe.transformer = torch.compile(
self.pipe.transformer, mode="max-autotune-no-cudagraphs",
)
self.pipe.vae = torch.compile(
self.pipe.vae, mode="max-autotune-no-cudagraphs",
)
self.pipe.transformer = autoquant(self.pipe.transformer, error_on_unseen=False)
self.pipe.vae = autoquant(self.pipe.vae, error_on_unseen=False)
gc.collect()
torch.cuda.empty_cache()
start_time = time.time()
print("Start warming-up pipeline")
self.pipe("Hello world!") # Warm-up for compiling
end_time = time.time()
time_taken = end_time - start_time
print(f"Time taken: {time_taken:.2f} seconds")
self.record=0
def __call__(self, data: Dict[str, Any]) -> Union[Image.Image, None]:
try:
logger.info(f"Received incoming request with {data=}")
if "inputs" in data and isinstance(data["inputs"], str):
prompt = data.pop("inputs")
elif "prompt" in data and isinstance(data["prompt"], str):
prompt = data.pop("prompt")
else:
raise ValueError(
"Provided input body must contain either the key `inputs` or `prompt` with the"
" prompt to use for the image generation, and it needs to be a non-empty string."
)
if prompt=="get_queue":
return self.record
parameters = data.pop("parameters", {})
num_inference_steps = parameters.get("num_inference_steps", 28)
width = parameters.get("width", 1024)
height = parameters.get("height", 1024)
#guidance_scale = parameters.get("guidance_scale", 3.5)
guidance_scale = parameters.get("guidance", 3.5)
# seed generator (seed cannot be provided as is but via a generator)
seed = parameters.get("seed", 0)
generator = torch.manual_seed(seed)
self.record+=1
start_time = time.time()
result = self.pipe( # type: ignore
prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
end_time = time.time()
time_taken = end_time - start_time
print(f"Time taken: {time_taken:.2f} seconds")
self.record-=1
return result
except Exception as e:
print(e)
return None
|