Ming Li commited on
Commit
68a99f3
1 Parent(s): acbc9f8
Files changed (1) hide show
  1. model.py +3 -3
model.py CHANGED
@@ -53,9 +53,9 @@ class Model:
53
  ):
54
  return self.pipe
55
  model_id = CONTROLNET_MODEL_IDS[task_name]
56
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
57
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
58
- base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
59
  )
60
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
61
  # if self.device.type == "cuda":
@@ -88,7 +88,7 @@ class Model:
88
  torch.cuda.empty_cache()
89
  gc.collect()
90
  model_id = CONTROLNET_MODEL_IDS[task_name]
91
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
92
  controlnet.to(self.device)
93
  torch.cuda.empty_cache()
94
  gc.collect()
 
53
  ):
54
  return self.pipe
55
  model_id = CONTROLNET_MODEL_IDS[task_name]
56
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32)
57
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
58
+ base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float32
59
  )
60
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
61
  # if self.device.type == "cuda":
 
88
  torch.cuda.empty_cache()
89
  gc.collect()
90
  model_id = CONTROLNET_MODEL_IDS[task_name]
91
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32)
92
  controlnet.to(self.device)
93
  torch.cuda.empty_cache()
94
  gc.collect()