Spaces:
Runtime error
Runtime error
cocktailpeanut
commited on
Commit
•
ca3008d
1
Parent(s):
6704916
update
Browse files- src/pix2pix_turbo.py +2 -1
src/pix2pix_turbo.py
CHANGED
@@ -109,7 +109,8 @@ class Pix2Pix_Turbo(torch.nn.Module):
|
|
109 |
_sd_unet = unet.state_dict()
|
110 |
for k in sd["state_dict_unet"]: _sd_unet[k] = sd["state_dict_unet"][k]
|
111 |
unet.load_state_dict(_sd_unet)
|
112 |
-
|
|
|
113 |
_sd_vae = vae.state_dict()
|
114 |
for k in sd["state_dict_vae"]: _sd_vae[k] = sd["state_dict_vae"][k]
|
115 |
vae.load_state_dict(_sd_vae)
|
|
|
109 |
_sd_unet = unet.state_dict()
|
110 |
for k in sd["state_dict_unet"]: _sd_unet[k] = sd["state_dict_unet"][k]
|
111 |
unet.load_state_dict(_sd_unet)
|
112 |
+
if device == "cuda":
|
113 |
+
unet.enable_xformers_memory_efficient_attention()
|
114 |
_sd_vae = vae.state_dict()
|
115 |
for k in sd["state_dict_vae"]: _sd_vae[k] = sd["state_dict_vae"][k]
|
116 |
vae.load_state_dict(_sd_vae)
|