import json | |
from typing import Dict | |
from safetensors.torch import load_file, save_file | |
from huggingface_hub import split_torch_state_dict_into_shards | |
import torch | |
import os | |
def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): | |
state_dict_split = split_torch_state_dict_into_shards(state_dict, filename_pattern='consolidated{suffix}.safetensors') | |
for filename, tensors in state_dict_split.filename_to_tensors.items(): | |
shard = {tensor: state_dict[tensor] for tensor in tensors} | |
print("Saving", save_directory, filename) | |
save_file(shard, os.path.join(save_directory, filename)) | |
if state_dict_split.is_sharded: | |
index = { | |
"metadata": state_dict_split.metadata, | |
"weight_map": state_dict_split.tensor_to_filename, | |
} | |
with open(os.path.join(save_directory, "consolidated.safetensors.index.json"), "w") as f: | |
f.write(json.dumps(index, indent=2)) | |
big_file = 'consolidated.safetensors' | |
loaded = load_file(big_file) | |
save_state_dict(loaded, save_directory=f'.') | |