Set more precise shape to the attention weights and outputs

#1
by ivanzhouyq - opened

This PR sets more precise shape to the attention's weights, biases, and the outputs.

The original implementation assumes that the embedding size is a multiple of senses. However, when the hyperparamters are picked so that the embedding size is no longer the multiple of the number of senses, it would cause different number of parameters in encoded before and after reshaping. I found this mismatching when testing with a larger model with embedding size = 1280 and number of senses at 48. I received the following error:

- sense_weight_net.c_attn.bias: found shape torch.Size([2496]) in the checkpoint and torch.Size([2560]) in the model instantiated
- sense_weight_net.c_attn.weight: found shape torch.Size([2496, 1280]) in the checkpoint and torch.Size([2560, 1280]) in the model instantiated

This PR addresses this issue. Of course, a better solution should be recommending/enforcing the embedding size to be a full multiple of senses.

johnhew changed pull request status to merged

Sign up or log in to comment