import json from typing import Dict from safetensors.torch import load_file, save_file from huggingface_hub import split_torch_state_dict_into_shards import torch import os def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): state_dict_split = split_torch_state_dict_into_shards(state_dict, filename_pattern='consolidated{suffix}.safetensors') for filename, tensors in state_dict_split.filename_to_tensors.items(): shard = {tensor: state_dict[tensor] for tensor in tensors} print("Saving", save_directory, filename) save_file(shard, os.path.join(save_directory, filename)) if state_dict_split.is_sharded: index = { "metadata": state_dict_split.metadata, "weight_map": state_dict_split.tensor_to_filename, } with open(os.path.join(save_directory, "consolidated.safetensors.index.json"), "w") as f: f.write(json.dumps(index, indent=2)) big_file = 'consolidated.safetensors' loaded = load_file(big_file) save_state_dict(loaded, save_directory=f'.')