yuanzhoulvpi commited on
Commit
af36ef2
1 Parent(s): b3ca596

Add flash attention

Browse files

add flash attention to `BaichuanAttention` class.

Files changed (1) hide show
  1. modeling_baichuan.py +28 -12
modeling_baichuan.py CHANGED
@@ -138,20 +138,36 @@ class BaichuanAttention(torch.nn.Module):
138
 
139
  past_key_value = (key_states, value_states) if use_cache else None
140
 
141
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
142
-
143
- if attention_mask is not None:
144
- if q_len == 1: # inference with cache
145
- if len(attention_mask.size()) == 4:
146
- attention_mask = attention_mask[:, :, -1:, :]
147
- else:
148
- attention_mask = attention_mask[:, -1:, :]
149
- attn_weights = attn_weights + attention_mask
150
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
151
 
152
- attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 
 
 
 
 
153
 
154
- attn_output = torch.matmul(attn_weights, value_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  attn_output = attn_output.transpose(1, 2)
157
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
138
 
139
  past_key_value = (key_states, value_states) if use_cache else None
140
 
141
+ pytorch_major_version = int(torch.__version__.split('.')[0])
142
+ if pytorch_major_version >= 2:
143
+ if attention_mask is not None:
144
+ if q_len == 1: # inference with cache
145
+ if len(attention_mask.size()) == 4:
146
+ attention_mask = attention_mask[:, :, -1:, :]
147
+ else:
148
+ attention_mask = attention_mask[:, -1:, :]
 
 
149
 
150
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query_states,
151
+ key_states,
152
+ value_states,
153
+ dropout_p=0.0,
154
+ attn_mask=attention_mask)
155
+ else:
156
 
157
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
158
+
159
+ if attention_mask is not None:
160
+ if q_len == 1: # inference with cache
161
+ if len(attention_mask.size()) == 4:
162
+ attention_mask = attention_mask[:, :, -1:, :]
163
+ else:
164
+ attention_mask = attention_mask[:, -1:, :]
165
+ attn_weights = attn_weights + attention_mask
166
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
167
+
168
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
169
+
170
+ attn_output = torch.matmul(attn_weights, value_states)
171
 
172
  attn_output = attn_output.transpose(1, 2)
173
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)