saassa commited on
Commit
39e9009
1 Parent(s): fe2c8f7

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -23
handler.py CHANGED
@@ -1,9 +1,12 @@
 
1
  from typing import Dict, List, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
5
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  import torch
 
 
7
 
8
 
9
  import numpy as np
@@ -53,21 +56,22 @@ CONTROLNET_MAPPING = {
53
  }
54
  }
55
 
56
-
57
  class EndpointHandler():
58
  def __init__(self, path=""):
59
  # define default controlnet id and load controlnet
60
  self.control_type = "normal"
61
  self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
62
-
63
- # Load StableDiffusionControlNetPipeline
64
- self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
65
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
66
- controlnet=self.controlnet,
67
- torch_dtype=dtype,
68
  safety_checker=None).to(device)
 
69
  # Define Generator with seed
70
- self.generator = torch.Generator(device="cpu").manual_seed(3)
71
 
72
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
73
  """
@@ -77,11 +81,11 @@ class EndpointHandler():
77
  prompt = data.pop("inputs", None)
78
  image = data.pop("image", None)
79
  controlnet_type = data.pop("controlnet_type", None)
80
-
81
  # Check if neither prompt nor image is provided
82
  if prompt is None and image is None:
83
  return {"error": "Please provide a prompt and base64 encoded image."}
84
-
85
  # Check if a new controlnet is provided
86
  if controlnet_type is not None and controlnet_type != self.control_type:
87
  print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
@@ -89,41 +93,43 @@ class EndpointHandler():
89
  self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
90
  torch_dtype=dtype).to(device)
91
  self.pipe.controlnet = self.controlnet
92
-
93
-
94
  # hyperparamters
95
- num_inference_steps = data.pop("num_inference_steps", 30)
96
- guidance_scale = data.pop("guidance_scale", 7.5)
 
97
  negative_prompt = data.pop("negative_prompt", None)
98
  height = data.pop("height", None)
99
  width = data.pop("width", None)
100
  controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
101
-
102
  # process image
103
  image = self.decode_base64_image(image)
104
  control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
105
-
106
  # run inference pipeline
107
  out = self.pipe(
108
- prompt=prompt,
109
  negative_prompt=negative_prompt,
110
  image=control_image,
111
- num_inference_steps=num_inference_steps,
112
  guidance_scale=guidance_scale,
113
  num_images_per_prompt=1,
114
  height=height,
115
  width=width,
116
  controlnet_conditioning_scale=controlnet_conditioning_scale,
117
- generator=self.generator
 
118
  )
119
 
120
-
121
  # return first generate PIL image
122
  return out.images[0]
123
-
124
  # helper to decode input image
125
  def decode_base64_image(self, image_string):
126
  base64_image = base64.b64decode(image_string)
127
  buffer = BytesIO(base64_image)
128
  image = Image.open(buffer)
129
- return image
 
1
+ %%writefile handler.py
2
  from typing import Dict, List, Any
3
  import base64
4
  from PIL import Image
5
  from io import BytesIO
6
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, AutoencoderKL, StableDiffusionXLControlNetPipeline
7
  import torch
8
+ from diffusers.utils import load_image
9
+
10
 
11
 
12
  import numpy as np
 
56
  }
57
  }
58
 
59
+
60
  class EndpointHandler():
61
  def __init__(self, path=""):
62
  # define default controlnet id and load controlnet
63
  self.control_type = "normal"
64
  self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
65
+
66
+ # Load StableDiffusionControlNetPipeline
67
+ self.stable_diffusion_id = "stablediffusionapi/disney-pixar-cartoon"
68
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
69
+ controlnet=self.controlnet,
70
+ torch_dtype=dtype,
71
  safety_checker=None).to(device)
72
+
73
  # Define Generator with seed
74
+ # COMMENTED self.generator = torch.Generator(device="cpu").manual_seed(3)
75
 
76
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
77
  """
 
81
  prompt = data.pop("inputs", None)
82
  image = data.pop("image", None)
83
  controlnet_type = data.pop("controlnet_type", None)
84
+
85
  # Check if neither prompt nor image is provided
86
  if prompt is None and image is None:
87
  return {"error": "Please provide a prompt and base64 encoded image."}
88
+
89
  # Check if a new controlnet is provided
90
  if controlnet_type is not None and controlnet_type != self.control_type:
91
  print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
 
93
  self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
94
  torch_dtype=dtype).to(device)
95
  self.pipe.controlnet = self.controlnet
96
+
97
+
98
  # hyperparamters
99
+ negatice_prompt = data.pop("negative_prompt", None)
100
+ num_inference_steps = data.pop("num_inference_steps", 150)
101
+ guidance_scale = data.pop("guidance_scale", 5)
102
  negative_prompt = data.pop("negative_prompt", None)
103
  height = data.pop("height", None)
104
  width = data.pop("width", None)
105
  controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
106
+
107
  # process image
108
  image = self.decode_base64_image(image)
109
  control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
110
+
111
  # run inference pipeline
112
  out = self.pipe(
113
+ prompt=prompt,
114
  negative_prompt=negative_prompt,
115
  image=control_image,
116
+ num_inference_steps=num_inference_steps,
117
  guidance_scale=guidance_scale,
118
  num_images_per_prompt=1,
119
  height=height,
120
  width=width,
121
  controlnet_conditioning_scale=controlnet_conditioning_scale,
122
+ guess_mode=True,
123
+
124
  )
125
 
126
+ #generator=self.generator COMMENTED from self.pipe
127
  # return first generate PIL image
128
  return out.images[0]
129
+
130
  # helper to decode input image
131
  def decode_base64_image(self, image_string):
132
  base64_image = base64.b64decode(image_string)
133
  buffer = BytesIO(base64_image)
134
  image = Image.open(buffer)
135
+ return image