gugarosa commited on
Commit
1f890f7
1 Parent(s): d22f35e

fix(phi-1): Checks length of `attention_mask`if it is passed as direct tensor.

Browse files
Files changed (1) hide show
  1. modeling_mixformer_sequential.py +3 -3
modeling_mixformer_sequential.py CHANGED
@@ -35,7 +35,7 @@ from __future__ import annotations
35
 
36
  import math
37
  import copy
38
- from typing import Any, Dict, Optional, Tuple
39
  from dataclasses import dataclass, field
40
 
41
  import torch
@@ -541,8 +541,8 @@ class MHA(nn.Module):
541
  kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
542
 
543
  if attention_mask is not None:
544
- attention_mask, cu_seqlens, max_seqlen = attention_mask
545
- attention_mask = attention_mask.to(qkv.device)
546
 
547
  attention_kwargs = {"attention_mask": attention_mask}
548
 
 
35
 
36
  import math
37
  import copy
38
+ from typing import Any, Dict, Optional, Tuple, Union
39
  from dataclasses import dataclass, field
40
 
41
  import torch
 
541
  kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
542
 
543
  if attention_mask is not None:
544
+ attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
545
+ attention_mask = attention_mask.bool().to(qkv.device)
546
 
547
  attention_kwargs = {"attention_mask": attention_mask}
548