import os | |
import torch | |
import json | |
# load your large model | |
model = SomeLargeModel('/mnt/e/ai_cache/output/wizardcoder_mmlu_2/merged') | |
model.load_state_dict(torch.load('pytorch_model.bin')) | |
# save each tensor to a separate file and record the mapping in the index | |
state_dict = model.state_dict() | |
index = {"metadata": {"total_size": 0}, "weight_map": {}} | |
i = 1 | |
total_files = len(state_dict.keys()) | |
for key, tensor in state_dict.items(): | |
chunk_file = f'pytorch_model-{str(i).zfill(5)}-of-{str(total_files).zfill(5)}.bin' | |
torch.save({key: tensor}, chunk_file) | |
index["weight_map"][key] = chunk_file | |
index["metadata"]["total_size"] += tensor.nelement() * tensor.element_size() | |
i += 1 | |
# save the index | |
with open('pytorch_model.bin.index', 'w') as f: | |
json.dump(index, f) | |