Anwar786 commited on
Commit
58066d8
1 Parent(s): ffa1636

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +28 -51
handler.py CHANGED
@@ -1,9 +1,11 @@
1
- from typing import Dict, List, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  import torch
 
 
7
  import numpy as np
8
  import cv2
9
  import controlnet_hinter
@@ -11,8 +13,7 @@ import controlnet_hinter
11
  # set device
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
  if device.type != 'cuda':
14
- raise ValueError("Need to run on GPU")
15
-
16
  # set mixed precision dtype
17
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
18
 
@@ -52,14 +53,12 @@ CONTROLNET_MAPPING = {
52
  }
53
  }
54
 
 
55
  class EndpointHandler():
56
- """
57
- A class to handle endpoint logic.
58
- """
59
  def __init__(self, path=""):
60
  # define default controlnet id and load controlnet
61
- self.control_type = "depth"
62
- self.controlnet = ControlNetModel.from_pretrained(controlnet_mapping[self.control_type]["model_id"], torch_dtype=dtype).to(device)
63
 
64
  # Load StableDiffusionControlNetPipeline
65
  self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
@@ -70,30 +69,29 @@ class EndpointHandler():
70
  # Define Generator with seed
71
  self.generator = torch.Generator(device="cpu").manual_seed(3)
72
 
73
- def __call__(self, data: Any) -> None:
74
  """
75
- Process input data and perform inference.
76
-
77
- :param data: A dictionary containing `inputs` and optional `image_path` field.
78
- :return: None
79
  """
80
  prompt = data.pop("inputs", None)
81
- image_path = data.pop("image_path", None)
82
  controlnet_type = data.pop("controlnet_type", None)
83
-
84
- # Check if neither prompt nor image path is provided
85
- if prompt is None and image_path is None:
86
- raise ValueError("Please provide a prompt and either an image path or a base64-encoded image.")
87
-
88
  # Check if a new controlnet is provided
89
  if controlnet_type is not None and controlnet_type != self.control_type:
90
- print(f"Changing controlnet from {self.control_type} to {controlnet_type} using {controlnet_mapping[controlnet_type]['model_id']} model")
91
  self.control_type = controlnet_type
92
- self.controlnet = ControlNetModel.from_pretrained(controlnet_mapping[self.control_type]["model_id"],
93
  torch_dtype=dtype).to(device)
94
  self.pipe.controlnet = self.controlnet
95
 
96
- # hyperparameters
 
97
  num_inference_steps = data.pop("num_inference_steps", 30)
98
  guidance_scale = data.pop("guidance_scale", 7.5)
99
  negative_prompt = data.pop("negative_prompt", None)
@@ -102,14 +100,8 @@ class EndpointHandler():
102
  controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
103
 
104
  # process image
105
- if image_path is not None:
106
- # Load the image from the specified path
107
- image = Image.open(image_path)
108
- else:
109
- # Decode base64-encoded image
110
- image = self.decode_base64_image(data.pop("image", ""))
111
-
112
- control_image = controlnet_mapping[self.control_type]["hinter"](image)
113
 
114
  # run inference pipeline
115
  out = self.pipe(
@@ -125,28 +117,13 @@ class EndpointHandler():
125
  generator=self.generator
126
  )
127
 
128
- # save the generated image as a JPEG file
129
- output_image = out.images[0]
130
- output_image.save("output.jpg", format="JPEG")
131
-
 
132
  def decode_base64_image(self, image_string):
133
  base64_image = base64.b64decode(image_string)
134
  buffer = BytesIO(base64_image)
135
  image = Image.open(buffer)
136
- return image
137
-
138
- # Example usage
139
- payload = {
140
- "inputs": "Your prompt here",
141
- "image_path": "path/to/your/image.jpg",
142
- "controlnet_type": "depth",
143
- "num_inference_steps": 30,
144
- "guidance_scale": 7.5,
145
- "negative_prompt": None,
146
- "height": None,
147
- "width": None,
148
- "controlnet_conditioning_scale": 1.0,
149
- }
150
-
151
- handler = EndpointHandler()
152
- handler(payload)
 
1
+ from typing import Dict, List, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  import torch
7
+
8
+
9
  import numpy as np
10
  import cv2
11
  import controlnet_hinter
 
13
  # set device
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  if device.type != 'cuda':
16
+ raise ValueError("need to run on GPU")
 
17
  # set mixed precision dtype
18
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
19
 
 
53
  }
54
  }
55
 
56
+
57
  class EndpointHandler():
 
 
 
58
  def __init__(self, path=""):
59
  # define default controlnet id and load controlnet
60
+ self.control_type = "normal"
61
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
62
 
63
  # Load StableDiffusionControlNetPipeline
64
  self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
 
69
  # Define Generator with seed
70
  self.generator = torch.Generator(device="cpu").manual_seed(3)
71
 
72
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
73
  """
74
+ :param data: A dictionary contains `inputs` and optional `image` field.
75
+ :return: A dictionary with `image` field contains image in base64.
 
 
76
  """
77
  prompt = data.pop("inputs", None)
78
+ image = data.pop("image", None)
79
  controlnet_type = data.pop("controlnet_type", None)
80
+
81
+ # Check if neither prompt nor image is provided
82
+ if prompt is None and image is None:
83
+ return {"error": "Please provide a prompt and base64 encoded image."}
84
+
85
  # Check if a new controlnet is provided
86
  if controlnet_type is not None and controlnet_type != self.control_type:
87
+ print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
88
  self.control_type = controlnet_type
89
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
90
  torch_dtype=dtype).to(device)
91
  self.pipe.controlnet = self.controlnet
92
 
93
+
94
+ # hyperparamters
95
  num_inference_steps = data.pop("num_inference_steps", 30)
96
  guidance_scale = data.pop("guidance_scale", 7.5)
97
  negative_prompt = data.pop("negative_prompt", None)
 
100
  controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
101
 
102
  # process image
103
+ image = self.decode_base64_image(image)
104
+ control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
 
 
 
 
 
 
105
 
106
  # run inference pipeline
107
  out = self.pipe(
 
117
  generator=self.generator
118
  )
119
 
120
+
121
+ # return first generate PIL image
122
+ return out.images[0]
123
+
124
+ # helper to decode input image
125
  def decode_base64_image(self, image_string):
126
  base64_image = base64.b64decode(image_string)
127
  buffer = BytesIO(base64_image)
128
  image = Image.open(buffer)
129
+ return image