import torch | |
from safetensors.torch import load_file, save_file | |
# List of safetensor files | |
safetensor_files = [ | |
"model-00001-of-00003.safetensors", | |
"model-00002-of-00003.safetensors", | |
"model-00003-of-00003.safetensors" | |
] | |
# Load weights from safetensor files | |
merged_weights = {} | |
for file in safetensor_files: | |
weights = load_file(file) | |
merged_weights.update(weights) | |
# Save the merged weights into a single pytorch_model.bin file | |
torch.save(merged_weights, "pytorch_model.bin") | |
print("Merged weights saved to pytorch_model.bin") | |