remove fix-sized causal mask
Browse files- modeling_qwen.py +3 -76
modeling_qwen.py
CHANGED
@@ -395,62 +395,6 @@ class QWenAttention(nn.Module):
|
|
395 |
|
396 |
return attn_output, attn_weights
|
397 |
|
398 |
-
def _upcast_and_reordered_attn(
|
399 |
-
self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None
|
400 |
-
):
|
401 |
-
bsz, num_heads, q_seq_len, dk = query.size()
|
402 |
-
_, _, k_seq_len, _ = key.size()
|
403 |
-
|
404 |
-
attn_weights = torch.empty(
|
405 |
-
bsz * num_heads,
|
406 |
-
q_seq_len,
|
407 |
-
k_seq_len,
|
408 |
-
dtype=torch.float32,
|
409 |
-
device=query.device,
|
410 |
-
)
|
411 |
-
|
412 |
-
scale_factor = 1.0
|
413 |
-
if self.scale_attn_weights:
|
414 |
-
scale_factor /= float(value.size(-1)) ** 0.5
|
415 |
-
|
416 |
-
with autocast(enabled=False):
|
417 |
-
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
|
418 |
-
-1, dk, k_seq_len
|
419 |
-
)
|
420 |
-
attn_weights = torch.baddbmm(
|
421 |
-
attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
|
422 |
-
)
|
423 |
-
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
424 |
-
|
425 |
-
query_length, key_length = query.size(-2), key.size(-2)
|
426 |
-
causal_mask = registered_causal_mask[
|
427 |
-
:, :, key_length - query_length : key_length, :key_length
|
428 |
-
]
|
429 |
-
mask_value = torch.finfo(attn_weights.dtype).min
|
430 |
-
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
|
431 |
-
attn_weights.device
|
432 |
-
)
|
433 |
-
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
434 |
-
|
435 |
-
if attention_mask is not None:
|
436 |
-
attn_weights = attn_weights + attention_mask
|
437 |
-
|
438 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
439 |
-
|
440 |
-
if attn_weights.dtype != torch.float32:
|
441 |
-
raise RuntimeError(
|
442 |
-
"Error with upcasting, attn_weights does not have dtype torch.float32"
|
443 |
-
)
|
444 |
-
attn_weights = attn_weights.type(value.dtype)
|
445 |
-
attn_weights = self.attn_dropout(attn_weights)
|
446 |
-
|
447 |
-
if head_mask is not None:
|
448 |
-
attn_weights = attn_weights * head_mask
|
449 |
-
|
450 |
-
attn_output = torch.matmul(attn_weights, value)
|
451 |
-
|
452 |
-
return attn_output, attn_weights
|
453 |
-
|
454 |
def _split_heads(self, tensor, num_heads, attn_head_size):
|
455 |
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
456 |
tensor = tensor.view(new_shape)
|
@@ -465,7 +409,6 @@ class QWenAttention(nn.Module):
|
|
465 |
self,
|
466 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
467 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
468 |
-
registered_causal_mask: Optional[torch.Tensor] = None,
|
469 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
470 |
attention_mask: Optional[torch.FloatTensor] = None,
|
471 |
head_mask: Optional[torch.FloatTensor] = None,
|
@@ -558,6 +501,9 @@ class QWenAttention(nn.Module):
|
|
558 |
q, k, v = query, key, value
|
559 |
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
560 |
else:
|
|
|
|
|
|
|
561 |
query = query.permute(0, 2, 1, 3)
|
562 |
if not self.use_cache_quantization:
|
563 |
key = key.permute(0, 2, 1, 3)
|
@@ -650,7 +596,6 @@ class QWenBlock(nn.Module):
|
|
650 |
self,
|
651 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
652 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
653 |
-
registered_causal_mask: Optional[torch.Tensor] = None,
|
654 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
655 |
attention_mask: Optional[torch.FloatTensor] = None,
|
656 |
head_mask: Optional[torch.FloatTensor] = None,
|
@@ -664,7 +609,6 @@ class QWenBlock(nn.Module):
|
|
664 |
attn_outputs = self.attn(
|
665 |
layernorm_output,
|
666 |
rotary_pos_emb_list,
|
667 |
-
registered_causal_mask=registered_causal_mask,
|
668 |
layer_past=layer_past,
|
669 |
attention_mask=attention_mask,
|
670 |
head_mask=head_mask,
|
@@ -764,21 +708,6 @@ class QWenModel(QWenPreTrainedModel):
|
|
764 |
|
765 |
self.use_flash_attn = config.use_flash_attn
|
766 |
self.is_fp32 = not (config.bf16 or config.fp16)
|
767 |
-
if (
|
768 |
-
self.use_flash_attn
|
769 |
-
and flash_attn_unpadded_func is not None
|
770 |
-
and not self.is_fp32
|
771 |
-
):
|
772 |
-
self.registered_causal_mask = None
|
773 |
-
else:
|
774 |
-
max_positions = config.max_position_embeddings
|
775 |
-
self.register_buffer(
|
776 |
-
"registered_causal_mask",
|
777 |
-
torch.tril(
|
778 |
-
torch.ones((max_positions, max_positions), dtype=torch.bool)
|
779 |
-
).view(1, 1, max_positions, max_positions),
|
780 |
-
persistent=False,
|
781 |
-
)
|
782 |
|
783 |
self.h = nn.ModuleList(
|
784 |
[
|
@@ -950,7 +879,6 @@ class QWenModel(QWenPreTrainedModel):
|
|
950 |
create_custom_forward(block),
|
951 |
hidden_states,
|
952 |
rotary_pos_emb_list,
|
953 |
-
self.registered_causal_mask,
|
954 |
None,
|
955 |
attention_mask,
|
956 |
head_mask[i],
|
@@ -962,7 +890,6 @@ class QWenModel(QWenPreTrainedModel):
|
|
962 |
hidden_states,
|
963 |
layer_past=layer_past,
|
964 |
rotary_pos_emb_list=rotary_pos_emb_list,
|
965 |
-
registered_causal_mask=self.registered_causal_mask,
|
966 |
attention_mask=attention_mask,
|
967 |
head_mask=head_mask[i],
|
968 |
encoder_hidden_states=encoder_hidden_states,
|
|
|
395 |
|
396 |
return attn_output, attn_weights
|
397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
def _split_heads(self, tensor, num_heads, attn_head_size):
|
399 |
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
400 |
tensor = tensor.view(new_shape)
|
|
|
409 |
self,
|
410 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
411 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
|
|
412 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
413 |
attention_mask: Optional[torch.FloatTensor] = None,
|
414 |
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
501 |
q, k, v = query, key, value
|
502 |
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
503 |
else:
|
504 |
+
registered_causal_mask = torch.tril(
|
505 |
+
torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
|
506 |
+
).view(1, 1, key.size(1), key.size(1))
|
507 |
query = query.permute(0, 2, 1, 3)
|
508 |
if not self.use_cache_quantization:
|
509 |
key = key.permute(0, 2, 1, 3)
|
|
|
596 |
self,
|
597 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
598 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
|
|
599 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
600 |
attention_mask: Optional[torch.FloatTensor] = None,
|
601 |
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
609 |
attn_outputs = self.attn(
|
610 |
layernorm_output,
|
611 |
rotary_pos_emb_list,
|
|
|
612 |
layer_past=layer_past,
|
613 |
attention_mask=attention_mask,
|
614 |
head_mask=head_mask,
|
|
|
708 |
|
709 |
self.use_flash_attn = config.use_flash_attn
|
710 |
self.is_fp32 = not (config.bf16 or config.fp16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
711 |
|
712 |
self.h = nn.ModuleList(
|
713 |
[
|
|
|
879 |
create_custom_forward(block),
|
880 |
hidden_states,
|
881 |
rotary_pos_emb_list,
|
|
|
882 |
None,
|
883 |
attention_mask,
|
884 |
head_mask[i],
|
|
|
890 |
hidden_states,
|
891 |
layer_past=layer_past,
|
892 |
rotary_pos_emb_list=rotary_pos_emb_list,
|
|
|
893 |
attention_mask=attention_mask,
|
894 |
head_mask=head_mask[i],
|
895 |
encoder_hidden_states=encoder_hidden_states,
|