Image-Text-to-Text
Transformers
Safetensors
English
idefics2
pretraining
multimodal
vision
Inference Endpoints
5 papers

Fine-tuning Script: QLoRA w/ Flash Attn fails

#41
by RonanMcGovern - opened

RuntimeError Traceback (most recent call last)
Cell In[8], line 1
----> 1 trainer.train()

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1875, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1873 hf_hub_utils.enable_progress_bars()
1874 else:
-> 1875 return inner_training_loop(
1876 args=args,
1877 resume_from_checkpoint=resume_from_checkpoint,
1878 trial=trial,
1879 ignore_keys_for_eval=ignore_keys_for_eval,
1880 )

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2206, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2203 self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
2205 with self.accelerator.accumulate(model):
-> 2206 tr_loss_step = self.training_step(model, inputs)
2208 if (
2209 args.logging_nan_inf_filter
2210 and not is_torch_xla_available()
2211 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
2212 ):
2213 # if loss is nan or inf simply add the average of previous logged losses
2214 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3184, in Trainer.training_step(self, model, inputs)
3181 return loss_mb.reduce_mean().detach().to(self.args.device)
3183 with self.compute_loss_context_manager():
-> 3184 loss = self.compute_loss(model, inputs)
3186 if self.args.n_gpu > 1:
3187 loss = loss.mean() # mean() to average on multi-gpu parallel training

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3207, in Trainer.compute_loss(self, model, inputs, return_outputs)
3205 else:
3206 labels = None
-> 3207 outputs = model(**inputs)
3208 # Save past state if it exists
3209 # TODO: this needs to be fixed and made cleaner later.
3210 if self.args.past_index >= 0:

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:825, in convert_outputs_to_fp32..forward(*args, **kwargs)
824 def forward(*args, **kwargs):
--> 825 return model_forward(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:813, in ConvertOutputsToFp32.call(self, *args, **kwargs)
812 def call(self, *args, **kwargs):
--> 813 return convert_to_fp32(self.model_forward(*args, **kwargs))

File /usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:16, in autocast_decorator..decorate_autocast(*args, **kwargs)
13 @functools.wraps(func)
14 def decorate_autocast(*args, **kwargs):
15 with autocast_instance:
---> 16 return func(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1829, in Idefics2ForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1826 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1828 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1829 outputs = self.model(
1830 input_ids=input_ids,
1831 attention_mask=attention_mask,
1832 position_ids=position_ids,
1833 past_key_values=past_key_values,
1834 inputs_embeds=inputs_embeds,
1835 pixel_values=pixel_values,
1836 pixel_attention_mask=pixel_attention_mask,
1837 image_hidden_states=image_hidden_states,
1838 use_cache=use_cache,
1839 output_attentions=output_attentions,
1840 output_hidden_states=output_hidden_states,
1841 return_dict=return_dict,
1842 )
1844 hidden_states = outputs[0]
1845 logits = self.lm_head(hidden_states)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1649, in Idefics2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, use_cache, output_attentions, output_hidden_states, return_dict)
1643 image_hidden_states = self.vision_model(
1644 pixel_values=pixel_values,
1645 patch_attention_mask=patch_attention_mask,
1646 ).last_hidden_state
1648 # Modality projection & resampling
-> 1649 image_hidden_states = self.connector(
1650 image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
1651 )
1653 elif image_hidden_states is not None:
1654 image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1317, in Idefics2Connector.forward(self, image_hidden_states, attention_mask)
1315 def forward(self, image_hidden_states, attention_mask):
1316 image_hidden_states = self.modality_projection(image_hidden_states)
-> 1317 image_hidden_states = self.perceiver_resampler(context=image_hidden_states, attention_mask=attention_mask)
1318 return image_hidden_states

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1287, in Idefics2PerceiverResampler.forward(self, context, attention_mask)
1285 compressed_context = latents
1286 for perceiver_layer in self.layers:
-> 1287 layer_outputs = perceiver_layer(
1288 compressed_context,
1289 context,
1290 attention_mask=attention_mask,
1291 position_ids=None,
1292 past_key_value=None,
1293 output_attentions=False,
1294 use_cache=False,
1295 )
1297 compressed_context = layer_outputs[0]
1299 compressed_context = self.norm(compressed_context)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1220, in Idefics2PerceiverLayer.forward(self, latents, context, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
1217 latents = self.input_latents_norm(latents)
1218 context = self.input_context_norm(context)
-> 1220 latents, self_attn_weights, present_key_value = self.self_attn(
1221 latents=latents,
1222 context=context,
1223 attention_mask=attention_mask,
1224 )
1225 latents = residual + latents
1226 residual = latents

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /usr/local/lib/python3.10/dist-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1004, in Idefics2PerceiverFlashAttention2.forward(self, latents, context, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
1001 key_states = key_states.transpose(1, 2)
1002 value_states = value_states.transpose(1, 2)
-> 1004 attn_output = self._flash_attention_forward(
1005 query_states,
1006 key_states,
1007 value_states,
1008 attention_mask,
1009 q_len,
1010 dropout=dropout_rate,
1011 use_sliding_windows=False,
1012 )
1014 attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
1015 attn_output = self.o_proj(attn_output)

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1071, in Idefics2PerceiverFlashAttention2._flash_attention_forward(self, query_states, key_states, value_states, attention_mask, query_length, dropout, softmax_scale, use_sliding_windows)
1068 max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1070 if not use_sliding_windows:
-> 1071 attn_output_unpad = flash_attn_varlen_func(
1072 query_states,
1073 key_states,
1074 value_states,
1075 cu_seqlens_q=cu_seqlens_q,
1076 cu_seqlens_k=cu_seqlens_k,
1077 max_seqlen_q=max_seqlen_in_batch_q,
1078 max_seqlen_k=max_seqlen_in_batch_k,
1079 dropout_p=dropout,
1080 softmax_scale=softmax_scale,
1081 causal=causal,
1082 )
1083 else:
1084 attn_output_unpad = flash_attn_varlen_func(
1085 query_states,
1086 key_states,
(...)
1095 window_size=(self.config.sliding_window, self.config.sliding_window),
1096 )

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:1066, in flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, block_table)
995 def flash_attn_varlen_func(
996 q,
997 k,
(...)
1010 block_table=None,
1011 ):
1012 """dropout_p should be set to 0.0 during evaluation
1013 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1014 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
(...)
1064 pattern (negative means that location was dropped, nonnegative means it was kept).
1065 """
-> 1066 return FlashAttnVarlenFunc.apply(
1067 q,
1068 k,
1069 v,
1070 cu_seqlens_q,
1071 cu_seqlens_k,
1072 max_seqlen_q,
1073 max_seqlen_k,
1074 dropout_p,
1075 softmax_scale,
1076 causal,
1077 window_size,
1078 alibi_slopes,
1079 deterministic,
1080 return_attn_probs,
1081 block_table,
1082 )

File /usr/local/lib/python3.10/dist-packages/torch/autograd/function.py:539, in Function.apply(cls, *args, **kwargs)
536 if not torch._C._are_functorch_transforms_active():
537 # See NOTE: [functorch vjp and autograd interaction]
538 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 539 return super().apply(*args, **kwargs) # type: ignore[misc]
541 if cls.setup_context == _SingleLevelFunction.setup_context:
542 raise RuntimeError(
543 "In order to use an autograd.Function with functorch transforms "
544 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
545 "staticmethod. For more details, please see "
546 "https://pytorch.org/docs/master/notes/extending.func.html"
547 )

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:581, in FlashAttnVarlenFunc.forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, block_table)
579 if softmax_scale is None:
580 softmax_scale = q.shape[-1] ** (-0.5)
--> 581 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
582 q,
583 k,
584 v,
585 cu_seqlens_q,
586 cu_seqlens_k,
587 max_seqlen_q,
588 max_seqlen_k,
589 dropout_p,
590 softmax_scale,
591 causal=causal,
592 window_size=window_size,
593 alibi_slopes=alibi_slopes,
594 return_softmax=return_softmax and dropout_p > 0,
595 block_table=block_table,
596 )
597 ctx.save_for_backward(
598 q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
599 )
600 ctx.dropout_p = dropout_p

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:86, in _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, block_table)
84 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
85 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 86 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
87 q,
88 k,
89 v,
90 None,
91 cu_seqlens_q,
92 cu_seqlens_k,
93 None,
94 block_table,
95 alibi_slopes,
96 max_seqlen_q,
97 max_seqlen_k,
98 dropout_p,
99 softmax_scale,
100 False,
101 causal,
102 window_size[0],
103 window_size[1],
104 return_softmax,
105 None,
106 )
107 # if out.isnan().any() or softmax_lse.isnan().any():
108 # breakpoint()
109 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state

RuntimeError: query and key must have the same dtype

when running the training.
RonanMcGovern changed discussion title from Fine-tuning script error to Fine-tuning Script: QLoRA w/ Flash Attn fails

Seems similar to this issue
https://github.com/huggingface/transformers/issues/30019

Thanks, yes, that's the issue. I'll subscribe to that on github and comment back here once that's resolved.

Sign up or log in to comment