Missing LM Head

#7
by gsarti - opened
BigScience Workshop org

If I try to load the checkpoint for `bigscience/bloom-6b3' using Pytorch, I see only the modules of the 'BloomModel', but not the 'lm_head' that should also be there to instantiate a 'BloomForCausalLM':

!git clone https;//huggingface.co/bigscience/bloom-6b3
checkpoint = torch.load("bloom-6b3/pytorch_model.bin")
for param_name, param in checkpoint.items():
    print(param_name)

Output:

word_embeddings.weight
word_embeddings_layernorm.weight
word_embeddings_layernorm.bias
h.0.input_layernorm.weight
h.0.input_layernorm.bias
h.0.self_attention.query_key_value.weight
h.0.self_attention.query_key_value.bias
h.0.self_attention.dense.weight
h.0.self_attention.dense.bias
h.0.post_attention_layernorm.weight
h.0.post_attention_layernorm.bias
h.0.mlp.dense_h_to_4h.weight
h.0.mlp.dense_h_to_4h.bias
h.0.mlp.dense_4h_to_h.weight
h.0.mlp.dense_4h_to_h.bias
... # Omitted for brevity, simply all layers between 1 and 29
h.29.input_layernorm.weight
h.29.input_layernorm.bias
h.29.self_attention.query_key_value.weight
h.29.self_attention.query_key_value.bias
h.29.self_attention.dense.weight
h.29.self_attention.dense.bias
h.29.post_attention_layernorm.weight
h.29.post_attention_layernorm.bias
h.29.mlp.dense_h_to_4h.weight
h.29.mlp.dense_h_to_4h.bias
h.29.mlp.dense_4h_to_h.weight
h.29.mlp.dense_4h_to_h.bias
ln_f.weight
ln_f.bias

How is it possible to load the model with a trained causal language modeling head?

BigScience Workshop org

Hi ! For BLOOM models I think that the weights of the LM head corresponds to the transpose of the embedding weights. The ForCausalLM module takes automatically care of that ;)

BigScience Workshop org

Thanks for the info! Maybe the class itself does, but it makes it pretty painful to load the checkpoint with Accelerate using load_checkpoint_and_dispatch! I think the only alternative at the moment is to code a custom loop to map checkpoint module names to the ones expected by the class, right?

@gsarti I'm facing the same problem, could you solve this issue using load_checkpoint_and_dispatch?

Sign up or log in to comment