Fine-tuning RuntimeError
I am trying to fine-tune this model on my dataset and I get a RuntimeError. Did anybody manage to fine-tune or have an idea how to resolve this?
Error message: RuntimeError: The size of tensor a (4096) must match the size of tensor b (5120) at non-singleton dimension 2
Here are snippets from my code:
model_name = "cognitivecomputations/dolphin-2.9.3-mistral-nemo-12b"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code = True
)
tokenizer.add_eos_token = True
tokenizer.add_bos_token = True
tokenizer.add_unk_token = True
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
tokenizer.truncation_side = 'right'
bnb_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = torch.bfloat16,
bnb_4bit_use_double_quant = True if 'llama' in model_name else False,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map = "auto",
use_cache = False,
quantization_config = bnb_config,
trust_remote_code = True
)
target_modules = 'all-linear'
peft_config = LoraConfig(
r = 32 ,
lora_alpha = 64,
lora_dropout = 0.05,
target_modules = target_modules,
bias = "none",
task_type = "CAUSAL_LM",
)
num_samples = train_dataset.shape[0]
num_devices = 1
per_device_train_batch_size = 2
gradient_accumulation_steps = 1
total_steps_per_epoch = num_samples / (per_device_train_batch_size * num_devices * gradient_accumulation_steps)
save_steps_adjusted = int(total_steps_per_epoch / 2)
logging_steps_adjusted = int(total_steps_per_epoch / 10)
optim = "paged_adamw_8bit"
learning_rate = 2e-4
training_arguments = TrainingArguments(
output_dir = output_dir,
num_train_epochs = 2,
per_device_train_batch_size = per_device_train_batch_size,
gradient_accumulation_steps = gradient_accumulation_steps,
optim = optim,
save_steps = save_steps_adjusted,
logging_steps = logging_steps_adjusted,
learning_rate = learning_rate,
weight_decay = 0.001,
fp16 = True,
bf16 = False,
warmup_ratio = 0.3,
lr_scheduler_type = 'cosine',
report_to = "wandb"
)
trainer = SFTTrainer(
model = model,
train_dataset = train_dataset,
peft_config = peft_config,
max_seq_length = max_seq_length,
dataset_text_field = "prompt",
tokenizer = tokenizer,
args = training_arguments,
packing = False,
)
trainer.train()
And the full error trace:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-29-3435b262f1ae> in <cell line: 1>()
----> 1 trainer.train()
30 frames
/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py in train(self, *args, **kwargs)
449 self.model = self._trl_activate_neftune(self.model)
450
--> 451 output = super().train(*args, **kwargs)
452
453 # After training we make sure to retrieve back the original forward pass method
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1930 hf_hub_utils.enable_progress_bars()
1931 else:
-> 1932 return inner_training_loop(
1933 args=args,
1934 resume_from_checkpoint=resume_from_checkpoint,
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2266
2267 with self.accelerator.accumulate(model):
-> 2268 tr_loss_step = self.training_step(model, inputs)
2269
2270 if (
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in training_step(self, model, inputs)
3305
3306 with self.compute_loss_context_manager():
-> 3307 loss = self.compute_loss(model, inputs)
3308
3309 del inputs
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in compute_loss(self, model, inputs, return_outputs)
3336 else:
3337 labels = None
-> 3338 outputs = model(**inputs)
3339 # Save past state if it exists
3340 # TODO: this needs to be fixed and made cleaner later.
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py in forward(*args, **kwargs)
817
818 def forward(*args, **kwargs):
--> 819 return model_forward(*args, **kwargs)
820
821 # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py in __call__(self, *args, **kwargs)
805
806 def __call__(self, *args, **kwargs):
--> 807 return convert_to_fp32(self.model_forward(*args, **kwargs))
808
809 def __getstate__(self):
/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py in decorate_autocast(*args, **kwargs)
14 def decorate_autocast(*args, **kwargs):
15 with autocast_instance:
---> 16 return func(*args, **kwargs)
17
18 decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode" # type: ignore[attr-defined]
/usr/local/lib/python3.10/dist-packages/peft/peft_model.py in forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
1575 with self._enable_peft_forward_hooks(**kwargs):
1576 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1577 return self.base_model(
1578 input_ids=input_ids,
1579 attention_mask=attention_mask,
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py in forward(self, *args, **kwargs)
186
187 def forward(self, *args: Any, **kwargs: Any):
--> 188 return self.model.forward(*args, **kwargs)
189
190 def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: str) -> None:
/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
167 output = module._old_forward(*args, **kwargs)
168 else:
--> 169 output = module._old_forward(*args, **kwargs)
170 return module._hf_hook.post_forward(module, output)
171
/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1198
1199 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1200 outputs = self.model(
1201 input_ids=input_ids,
1202 attention_mask=attention_mask,
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
167 output = module._old_forward(*args, **kwargs)
168 else:
--> 169 output = module._old_forward(*args, **kwargs)
170 return module._hf_hook.post_forward(module, output)
171
/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
974 )
975 else:
--> 976 layer_outputs = decoder_layer(
977 hidden_states,
978 attention_mask=causal_mask,
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
167 output = module._old_forward(*args, **kwargs)
168 else:
--> 169 output = module._old_forward(*args, **kwargs)
170 return module._hf_hook.post_forward(module, output)
171
/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
716
717 # Self Attention
--> 718 hidden_states, self_attn_weights, present_key_value = self.self_attn(
719 hidden_states=hidden_states,
720 attention_mask=attention_mask,
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
167 output = module._old_forward(*args, **kwargs)
168 else:
--> 169 output = module._old_forward(*args, **kwargs)
170 return module._hf_hook.post_forward(module, output)
171
/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
611 bsz, q_len, _ = hidden_states.size()
612
--> 613 query_states = self.q_proj(hidden_states)
614 key_states = self.k_proj(hidden_states)
615 value_states = self.v_proj(hidden_states)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/bnb.py in forward(self, x, *args, **kwargs)
500 output = output.to(expected_dtype)
501
--> 502 result = result + output
503
504 return result
RuntimeError: The size of tensor a (4096) must match the size of tensor b (5120) at non-singleton dimension 2
For some reason I was able to fine-tune this on a single A100 but not multiple. 80gb was enough for 8k context with 1 batch size / 12 accum steps.
I am trying to run the above on Google Colab on a Tesla T4 16GB VRAM, the models seems to load fine, and has plenty of room for batch size of 1. However, the error stills pops up. Any thoughts?
I really dont know.. I am not an expert in these things. It works for me, I quantized and put it on ollama, and it works.