Texttra commited on
Commit
f6f3cdf
·
verified ·
1 Parent(s): 81b2c86

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -18
handler.py CHANGED
@@ -7,41 +7,39 @@ import base64
7
 
8
  class EndpointHandler:
9
  def __init__(self, path: str = ""):
10
- # Load base FLUX pipeline
11
  self.pipe = DiffusionPipeline.from_pretrained(
12
  "black-forest-labs/FLUX.1-dev",
13
  torch_dtype=torch.float16,
14
  use_auth_token=True
15
  )
16
 
17
- # Load your LoRA weights from the repo
18
  self.pipe.load_lora_weights(path, weight_name="c1t3_v1.safetensors")
19
 
20
- # Move to GPU if available
21
  if torch.cuda.is_available():
22
  self.pipe.to("cuda")
23
  else:
24
  self.pipe.to("cpu")
25
 
26
- # Optional: enable memory optimization
27
  self.pipe.enable_model_cpu_offload()
28
 
29
- # Initialize Compel (prompt parser for FLUX)
30
- self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
 
 
31
 
32
- def __call__(self, data: Dict[str, str]) -> Dict:
33
- inputs = data.get("inputs", {})
34
- prompt = inputs.get("prompt", "")
35
- if not prompt:
36
- return {"error": "No prompt provided."}
37
 
38
- print(f"Received prompt: {prompt}") # helpful logging
39
 
40
- conditioning = self.compel(prompt)
41
- image = self.pipe(prompt_embeds=conditioning).images[0]
42
 
43
- buffer = BytesIO()
44
- image.save(buffer, format="PNG")
45
- base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
46
 
47
- return {"image": base64_image}
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path: str = ""):
 
10
  self.pipe = DiffusionPipeline.from_pretrained(
11
  "black-forest-labs/FLUX.1-dev",
12
  torch_dtype=torch.float16,
13
  use_auth_token=True
14
  )
15
 
 
16
  self.pipe.load_lora_weights(path, weight_name="c1t3_v1.safetensors")
17
 
 
18
  if torch.cuda.is_available():
19
  self.pipe.to("cuda")
20
  else:
21
  self.pipe.to("cpu")
22
 
 
23
  self.pipe.enable_model_cpu_offload()
24
 
25
+ self.compel = Compel(
26
+ tokenizer=self.pipe.tokenizer,
27
+ text_encoder=self.pipe.text_encoder
28
+ )
29
 
30
+ def __call__(self, data: Dict[str, Dict[str, str]]) -> Dict:
31
+ inputs = data.get("inputs", {})
32
+ prompt = inputs.get("prompt", "")
33
+ if not prompt:
34
+ return {"error": "No prompt provided."}
35
 
36
+ print(f"Received prompt: {prompt}")
37
 
38
+ conditioning = self.compel(prompt)
39
+ image = self.pipe(prompt_embeds=conditioning).images[0]
40
 
41
+ buffer = BytesIO()
42
+ image.save(buffer, format="PNG")
43
+ base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
44
 
45
+ return {"image": base64_image}