krtk00 commited on
Commit
e3d77b5
·
verified ·
1 Parent(s): 2e28b6e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -23
handler.py CHANGED
@@ -1,19 +1,21 @@
1
- from typing import Dict, Any
2
  from diffusers import AutoPipelineForText2Image
3
  import torch
4
- from PIL import Image
5
- import base64
6
- from io import BytesIO
7
 
8
  class EndpointHandler:
9
  def __init__(self, path: str = ""):
10
  """
11
  Initialize the handler, loading the model and LoRA weights.
12
- The path parameter is provided by Hugging Face Inference Endpoints to point to the model directory.
13
  """
14
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
15
  self.pipeline = AutoPipelineForText2Image.from_pretrained(
16
- 'black-forest-labs/FLUX.1-dev',
 
17
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
18
  ).to(self.device)
19
 
@@ -21,25 +23,10 @@ class EndpointHandler:
21
  lora_weights_path = 'krtk00/pan_crd_lora_v2'
22
  self.pipeline.load_lora_weights(lora_weights_path, weight_name='lora.safetensors')
23
 
24
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
25
- """
26
- This method will be called on every request. The input is expected to be a dictionary
27
- with a key "inputs" containing the text prompt.
28
- """
29
- # Preprocess input
30
  prompt = data.get("inputs", None)
31
  if not prompt:
32
  raise ValueError("No prompt provided in the input")
33
-
34
- # Run inference
35
  with torch.no_grad():
36
  images = self.pipeline(prompt).images
37
-
38
- # Postprocess output: Convert image to base64
39
- pil_image = images[0] # Assuming one image is generated
40
- buffered = BytesIO()
41
- pil_image.save(buffered, format="PNG")
42
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
43
-
44
- # Return result
45
- return {"image": img_str}
 
1
+ import os
2
  from diffusers import AutoPipelineForText2Image
3
  import torch
 
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path: str = ""):
7
  """
8
  Initialize the handler, loading the model and LoRA weights.
 
9
  """
10
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ # Retrieve the Hugging Face token from environment variable
13
+ hf_token = os.getenv("HF_TOKEN") # Ensure HF_TOKEN is set in environment
14
+
15
+ # Load the model using the token
16
  self.pipeline = AutoPipelineForText2Image.from_pretrained(
17
+ 'black-forest-labs/FLUX.1-dev',
18
+ use_auth_token=hf_token,
19
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
20
  ).to(self.device)
21
 
 
23
  lora_weights_path = 'krtk00/pan_crd_lora_v2'
24
  self.pipeline.load_lora_weights(lora_weights_path, weight_name='lora.safetensors')
25
 
26
+ def __call__(self, data):
 
 
 
 
 
27
  prompt = data.get("inputs", None)
28
  if not prompt:
29
  raise ValueError("No prompt provided in the input")
 
 
30
  with torch.no_grad():
31
  images = self.pipeline(prompt).images
32
+ return {"image": images[0]}