tools / safetensors_bench.py
patrickvonplaten's picture
add
ea151b9
raw
history blame
660 Bytes
#!/usr/bin/env python3
from safetensors.torch import load_file as safe_load_file
import time
import sys
direct_on_gpu = bool(int(sys.argv[1]))
if direct_on_gpu:
start_time = time.time()
checkpoint = safe_load_file("/home/patrick_huggingface_co/stable-diffusion-v1-4/unet/diffusion_pytorch_model.safetensors", device=0)
print("Directly on GPU", time.time() - start_time)
else:
start_time = time.time()
checkpoint = safe_load_file("/home/patrick_huggingface_co/stable-diffusion-v1-4/unet/diffusion_pytorch_model.safetensors")
checkpoint = {k: v.to("cuda:0") for k, v in checkpoint.items()}
print("On CPU", time.time() - start_time)