File size: 2,249 Bytes
fe14a7c
 
 
 
8f5e3f9
fe14a7c
8f5e3f9
 
782a979
 
fe14a7c
2ede802
8f5e3f9
 
 
 
 
 
 
 
 
 
 
782a979
2ede802
fe14a7c
 
 
 
 
 
 
 
 
 
 
782a979
8f5e3f9
 
 
 
 
 
 
 
 
 
 
782a979
 
 
 
 
 
 
 
8f5e3f9
 
 
025bb73
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
from typing import Any, Dict
from PIL import Image
import torch
from diffusers import FluxPipeline
from huggingface_inference_toolkit.logging import logger
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
import time


class EndpointHandler:
    def __init__(self, path=""):
        self.pipe = FluxPipeline.from_pretrained(
            "NoMoreCopyrightOrg/flux-dev",
            torch_dtype=torch.bfloat16,
        ).to("cuda")
        apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
        self.pipe.transformer = torch.compile(
            self.pipe.transformer, mode="max-autotune-no-cudagraphs",
        )
        self.pipe.vae = torch.compile(
            self.pipe.vae, mode="max-autotune-no-cudagraphs",
        )

    def __call__(self, data: Dict[str, Any]) -> str:
        logger.info(f"Received incoming request with {data=}")

        if "inputs" in data and isinstance(data["inputs"], str):
            prompt = data.pop("inputs")
        elif "prompt" in data and isinstance(data["prompt"], str):
            prompt = data.pop("prompt")
        else:
            raise ValueError(
                "Provided input body must contain either the key `inputs` or `prompt` with the"
                " prompt to use for the image generation, and it needs to be a non-empty string."
            )

        parameters = data.pop("parameters", {})

        num_inference_steps = parameters.get("num_inference_steps", 28)
        width = parameters.get("width", 1024)
        height = parameters.get("height", 1024)
        guidance_scale = parameters.get("guidance_scale", 3.5)

        # seed generator (seed cannot be provided as is but via a generator)
        seed = parameters.get("seed", 0)
        generator = torch.manual_seed(seed)
        start_time = time.time()
        result = self.pipe(  # type: ignore
            prompt,
            height=height,
            width=width,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images[0]
        end_time = time.time()
        time_taken = end_time - start_time
        print(f"Time taken: {time_taken:.2f} seconds")
        return result