English
File size: 3,840 Bytes
e4e2447
0f93b31
e4e2447
2b7dcd2
817434a
e54933c
817434a
0f93b31
817434a
0f93b31
 
 
 
 
 
 
 
a717c03
e4e2447
 
817434a
 
 
 
0f93b31
 
 
 
 
 
817434a
8b61e71
 
 
 
 
 
0f93b31
 
 
 
 
817434a
0f93b31
 
 
 
 
 
 
817434a
0f93b31
 
 
e4e2447
0f93b31
 
 
 
 
 
 
 
 
 
 
 
e4e2447
0f93b31
 
 
 
 
e4e2447
0f93b31
 
 
 
 
 
 
 
 
 
 
 
 
 
264a49a
 
0f93b31
 
264a49a
0f93b31
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from typing import Any, Dict, Union
from PIL import Image
import torch
from diffusers import FluxPipeline
from huggingface_inference_toolkit.logging import logger
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
from torchao.quantization import autoquant
import time
import gc

# Set high precision for float32 matrix multiplications.
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
torch.set_float32_matmul_precision("high")

import torch._dynamo
torch._dynamo.config.suppress_errors = False # for debugging

class EndpointHandler:
    def __init__(self, path=""):
        self.pipe = FluxPipeline.from_pretrained(
            "NoMoreCopyrightOrg/flux-dev",
            torch_dtype=torch.bfloat16,
        ).to("cuda")
        self.pipe.enable_vae_slicing()
        self.pipe.enable_vae_tiling()
        self.pipe.transformer.fuse_qkv_projections()
        self.pipe.vae.fuse_qkv_projections()
        self.pipe.transformer.to(memory_format=torch.channels_last)
        self.pipe.vae.to(memory_format=torch.channels_last)
        apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
        self.pipe.transformer = torch.compile(
            self.pipe.transformer, mode="max-autotune-no-cudagraphs",
        )
        self.pipe.vae = torch.compile(
            self.pipe.vae, mode="max-autotune-no-cudagraphs",
        )
        self.pipe.transformer = autoquant(self.pipe.transformer, error_on_unseen=False)
        self.pipe.vae = autoquant(self.pipe.vae, error_on_unseen=False)
        
        gc.collect()
        torch.cuda.empty_cache()

        start_time = time.time()
        print("Start warming-up pipeline")
        self.pipe("Hello world!") # Warm-up for compiling
        end_time = time.time()
        time_taken = end_time - start_time
        print(f"Time taken: {time_taken:.2f} seconds")
        self.record=0

    def __call__(self, data: Dict[str, Any]) -> Union[Image.Image, None]:
        try:
            logger.info(f"Received incoming request with {data=}")

            if "inputs" in data and isinstance(data["inputs"], str):
                prompt = data.pop("inputs")
            elif "prompt" in data and isinstance(data["prompt"], str):
                prompt = data.pop("prompt")
            else:
                raise ValueError(
                    "Provided input body must contain either the key `inputs` or `prompt` with the"
                    " prompt to use for the image generation, and it needs to be a non-empty string."
                )
            if prompt=="get_queue":
                return self.record
            parameters = data.pop("parameters", {})

            num_inference_steps = parameters.get("num_inference_steps", 28)
            width = parameters.get("width", 1024)
            height = parameters.get("height", 1024)
            #guidance_scale = parameters.get("guidance_scale", 3.5)
            guidance_scale = parameters.get("guidance", 3.5)

            # seed generator (seed cannot be provided as is but via a generator)
            seed = parameters.get("seed", 0)
            generator = torch.manual_seed(seed)
            self.record+=1
            start_time = time.time()
            result = self.pipe(  # type: ignore
                prompt,
                height=height,
                width=width,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                generator=generator,
            ).images[0]
            end_time = time.time()
            time_taken = end_time - start_time
            print(f"Time taken: {time_taken:.2f} seconds")
            self.record-=1

            return result
        except Exception as e:
            print(e)
            return None