alvarobartt HF staff commited on
Commit
253e509
·
verified ·
1 Parent(s): 419e49a

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +61 -0
handler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict
3
+
4
+ from diffusers import DiffusionPipeline # type: ignore
5
+ from PIL.Image import Image
6
+ import torch
7
+
8
+ from huggingface_inference_toolkit.logging import logger
9
+
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, model_dir: str, **kwargs: Any) -> None: # type: ignore
13
+ """The current `EndpointHandler` works with any FLUX.1-dev LoRA Adapter."""
14
+ if os.getenv("HF_TOKEN") is None:
15
+ raise ValueError(
16
+ "Since `black-forest-labs/FLUX.1-dev` is a gated model, you will need to provide a valid "
17
+ "`HF_TOKEN` as an environment variable for the handler to work properly."
18
+ )
19
+
20
+ self.pipeline = DiffusionPipeline.from_pretrained(
21
+ "black-forest-labs/FLUX.1-dev",
22
+ torch_dtype=torch.bfloat16,
23
+ token=os.getenv("HF_TOKEN"),
24
+ )
25
+ self.pipeline.load_lora_weights(model_dir)
26
+ self.pipeline.to("cuda")
27
+
28
+ def __call__(self, data: Dict[str, Any]) -> Image:
29
+ logger.info(f"Received incoming request with {data=}")
30
+
31
+ if "inputs" in data and isinstance(data["inputs"], str):
32
+ prompt = data.pop("inputs")
33
+ elif "prompt" in data and isinstance(data["prompt"], str):
34
+ prompt = data.pop("prompt")
35
+ else:
36
+ raise ValueError(
37
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
38
+ " prompt to use for the image generation, and it needs to be a non-empty string."
39
+ )
40
+
41
+ parameters = data.pop("parameters", {})
42
+
43
+ num_inference_steps = parameters.get("num_inference_steps", 30)
44
+ width = parameters.get("width", 1024)
45
+ height = parameters.get("height", 768)
46
+ guidance_scale = parameters.get("guidance_scale", 3.5)
47
+ lora_scale = parameters.get("lora_scale", 1.0)
48
+
49
+ # seed generator (seed cannot be provided as is but via a generator)
50
+ seed = parameters.get("seed", 0)
51
+ generator = torch.manual_seed(seed)
52
+
53
+ return self.pipeline( # type: ignore
54
+ prompt,
55
+ height=height,
56
+ width=width,
57
+ guidance_scale=guidance_scale,
58
+ num_inference_steps=num_inference_steps,
59
+ lora_scale=lora_scale,
60
+ generator=generator,
61
+ ).images[0]