hysts HF staff commited on
Commit
dd9a968
1 Parent(s): e189cea

Use mixed precision

Browse files
Files changed (1) hide show
  1. model.py +10 -3
model.py CHANGED
@@ -57,9 +57,13 @@ class Model:
57
  if base_model_id == self.base_model_id and task_name == self.task_name:
58
  return self.pipe
59
  model_id = CONTROLNET_MODEL_IDS[task_name]
60
- controlnet = ControlNetModel.from_pretrained(model_id)
 
61
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
62
- base_model_id, safety_checker=None, controlnet=controlnet)
 
 
 
63
  pipe.scheduler = UniPCMultistepScheduler.from_config(
64
  pipe.scheduler.config)
65
  pipe.enable_xformers_memory_efficient_attention()
@@ -89,7 +93,9 @@ class Model:
89
  torch.cuda.empty_cache()
90
  gc.collect()
91
  model_id = CONTROLNET_MODEL_IDS[task_name]
92
- controlnet = ControlNetModel.from_pretrained(model_id).to(self.device)
 
 
93
  torch.cuda.empty_cache()
94
  gc.collect()
95
  self.pipe.controlnet = controlnet
@@ -102,6 +108,7 @@ class Model:
102
  prompt = f'{prompt}, {additional_prompt}'
103
  return prompt
104
 
 
105
  def run_pipe(
106
  self,
107
  prompt: str,
 
57
  if base_model_id == self.base_model_id and task_name == self.task_name:
58
  return self.pipe
59
  model_id = CONTROLNET_MODEL_IDS[task_name]
60
+ controlnet = ControlNetModel.from_pretrained(model_id,
61
+ torch_dtype=torch.float16)
62
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
63
+ base_model_id,
64
+ safety_checker=None,
65
+ controlnet=controlnet,
66
+ torch_dtype=torch.float16)
67
  pipe.scheduler = UniPCMultistepScheduler.from_config(
68
  pipe.scheduler.config)
69
  pipe.enable_xformers_memory_efficient_attention()
 
93
  torch.cuda.empty_cache()
94
  gc.collect()
95
  model_id = CONTROLNET_MODEL_IDS[task_name]
96
+ controlnet = ControlNetModel.from_pretrained(model_id,
97
+ torch_dtype=torch.float16)
98
+ controlnet.to(self.device)
99
  torch.cuda.empty_cache()
100
  gc.collect()
101
  self.pipe.controlnet = controlnet
 
108
  prompt = f'{prompt}, {additional_prompt}'
109
  return prompt
110
 
111
+ @torch.autocast('cuda')
112
  def run_pipe(
113
  self,
114
  prompt: str,