qlora-support
#2
by
muelletm
- opened
Make sure hidden state and wte weights are on same device when in parallel model.
This should fix the following crash when running qlora:
Traceback (most recent call last):
File "/code/qlora/qlora.py", line 758, in <module>
train()
File "/code/qlora/qlora.py", line 720, in train
train_result = trainer.train(resume_from_checkpoint=checkpoint_dir)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1696, in train
return inner_training_loop(
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1973, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2787, in training_step
loss = self.compute_loss(model, inputs)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2819, in compute_loss
outputs = model(**inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/peft/peft_model.py", line 686, in forward
return self.base_model(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/mpt-7b/modeling_mpt.py", line 294, in forward
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
[DELETED]
muelletm
changed pull request status to
open
Hi, I've made the requested changes, try it now. Will also update README. 👍
Thanks! (I don't know if you saw the PR attached to this, I guess we can close it now?)
muelletm
changed pull request status to
closed