chaojiemao commited on
Commit
47b17eb
·
verified ·
1 Parent(s): c599725

Update ace_inference.py

Browse files
Files changed (1) hide show
  1. ace_inference.py +1 -0
ace_inference.py CHANGED
@@ -147,6 +147,7 @@ class ACEInference(DiffusionInference):
147
  self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
148
  if self.ref_cond_stage_model is not None: self.dynamic_load(self.ref_cond_stage_model, 'ref_cond_stage_model')
149
  self.dynamic_load(self.diffusion_model, 'diffusion_model')
 
150
 
151
  def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
152
  c, H, W = image.shape
 
147
  self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
148
  if self.ref_cond_stage_model is not None: self.dynamic_load(self.ref_cond_stage_model, 'ref_cond_stage_model')
149
  self.dynamic_load(self.diffusion_model, 'diffusion_model')
150
+ self.diffusion_model["model"].to(torch.bfloat16)
151
 
152
  def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
153
  c, H, W = image.shape