|
|
|
from safetensors.torch import load_file as safe_load_file |
|
import time |
|
import sys |
|
|
|
direct_on_gpu = bool(int(sys.argv[1])) |
|
|
|
if direct_on_gpu: |
|
start_time = time.time() |
|
checkpoint = safe_load_file("/home/patrick_huggingface_co/stable-diffusion-v1-4/unet/diffusion_pytorch_model.safetensors", device=0) |
|
print("Directly on GPU", time.time() - start_time) |
|
else: |
|
start_time = time.time() |
|
checkpoint = safe_load_file("/home/patrick_huggingface_co/stable-diffusion-v1-4/unet/diffusion_pytorch_model.safetensors") |
|
checkpoint = {k: v.to("cuda:0") for k, v in checkpoint.items()} |
|
print("On CPU", time.time() - start_time) |
|
|