PyTorch model architecture doubt

#19
by JacopoBandoni - opened

"self.transformer.wte" is an embedding layer but it is also used as a language modelling head in the MTPForCausalLM class by passing the "self.transformer.wte.weight" matrix to the F.linear function in the forward function.

Am I missing something? Shouldn't the language head be separately learned ?

It's very often exactly the same weights. Just like in GPT2, Bloom, ...

Hi @JacopoBandoni , for MPT we use weight tying, which shares the word embedding weights with the final LM head. It is used by default in most HF causal language models, you can see the codepath here: https://github.com/huggingface/transformers/blob/130e15429116689c9d747be2cdd8c4be7bb7e2bd/src/transformers/modeling_utils.py#L1245-L1264

For MPT, to make the model easier to deal with for meta initialization and FSDP, we directly use the self.transformer.wte.weightas the LM head rather than create a separate nn.Linear module and tie the weights.

abhi-mosaic changed discussion status to closed

Sign up or log in to comment