Anwar786 commited on
Commit
f36a3e1
1 Parent(s): 4836f4c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +85 -74
handler.py CHANGED
@@ -1,116 +1,127 @@
1
- from typing import List, Dict, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  import torch
 
 
 
 
7
  import controlnet_hinter
8
 
9
  # set device
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  if device.type != 'cuda':
12
- raise ValueError("Need to run on GPU")
13
  # set mixed precision dtype
14
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
15
 
16
- # controlnet mapping for depth controlnet
17
  CONTROLNET_MAPPING = {
 
 
 
 
 
 
 
 
18
  "depth": {
19
  "model_id": "lllyasviel/sd-controlnet-depth",
20
  "hinter": controlnet_hinter.hint_depth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  }
22
  }
23
 
 
24
  class EndpointHandler():
25
  def __init__(self, path=""):
26
  # define default controlnet id and load controlnet
27
  self.control_type = "depth"
28
- self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)
29
-
30
- # Load StableDiffusionControlNetPipeline
31
  self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
32
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
33
- controlnet=self.controlnet,
34
  torch_dtype=dtype,
35
  safety_checker=None).to(device)
36
  # Define Generator with seed
37
  self.generator = torch.Generator(device="cpu").manual_seed(3)
38
 
39
- def __call__(self, data: Any) -> Dict[str, str]:
40
- # Example JSON payload for testing
41
- example_payload = {
42
- "prompt": "a beautiful landscape",
43
- "negative_prompt": "blur",
44
- "width": 1024,
45
- "height": 1024,
46
- "steps": 30,
47
- "cfg_scale": 7,
48
- "alwayson_scripts": {
49
- "controlnet": {
50
- "args": [
51
- {
52
- "enabled": True,
53
- "input_image": "image in base64",
54
- "model": "control_sd15_depth [fef5e48e]",
55
- "control_mode": "Balanced"
56
- }
57
- ]
58
- }
59
- }
60
- }
61
-
62
- # Extract parameters from the payload
63
- prompt = data.get("prompt", None)
64
- negative_prompt = data.get("negative_prompt", None)
65
- width = data.get("width", None)
66
- height = data.get("height", None)
67
- num_inference_steps = data.get("steps", 30)
68
- guidance_scale = data.get("cfg_scale", 7)
69
 
70
- # Extract controlnet configuration from payload
71
- controlnet_config = data.get("alwayson_scripts", {}).get("controlnet", {}).get("args", [{}])[0]
72
-
73
- # Run stable diffusion process
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  out = self.pipe(
75
- prompt=prompt,
76
  negative_prompt=negative_prompt,
77
- num_inference_steps=num_inference_steps,
 
78
  guidance_scale=guidance_scale,
79
  num_images_per_prompt=1,
80
  height=height,
81
  width=width,
82
- controlnet_conditioning_scale=1.0,
83
- generator=self.generator,
84
  )
85
 
86
- # Get the generated image
87
- generated_image = out.images[0]
88
-
89
- # Process with controlnet if enabled
90
- if controlnet_config.get("enabled", False):
91
- input_image_base64 = controlnet_config.get("input_image", "")
92
- input_image = self.decode_base64_image(input_image_base64)
93
- controlnet_model = controlnet_config.get("model", "")
94
- controlnet_control_mode = controlnet_config.get("control_mode", "")
95
-
96
- processed_image = self.process_with_controlnet(generated_image, input_image, controlnet_model, controlnet_control_mode)
97
- else:
98
- processed_image = generated_image
99
-
100
- # Return the final processed image as base64
101
- return {"image": self.encode_base64_image(processed_image)}
102
-
103
- def process_with_controlnet(self, generated_image, input_image, model, control_mode):
104
- # Simulated controlnet processing (replace with actual implementation)
105
- # Here, we're just using the input_image as-is. Replace this with your controlnet logic.
106
- return input_image
107
-
108
- def encode_base64_image(self, image):
109
- # Encode the PIL Image to base64
110
- buffer = BytesIO()
111
- image.save(buffer, format="PNG")
112
- return base64.b64encode(buffer.getvalue()).decode("utf-8")
113
-
114
  def decode_base64_image(self, image_string):
115
  base64_image = base64.b64decode(image_string)
116
  buffer = BytesIO(base64_image)
 
