Loading partial model #34

by maveriq - opened

I was wondering what's the most efficient way to load partial model (num_layers<70) for bloom. One naive way could be to specify the number of layers in config and then load state dict with strict=False, but this would go through all the keys of 70 layers.

Can there be a more efficient solution?

Hi @maveriq
Thank you for your question !!
I would go as suggested, but in addition to that manually recreate the pytorch_model.bin.index.json file to adjust the new mapping + rename the sharded files accordingly. I think this should work

So if let's say you want to load the first layer only your new .index.json would look like:

  "metadata": {
    "total_size": 352494542848
  "weight_map": {
    "h.0.input_layernorm.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.input_layernorm.weight": "pytorch_model_00002-of-0001.bin",
    "h.0.mlp.dense_4h_to_h.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.mlp.dense_4h_to_h.weight": "pytorch_model_00002-of-0002.bin",
    "h.0.mlp.dense_h_to_4h.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.mlp.dense_h_to_4h.weight": "pytorch_model_00002-of-0002.bin",
    "h.0.post_attention_layernorm.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.post_attention_layernorm.weight": "pytorch_model_00002-of-0002.bin",
    "h.0.self_attention.dense.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.self_attention.dense.weight": "pytorch_model_00002-of-0002.bin",
    "h.0.self_attention.query_key_value.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.self_attention.query_key_value.weight": "pytorch_model_00002-of-0002.bin",
     "word_embeddings.weight": "pytorch_model_00001-of-0002.bin",
    "word_embeddings_layernorm.bias": "pytorch_model_00001-of-0002.bin",
    "word_embeddings_layernorm.weight": "pytorch_model_00001-of-0002.bin"

And you'll have to manually rename the pytorch_model_00002-of-0072.binfile to pytorch_model_00002-of-0002.bin

Thank you for your quick reply. Is the file renaming necessary. Or can I just point to the correct file with original indices?

I think that you are right, you can try as you suggested and let me know!

Also do not forget to add as I forgot to put it above !

    "ln_f.bias": "pytorch_model_00072-of-00072.bin",
    "ln_f.weight": "pytorch_model_00072-of-00072.bin",

Hi. So I made this function that loads only the layers that you pass in, as a list.

bloom_dir = Path('/mounts/data/huggingface/bloom/')

def get_json(layers):
    index = json.load(open(bloom_dir/'pytorch_model.bin.index.json','rt'))

    load_dict = {}
    load_dict['metadata'] = {'total_size': 352494542848}
    load_dict['weight_map'] = {}

    for k,v in index['weight_map'].items():
        if k.startswith('h'):
            l = int(k.split('.')[1])
            if l in layers:

    return load_dict

The model can then be loaded as

model = AutoModel.from_pretrained('./tmp/')

The bloom_dir is the path where I cloned the repo using 'git clone https://huggingface.co/bigscience/bloom'

Hope it helps anyone who wants to do the same efficiently.

This comment has been hidden

Nice! I'll close this discussion as it seems to be resolved. Feel free to re-open if you feel that the snippet doesn't solve the discussion.

TimeRobber changed discussion status to closed

Sign up or log in to comment