by maveriq

I was wondering what's the most efficient way to load partial model (num_layers<70) for bloom. One naive way could be to specify the number of layers in config and then load state dict with strict=False, but this would go through all the keys of 70 layers.

Can there be a more efficient solution?

Hi @maveriq
Thank you for your question !!
I would go as suggested, but in addition to that manually recreate the pytorch_model.bin.index.json file to adjust the new mapping + rename the sharded files accordingly. I think this should work

So if let's say you want to load the first layer only your new .index.json would look like:

  "metadata": {
    "total_size": 352494542848
  "weight_map": {
    "h.0.input_layernorm.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.input_layernorm.weight": "pytorch_model_00002-of-0001.bin",
    "h.0.mlp.dense_4h_to_h.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.mlp.dense_4h_to_h.weight": "pytorch_model_00002-of-0002.bin",
    "h.0.mlp.dense_h_to_4h.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.mlp.dense_h_to_4h.weight": "pytorch_model_00002-of-0002.bin",
    "h.0.post_attention_layernorm.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.post_attention_layernorm.weight": "pytorch_model_00002-of-0002.bin",
    "h.0.self_attention.dense.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.self_attention.dense.weight": "pytorch_model_00002-of-0002.bin",
    "h.0.self_attention.query_key_value.bias": "pytorch_model_00002-of-0002.bin",
    "h.0.self_attention.query_key_value.weight": "pytorch_model_00002-of-0002.bin",
     "word_embeddings.weight": "pytorch_model_00001-of-0002.bin",
    "word_embeddings_layernorm.bias": "pytorch_model_00001-of-0002.bin",
    "word_embeddings_layernorm.weight": "pytorch_model_00001-of-0002.bin"

And you'll have to manually rename the pytorch_model_00002-of-0072.binfile to pytorch_model_00002-of-0002.bin

Thank you for your quick reply. Is the file renaming necessary. Or can I just point to the correct file with original indices?

I think that you are right, you can try as you suggested and let me know!

Also do not forget to add as I forgot to put it above !

    "ln_f.bias": "pytorch_model_00072-of-00072.bin",
    "ln_f.weight": "pytorch_model_00072-of-00072.bin",

Hi. So I made this function that loads only the layers that you pass in, as a list.

bloom_dir = Path('/mounts/data/huggingface/bloom/')

def get_json(layers):
    index = json.load(open(bloom_dir/'pytorch_model.bin.index.json','rt'))

    load_dict = {}
    load_dict['metadata'] = {'total_size': 352494542848}
    load_dict['weight_map'] = {}

    for k,v in index['weight_map'].items():
        if k.startswith('h'):
            l = int(k.split('.')[1])
            if l in layers:

    return load_dict

The model can then be loaded as

model = AutoModel.from_pretrained('./tmp/')

The bloom_dir is the path where I cloned the repo using 'git clone https://huggingface.co/bigscience/bloom'

Hope it helps anyone who wants to do the same efficiently.

Nice! I'll close this discussion as it seems to be resolved. Feel free to re-open if you feel that the snippet doesn't solve the discussion.

