williamberman commited on
Commit
947a0fa
1 Parent(s): 9dd235f

default device placement

Browse files
Files changed (1) hide show
  1. sdxl_models.py +6 -6
sdxl_models.py CHANGED
@@ -20,7 +20,7 @@ class ModelUtils:
20
  return next(self.parameters()).device
21
 
22
  @classmethod
23
- def load(cls, load_from: str, device, overrides: Optional[Union[str, List[str]]] = None):
24
  import load_state_dict_patch
25
 
26
  load_from = [load_from]
@@ -221,15 +221,15 @@ class SDXLVae(nn.Module, ModelUtils):
221
  return x_pred
222
 
223
  @classmethod
224
- def load_fp32(cls, device=None, overrides=None):
225
  return cls.load("./weights/sdxl_vae.safetensors", device=device, overrides=overrides)
226
 
227
  @classmethod
228
- def load_fp16(cls, device=None, overrides=None):
229
  return cls.load("./weights/sdxl_vae.fp16.safetensors", device=device, overrides=overrides)
230
 
231
  @classmethod
232
- def load_fp16_fix(cls, device=None, overrides=None):
233
  return cls.load("./weights/sdxl_vae_fp16_fix.safetensors", device=device, overrides=overrides)
234
 
235
 
@@ -449,11 +449,11 @@ class SDXLUNet(nn.Module, ModelUtils):
449
  return eps_hat
450
 
451
  @classmethod
452
- def load_fp32(cls, device=None, overrides=None):
453
  return cls.load("./weights/sdxl_unet.safetensors", device=device, overrides=overrides)
454
 
455
  @classmethod
456
- def load_fp16(cls, device=None, overrides=None):
457
  return cls.load("./weights/sdxl_unet.fp16.safetensors", device=device, overrides=overrides)
458
 
459
 
 
20
  return next(self.parameters()).device
21
 
22
  @classmethod
23
+ def load(cls, load_from: str, device='cpu', overrides: Optional[Union[str, List[str]]] = None):
24
  import load_state_dict_patch
25
 
26
  load_from = [load_from]
 
221
  return x_pred
222
 
223
  @classmethod
224
+ def load_fp32(cls, device='cpu', overrides=None):
225
  return cls.load("./weights/sdxl_vae.safetensors", device=device, overrides=overrides)
226
 
227
  @classmethod
228
+ def load_fp16(cls, device='cpu', overrides=None):
229
  return cls.load("./weights/sdxl_vae.fp16.safetensors", device=device, overrides=overrides)
230
 
231
  @classmethod
232
+ def load_fp16_fix(cls, device='cpu', overrides=None):
233
  return cls.load("./weights/sdxl_vae_fp16_fix.safetensors", device=device, overrides=overrides)
234
 
235
 
 
449
  return eps_hat
450
 
451
  @classmethod
452
+ def load_fp32(cls, device='cpu', overrides=None):
453
  return cls.load("./weights/sdxl_unet.safetensors", device=device, overrides=overrides)
454
 
455
  @classmethod
456
+ def load_fp16(cls, device='cpu', overrides=None):
457
  return cls.load("./weights/sdxl_unet.fp16.safetensors", device=device, overrides=overrides)
458
 
459