1
+ from typing import Dict, List, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  import torch
7
+
8
+
9
+ import numpy as np
10
+ import cv2
11
  import controlnet_hinter
12
 
13
  # set device
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  if device.type != 'cuda':
16
+ raise ValueError("need to run on GPU")
17
  # set mixed precision dtype
18
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
19
 
20
+ # controlnet mapping for controlnet id and control hinter
21
  CONTROLNET_MAPPING = {
22
+ "canny_edge": {
23
+ "model_id": "lllyasviel/sd-controlnet-canny",
24
+ "hinter": controlnet_hinter.hint_canny
25
+ },
26
+ "pose": {
27
+ "model_id": "lllyasviel/sd-controlnet-openpose",
28
+ "hinter": controlnet_hinter.hint_openpose
29
+ },
30
  "depth": {
31
  "model_id": "lllyasviel/sd-controlnet-depth",
32
  "hinter": controlnet_hinter.hint_depth
33
+ },
34
+ "scribble": {
35
+ "model_id": "lllyasviel/sd-controlnet-scribble",
36
+ "hinter": controlnet_hinter.hint_scribble,
37
+ },
38
+ "segmentation": {
39
+ "model_id": "lllyasviel/sd-controlnet-seg",
40
+ "hinter": controlnet_hinter.hint_segmentation,
41
+ },
42
+ "normal": {
43
+ "model_id": "lllyasviel/sd-controlnet-normal",
44
+ "hinter": controlnet_hinter.hint_normal,
45
+ },
46
+ "hed": {
47
+ "model_id": "lllyasviel/sd-controlnet-hed",
48
+ "hinter": controlnet_hinter.hint_hed,
49
+ },
50
+ "hough": {
51
+ "model_id": "lllyasviel/sd-controlnet-mlsd",
52
+ "hinter": controlnet_hinter.hint_hough,
53
  }
54
  }
55
 
56
+
57
  class EndpointHandler():
58
  def __init__(self, path=""):
59
  # define default controlnet id and load controlnet
60
  self.control_type = "depth"
61
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
62
+
63
+ # Load StableDiffusionControlNetPipeline
64
  self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
65
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
66
+ controlnet=self.controlnet,
67
  torch_dtype=dtype,
68
  safety_checker=None).to(device)
69
  # Define Generator with seed
70
  self.generator = torch.Generator(device="cpu").manual_seed(3)
71
 
72
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
73
+ """
74
+ :param data: A dictionary contains `inputs` and optional `image` field.
75
+ :return: A dictionary with `image` field contains image in base64.
76
+ """
77
+ prompt = data.pop("inputs", None)
78
+ image = data.pop("image", None)
79
+ controlnet_type = data.pop("controlnet_type", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # Check if neither prompt nor image is provided
82
+ if prompt is None and image is None:
83
+ return {"error": "Please provide a prompt and base64 encoded image."}
84
+
85
+ # Check if a new controlnet is provided
86
+ if controlnet_type is not None and controlnet_type != self.control_type:
87
+ print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
88
+ self.control_type = controlnet_type
89
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
90
+ torch_dtype=dtype).to(device)
91
+ self.pipe.controlnet = self.controlnet
92
+
93
+
94
+ # hyperparamters
95
+ num_inference_steps = data.pop("num_inference_steps", 30)
96
+ guidance_scale = data.pop("guidance_scale", 7.5)
97
+ negative_prompt = data.pop("negative_prompt", None)
98
+ height = data.pop("height", None)
99
+ width = data.pop("width", None)
100
+ controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
101
+
102
+ # process image
103
+ image = self.decode_base64_image(image)
104
+ control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
105
+
106
+ # run inference pipeline
107
  out = self.pipe(
108
+ prompt=prompt,
109
  negative_prompt=negative_prompt,
110
+ image=control_image,
111
+ num_inference_steps=num_inference_steps,
112
  guidance_scale=guidance_scale,
113
  num_images_per_prompt=1,
114
  height=height,
115
  width=width,
116
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
117
+ generator=self.generator
118
  )
119
 
120
+
121
+ # return first generate PIL image
122
+ return out.images[0]
123
+
124
+ # helper to decode input image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def decode_base64_image(self, image_string):
126
  base64_image = base64.b64decode(image_string)
127
  buffer = BytesIO(base64_image)