charleselena commited on
Commit
bacee68
1 Parent(s): 212c508

remove dtype

Browse files
Files changed (1) hide show
  1. handler.py +10 -4
handler.py CHANGED
@@ -76,17 +76,23 @@ class EndpointHandler():
76
  # #safety_checker = SafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
77
  # safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
78
 
 
 
 
 
 
 
 
 
79
  self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
80
  self.stable_diffusion_id,
81
  controlnet=self.controlnet,
82
- torch_dtype=dtype,
83
  safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
84
- )
85
 
86
 
87
  # Define Generator with seed
88
- # self.generator = torch.Generator(device="cpu").manual_seed(3)
89
- self.generator = torch.Generator(device="cuda").manual_seed(3)
90
 
91
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
92
  """
 
76
  # #safety_checker = SafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
77
  # safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
78
 
79
+ # self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
80
+ # self.stable_diffusion_id,
81
+ # controlnet=self.controlnet,
82
+ # torch_dtype=dtype,
83
+ # safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
84
+ # ).to(device)
85
+
86
+
87
  self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
88
  self.stable_diffusion_id,
89
  controlnet=self.controlnet,
 
90
  safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
91
+ ).to(device)
92
 
93
 
94
  # Define Generator with seed
95
+ self.generator = torch.Generator(device="cpu").manual_seed(3)
 
96
 
97
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
98
  """