charleselena commited on
Commit
c32757f
1 Parent(s): a3fab79

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +28 -5
handler.py CHANGED
@@ -3,6 +3,10 @@ import base64
3
  from PIL import Image
4
  from io import BytesIO
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
 
 
 
 
6
  import torch
7
 
8
 
@@ -16,7 +20,7 @@ if device.type != 'cuda':
16
  raise ValueError("need to run on GPU")
17
  # set mixed precision dtype
18
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
19
-
20
  # controlnet mapping for controlnet id and control hinter
21
  CONTROLNET_MAPPING = {
22
  "canny_edge": {
@@ -59,16 +63,35 @@ class EndpointHandler():
59
  # define default controlnet id and load controlnet
60
  self.control_type = "depth"
61
  self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
 
 
 
62
 
63
  # Load StableDiffusionControlNetPipeline
64
  #self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
65
  self.stable_diffusion_id = "Lykon/dreamshaper-8"
66
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  controlnet=self.controlnet,
68
  torch_dtype=dtype,
69
- safety_checker=None).to(device)
70
  # Define Generator with seed
71
- self.generator = torch.Generator(device="cpu").manual_seed(3)
72
 
73
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
74
  """
@@ -104,7 +127,7 @@ class EndpointHandler():
104
  # process image
105
  image = self.decode_base64_image(image)
106
  #control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
107
-
108
  # run inference pipeline
109
  out = self.pipe(
110
  prompt=prompt,
 
3
  from PIL import Image
4
  from io import BytesIO
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
+ #from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionSafetyChecker
7
+ # import Safety Checker
8
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
9
+
10
  import torch
11
 
12
 
 
20
  raise ValueError("need to run on GPU")
21
  # set mixed precision dtype
22
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
23
+
24
  # controlnet mapping for controlnet id and control hinter
25
  CONTROLNET_MAPPING = {
26
  "canny_edge": {
 
63
  # define default controlnet id and load controlnet
64
  self.control_type = "depth"
65
  self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
66
+
67
+ #processor = AutoProcessor.from_pretrained("CompVis/stable-diffusion-safety-checker")
68
+
69
 
70
  # Load StableDiffusionControlNetPipeline
71
  #self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
72
  self.stable_diffusion_id = "Lykon/dreamshaper-8"
73
+ # self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
74
+ # controlnet=self.controlnet,
75
+ # torch_dtype=dtype,
76
+ # #safety_checker=None).to(device)
77
+ # #processor = AutoProcessor.from_pretrained("CompVis/stable-diffusion-safety-checker")
78
+ # #safety_checker = SafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
79
+ # safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
80
+
81
+ # self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
82
+ # self.stable_diffusion_id,
83
+ # controlnet=self.controlnet,
84
+ # torch_dtype=dtype,
85
+ # safety_checker = SafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
86
+ # ).to(device)
87
+
88
+
89
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
90
  controlnet=self.controlnet,
91
  torch_dtype=dtype,
92
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16)).to("cuda")
93
  # Define Generator with seed
94
+ self.generator = torch.Generator(device=device.type).manual_seed(3)
95
 
96
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
97
  """
 
127
  # process image
128
  image = self.decode_base64_image(image)
129
  #control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
130
+
131
  # run inference pipeline
132
  out = self.pipe(
133
  prompt=prompt,