File size: 791 Bytes
a8048d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os
import torch
import json

# load your large model
model = SomeLargeModel('/mnt/e/ai_cache/output/wizardcoder_mmlu_2/merged')
model.load_state_dict(torch.load('pytorch_model.bin'))

# save each tensor to a separate file and record the mapping in the index
state_dict = model.state_dict()
index = {"metadata": {"total_size": 0}, "weight_map": {}}
i = 1
total_files = len(state_dict.keys())

for key, tensor in state_dict.items():
    chunk_file = f'pytorch_model-{str(i).zfill(5)}-of-{str(total_files).zfill(5)}.bin'
    torch.save({key: tensor}, chunk_file)
    index["weight_map"][key] = chunk_file
    index["metadata"]["total_size"] += tensor.nelement() * tensor.element_size()
    i += 1

# save the index
with open('pytorch_model.bin.index', 'w') as f:
    json.dump(index, f)