|
import json |
|
from typing import Dict |
|
|
|
from safetensors.torch import load_file, save_file |
|
from huggingface_hub import split_torch_state_dict_into_shards |
|
import torch |
|
import os |
|
|
|
def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): |
|
state_dict_split = split_torch_state_dict_into_shards(state_dict, filename_pattern='consolidated{suffix}.safetensors') |
|
for filename, tensors in state_dict_split.filename_to_tensors.items(): |
|
shard = {tensor: state_dict[tensor] for tensor in tensors} |
|
print("Saving", save_directory, filename) |
|
save_file(shard, os.path.join(save_directory, filename)) |
|
if state_dict_split.is_sharded: |
|
index = { |
|
"metadata": state_dict_split.metadata, |
|
"weight_map": state_dict_split.tensor_to_filename, |
|
} |
|
with open(os.path.join(save_directory, "consolidated.safetensors.index.json"), "w") as f: |
|
f.write(json.dumps(index, indent=2)) |
|
|
|
big_file = 'consolidated.safetensors' |
|
loaded = load_file(big_file) |
|
|
|
save_state_dict(loaded, save_directory=f'.') |
|
|