williamberman
commited on
Commit
•
947a0fa
1
Parent(s):
9dd235f
default device placement
Browse files- sdxl_models.py +6 -6
sdxl_models.py
CHANGED
@@ -20,7 +20,7 @@ class ModelUtils:
|
|
20 |
return next(self.parameters()).device
|
21 |
|
22 |
@classmethod
|
23 |
-
def load(cls, load_from: str, device, overrides: Optional[Union[str, List[str]]] = None):
|
24 |
import load_state_dict_patch
|
25 |
|
26 |
load_from = [load_from]
|
@@ -221,15 +221,15 @@ class SDXLVae(nn.Module, ModelUtils):
|
|
221 |
return x_pred
|
222 |
|
223 |
@classmethod
|
224 |
-
def load_fp32(cls, device=
|
225 |
return cls.load("./weights/sdxl_vae.safetensors", device=device, overrides=overrides)
|
226 |
|
227 |
@classmethod
|
228 |
-
def load_fp16(cls, device=
|
229 |
return cls.load("./weights/sdxl_vae.fp16.safetensors", device=device, overrides=overrides)
|
230 |
|
231 |
@classmethod
|
232 |
-
def load_fp16_fix(cls, device=
|
233 |
return cls.load("./weights/sdxl_vae_fp16_fix.safetensors", device=device, overrides=overrides)
|
234 |
|
235 |
|
@@ -449,11 +449,11 @@ class SDXLUNet(nn.Module, ModelUtils):
|
|
449 |
return eps_hat
|
450 |
|
451 |
@classmethod
|
452 |
-
def load_fp32(cls, device=
|
453 |
return cls.load("./weights/sdxl_unet.safetensors", device=device, overrides=overrides)
|
454 |
|
455 |
@classmethod
|
456 |
-
def load_fp16(cls, device=
|
457 |
return cls.load("./weights/sdxl_unet.fp16.safetensors", device=device, overrides=overrides)
|
458 |
|
459 |
|
|
|
20 |
return next(self.parameters()).device
|
21 |
|
22 |
@classmethod
|
23 |
+
def load(cls, load_from: str, device='cpu', overrides: Optional[Union[str, List[str]]] = None):
|
24 |
import load_state_dict_patch
|
25 |
|
26 |
load_from = [load_from]
|
|
|
221 |
return x_pred
|
222 |
|
223 |
@classmethod
|
224 |
+
def load_fp32(cls, device='cpu', overrides=None):
|
225 |
return cls.load("./weights/sdxl_vae.safetensors", device=device, overrides=overrides)
|
226 |
|
227 |
@classmethod
|
228 |
+
def load_fp16(cls, device='cpu', overrides=None):
|
229 |
return cls.load("./weights/sdxl_vae.fp16.safetensors", device=device, overrides=overrides)
|
230 |
|
231 |
@classmethod
|
232 |
+
def load_fp16_fix(cls, device='cpu', overrides=None):
|
233 |
return cls.load("./weights/sdxl_vae_fp16_fix.safetensors", device=device, overrides=overrides)
|
234 |
|
235 |
|
|
|
449 |
return eps_hat
|
450 |
|
451 |
@classmethod
|
452 |
+
def load_fp32(cls, device='cpu', overrides=None):
|
453 |
return cls.load("./weights/sdxl_unet.safetensors", device=device, overrides=overrides)
|
454 |
|
455 |
@classmethod
|
456 |
+
def load_fp16(cls, device='cpu', overrides=None):
|
457 |
return cls.load("./weights/sdxl_unet.fp16.safetensors", device=device, overrides=overrides)
|
458 |
|
459 |
|