English
File size: 3,761 Bytes
d3841a2
015668e
d3841a2
 
015668e
d3841a2
015668e
 
 
 
d3841a2
 
 
 
 
015668e
 
d3841a2
 
 
015668e
 
 
 
957bbbe
 
015668e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3841a2
015668e
 
 
 
 
 
d3841a2
015668e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from typing import Any, Dict, Union
from PIL import Image
import torch
from diffusers import FluxPipeline
from huggingface_inference_toolkit.logging import logger
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
from torchao.quantization import autoquant
import time
import gc

# Set high precision for float32 matrix multiplications.
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
torch.set_float32_matmul_precision("high")

import torch._dynamo
torch._dynamo.config.suppress_errors = False # for debugging

class EndpointHandler:
    def __init__(self, path=""):
        self.pipeline = FluxPipeline.from_pretrained(
            "NoMoreCopyrightOrg/flux-dev",
            torch_dtype=torch.bfloat16,
        ).to("cuda")
        self.pipeline.enable_vae_slicing()
        self.pipeline.enable_vae_tiling()
        self.pipeline.transformer.fuse_qkv_projections()
        self.pipeline.vae.fuse_qkv_projections()
        self.pipeline.transformer.to(memory_format=torch.channels_last)
        self.pipeline.vae.to(memory_format=torch.channels_last)
        apply_cache_on_pipe(self.pipeline, residual_diff_threshold=0.12)
        self.pipeline.transformer = torch.compile(
            self.pipeline.transformer, mode="max-autotune-no-cudagraphs",
        )
        self.pipeline.vae = torch.compile(
            self.pipeline.vae, mode="max-autotune-no-cudagraphs",
        )
        self.pipeline.transformer = autoquant(self.pipeline.transformer, error_on_unseen=False)
        self.pipeline.vae = autoquant(self.pipeline.vae, error_on_unseen=False)
        
        gc.collect()
        torch.cuda.empty_cache()

        start_time = time.time()
        print("Start warming-up pipeline")
        self.pipeline("Hello world!") # Warm-up for compiling
        end_time = time.time()
        time_taken = end_time - start_time
        print(f"Time taken: {time_taken:.2f} seconds")

    def __call__(self, data: Dict[str, Any]) -> Union[Image.Image, None]:
        logger.info(f"Received incoming request with {data=}")
        try:
            if "inputs" in data and isinstance(data["inputs"], str):
                prompt = data.pop("inputs")
            elif "prompt" in data and isinstance(data["prompt"], str):
                prompt = data.pop("prompt")
            else:
                raise ValueError(
                    "Provided input body must contain either the key `inputs` or `prompt` with the"
                    " prompt to use for the image generation, and it needs to be a non-empty string."
                )

            parameters = data.pop("parameters", {})

            num_inference_steps = parameters.get("num_inference_steps", 28)
            width = parameters.get("width", 1024)
            height = parameters.get("height", 1024)
            #guidance_scale = parameters.get("guidance_scale", 3.5)
            guidance_scale = parameters.get("guidance", 3.5)

            # seed generator (seed cannot be provided as is but via a generator)
            seed = parameters.get("seed", 0)
            generator = torch.manual_seed(seed)
            start_time = time.time()
            result = self.pipeline(  # type: ignore
                prompt,
                height=height,
                width=width,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                generator=generator,
            ).images[0]
            end_time = time.time()
            time_taken = end_time - start_time
            print(f"Time taken: {time_taken:.2f} seconds")

            return result
        except Exception as e:
            print(e)
            return None