import os import torch import json # load your large model model = SomeLargeModel('/mnt/e/ai_cache/output/wizardcoder_mmlu_2/merged') model.load_state_dict(torch.load('pytorch_model.bin')) # save each tensor to a separate file and record the mapping in the index state_dict = model.state_dict() index = {"metadata": {"total_size": 0}, "weight_map": {}} i = 1 total_files = len(state_dict.keys()) for key, tensor in state_dict.items(): chunk_file = f'pytorch_model-{str(i).zfill(5)}-of-{str(total_files).zfill(5)}.bin' torch.save({key: tensor}, chunk_file) index["weight_map"][key] = chunk_file index["metadata"]["total_size"] += tensor.nelement() * tensor.element_size() i += 1 # save the index with open('pytorch_model.bin.index', 'w') as f: json.dump(index, f)