Spaces:
Runtime error
Runtime error
Update merged_files3.py
Browse files- merged_files3.py +1 -1
merged_files3.py
CHANGED
@@ -495,7 +495,7 @@ dtype = torch.float16 # Use float16 consistently for all models
|
|
495 |
sd_offset = sf.load_file(model_path) # Use device variable
|
496 |
sd_origin = unet.state_dict()
|
497 |
keys = sd_origin.keys()
|
498 |
-
|
499 |
del sd_offset, sd_origin, sd_merged, keys
|
500 |
|
501 |
|
|
|
495 |
sd_offset = sf.load_file(model_path) # Use device variable
|
496 |
sd_origin = unet.state_dict()
|
497 |
keys = sd_origin.keys()
|
498 |
+
sd_merged = {k: v.to(device) for k, v in sd_offset.items()} # Move each tensor to GPUunet.load_state_dict(sd_merged, strict=True)
|
499 |
del sd_offset, sd_origin, sd_merged, keys
|
500 |
|
501 |
|