Change `wte` to use shared embedding

#43
by bcui19 - opened
Mosaic ML, Inc. org

Change wte to use shared embedding

daking changed pull request status to merged

This change leads to the following error with torch==0.2.1:

"/home/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-instruct/_____/modeling_mpt.py", line 271, in forward
    logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
  File "/home/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given

Sign up or log in to comment