Loading partial model #34
I was wondering what's the most efficient way to load partial model (num_layers<70) for bloom. One naive way could be to specify the number of layers in config and then load state dict with strict=False, but this would go through all the keys of 70 layers.
Can there be a more efficient solution?
Hi @maveriq
Thank you for your question !!
I would go as suggested, but in addition to that manually recreate the pytorch_model.bin.index.json
file to adjust the new mapping + rename the sharded files accordingly. I think this should work
So if let's say you want to load the first layer only your new .index.json
would look like:
{
"metadata": {
"total_size": 352494542848
},
"weight_map": {
"h.0.input_layernorm.bias": "pytorch_model_00002-of-0002.bin",
"h.0.input_layernorm.weight": "pytorch_model_00002-of-0001.bin",
"h.0.mlp.dense_4h_to_h.bias": "pytorch_model_00002-of-0002.bin",
"h.0.mlp.dense_4h_to_h.weight": "pytorch_model_00002-of-0002.bin",
"h.0.mlp.dense_h_to_4h.bias": "pytorch_model_00002-of-0002.bin",
"h.0.mlp.dense_h_to_4h.weight": "pytorch_model_00002-of-0002.bin",
"h.0.post_attention_layernorm.bias": "pytorch_model_00002-of-0002.bin",
"h.0.post_attention_layernorm.weight": "pytorch_model_00002-of-0002.bin",
"h.0.self_attention.dense.bias": "pytorch_model_00002-of-0002.bin",
"h.0.self_attention.dense.weight": "pytorch_model_00002-of-0002.bin",
"h.0.self_attention.query_key_value.bias": "pytorch_model_00002-of-0002.bin",
"h.0.self_attention.query_key_value.weight": "pytorch_model_00002-of-0002.bin",
"word_embeddings.weight": "pytorch_model_00001-of-0002.bin",
"word_embeddings_layernorm.bias": "pytorch_model_00001-of-0002.bin",
"word_embeddings_layernorm.weight": "pytorch_model_00001-of-0002.bin"
}
}
And you'll have to manually rename the pytorch_model_00002-of-0072.bin
file to pytorch_model_00002-of-0002.bin
Thank you for your quick reply. Is the file renaming necessary. Or can I just point to the correct file with original indices?
I think that you are right, you can try as you suggested and let me know!
Also do not forget to add as I forgot to put it above !
"ln_f.bias": "pytorch_model_00072-of-00072.bin",
"ln_f.weight": "pytorch_model_00072-of-00072.bin",
Hi. So I made this function that loads only the layers that you pass in, as a list.
bloom_dir = Path('/mounts/data/huggingface/bloom/')
def get_json(layers):
index = json.load(open(bloom_dir/'pytorch_model.bin.index.json','rt'))
load_dict = {}
load_dict['metadata'] = {'total_size': 352494542848}
load_dict['weight_map'] = {}
for k,v in index['weight_map'].items():
if k.startswith('h'):
l = int(k.split('.')[1])
if l in layers:
load_dict['weight_map'][k]=str(base_dir/v)
else:
load_dict['weight_map'][k]=str(base_dir/v)
return load_dict
The model can then be loaded as
json.dump(get_json([4]),open('./tmp/pytorch_model.bin.index.json','wt'))
model = AutoModel.from_pretrained('./tmp/')
The bloom_dir is the path where I cloned the repo using 'git clone https://huggingface.co/bigscience/bloom'
Hope it helps anyone who wants to do the same efficiently.
Nice! I'll close this discussion as it seems to be resolved. Feel free to re-open if you feel that the snippet doesn't solve the discussion.