robinzixuan commited on
Commit
e9790b8
1 Parent(s): b8c7545

Update modeling_opt.py

Browse files
Files changed (1) hide show
  1. modeling_opt.py +5 -180
modeling_opt.py CHANGED
@@ -104,7 +104,6 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
104
  def softmax_n_shifted_zeros(input: torch.Tensor, n: int, dim=-1) -> torch.Tensor:
105
  """
106
  $\text(softmax)_n(x_i) = exp(x_i) / (n + \sum_j exp(x_j))$
107
-
108
  Note: softmax_n, with fixed input, is _not_ shift-symmetric when n != 0
109
  """
110
  # compute the maxes along the last dimension
@@ -126,7 +125,8 @@ def softmax_1(input: torch.Tensor, dim=-1, dtype=torch.float32) -> torch.Tensor:
126
  """
127
  $\text(softmax)_n(x_i) = exp(x_i) / (1 + \sum_j exp(x_j))$
128
  """
129
- return softmax_n_shifted_zeros(input, 1, dim=dim)
 
130
 
131
 
132
  class OPTAttention(nn.Module):
@@ -169,182 +169,6 @@ class OPTAttention(nn.Module):
169
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
170
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
171
 
172
- def forward(
173
- self,
174
- hidden_states: torch.Tensor,
175
- key_value_states: Optional[torch.Tensor] = None,
176
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
177
- attention_mask: Optional[torch.Tensor] = None,
178
- layer_head_mask: Optional[torch.Tensor] = None,
179
- output_attentions: bool = False,
180
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
181
- """Input shape: Batch x Time x Channel"""
182
-
183
- # if key_value_states are provided this layer is used as a cross-attention layer
184
- # for the decoder
185
- is_cross_attention = key_value_states is not None
186
-
187
- bsz, tgt_len, _ = hidden_states.size()
188
-
189
- # get query proj
190
- query_states = self.q_proj(hidden_states) * self.scaling
191
- # get key, value proj
192
- if is_cross_attention and past_key_value is not None:
193
- # reuse k,v, cross_attentions
194
- key_states = past_key_value[0]
195
- value_states = past_key_value[1]
196
- elif is_cross_attention:
197
- # cross_attentions
198
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
199
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
200
- elif past_key_value is not None:
201
- # reuse k, v, self_attention
202
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
203
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
204
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
205
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
206
- else:
207
- # self_attention
208
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
209
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
210
-
211
- if self.is_decoder:
212
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
213
- # Further calls to cross_attention layer can then reuse all cross-attention
214
- # key/value_states (first "if" case)
215
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
216
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
217
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
218
- # if encoder bi-directional self-attention `past_key_value` is always `None`
219
- past_key_value = (key_states, value_states)
220
-
221
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
222
- query_states = self._shape(
223
- query_states, tgt_len, bsz).view(*proj_shape)
224
- key_states = key_states.view(*proj_shape)
225
- value_states = value_states.view(*proj_shape)
226
-
227
- src_len = key_states.size(1)
228
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
229
-
230
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
231
- raise ValueError(
232
- f"Attention weights should be of size {
233
- (bsz * self.num_heads, tgt_len, src_len)}, but is"
234
- f" {attn_weights.size()}"
235
- )
236
-
237
- if attention_mask is not None:
238
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
239
- raise ValueError(
240
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {
241
- attention_mask.size()}"
242
- )
243
- attn_weights = attn_weights.view(
244
- bsz, self.num_heads, tgt_len, src_len) + attention_mask
245
- attn_weights = torch.max(
246
- attn_weights, torch.tensor(torch.finfo(
247
- attn_weights.dtype).min, device=attn_weights.device)
248
- )
249
- attn_weights = attn_weights.view(
250
- bsz * self.num_heads, tgt_len, src_len)
251
-
252
- # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
253
- if attn_weights.dtype == torch.float16:
254
- attn_weights = nn.functional.softmax(
255
- attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
256
- else:
257
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
258
-
259
- if layer_head_mask is not None:
260
- if layer_head_mask.size() != (self.num_heads,):
261
- raise ValueError(
262
- f"Head mask for a single layer should be of size {
263
- (self.num_heads,)}, but is"
264
- f" {layer_head_mask.size()}"
265
- )
266
- attn_weights = layer_head_mask.view(
267
- 1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
268
- attn_weights = attn_weights.view(
269
- bsz * self.num_heads, tgt_len, src_len)
270
-
271
- if output_attentions:
272
- # this operation is a bit awkward, but it's required to
273
- # make sure that attn_weights keeps its gradient.
274
- # In order to do so, attn_weights have to be reshaped
275
- # twice and have to be reused in the following
276
- attn_weights_reshaped = attn_weights.view(
277
- bsz, self.num_heads, tgt_len, src_len)
278
- attn_weights = attn_weights_reshaped.view(
279
- bsz * self.num_heads, tgt_len, src_len)
280
- else:
281
- attn_weights_reshaped = None
282
-
283
- attn_probs = nn.functional.dropout(
284
- attn_weights, p=self.dropout, training=self.training)
285
-
286
- attn_output = torch.bmm(attn_probs, value_states)
287
-
288
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
289
- raise ValueError(
290
- f"`attn_output` should be of size {
291
- (bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
292
- f" {attn_output.size()}"
293
- )
294
-
295
- attn_output = attn_output.view(
296
- bsz, self.num_heads, tgt_len, self.head_dim)
297
- attn_output = attn_output.transpose(1, 2)
298
-
299
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
300
- # partitioned aross GPUs when using tensor-parallelism.
301
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
302
-
303
- attn_output = self.out_proj(attn_output)
304
-
305
- return attn_output, attn_weights_reshaped, past_key_value
306
-
307
-
308
- class OPTOutEffHop(OPTAttention):
309
- """Multi-headed attention from 'Attention Is All You Need' paper"""
310
-
311
- def __init__(
312
- self,
313
- config: OPTConfig,
314
- is_decoder: bool = False,
315
- **kwargs,
316
- ):
317
- super().__init__()
318
- self.config = config
319
- self.embed_dim = config.hidden_size
320
- self.num_heads = config.num_attention_heads
321
- self.dropout = config.attention_dropout
322
- self.enable_bias = config.enable_bias
323
-
324
- self.head_dim = self.embed_dim // self.num_heads
325
- self.is_causal = True
326
-
327
- if (self.head_dim * self.num_heads) != self.embed_dim:
328
- raise ValueError(
329
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {
330
- self.embed_dim}"
331
- f" and `num_heads`: {self.num_heads})."
332
- )
333
- self.scaling = self.head_dim**-0.5
334
- self.is_decoder = is_decoder
335
-
336
- self.k_proj = nn.Linear(
337
- self.embed_dim, self.embed_dim, bias=self.enable_bias)
338
- self.v_proj = nn.Linear(
339
- self.embed_dim, self.embed_dim, bias=self.enable_bias)
340
- self.q_proj = nn.Linear(
341
- self.embed_dim, self.embed_dim, bias=self.enable_bias)
342
- self.out_proj = nn.Linear(
343
- self.embed_dim, self.embed_dim, bias=self.enable_bias)
344
-
345
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
346
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
347
-
348
  def forward(
349
  self,
350
  hidden_states: torch.Tensor,
@@ -481,6 +305,8 @@ class OPTOutEffHop(OPTAttention):
481
  return attn_output, attn_weights_reshaped, past_key_value
482
 
483
 
 
 
484
  class OptFlashAttention2(OPTAttention):
485
  """
486
  OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
@@ -705,7 +531,6 @@ class OptFlashAttention2(OPTAttention):
705
  OPT_ATTENTION_CLASSES = {
706
  "eager": OPTAttention,
707
  "flash_attention_2": OptFlashAttention2,
708
- "out_eff_hop": OPTOutEffHop,
709
  }
710
 
711
 
@@ -714,7 +539,7 @@ class OPTDecoderLayer(nn.Module):
714
  super().__init__()
715
  self.embed_dim = config.hidden_size
716
 
717
- self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](
718
  config=config, is_decoder=True)
719
 
720
  self.do_layer_norm_before = config.do_layer_norm_before
 
104
  def softmax_n_shifted_zeros(input: torch.Tensor, n: int, dim=-1) -> torch.Tensor:
105
  """
106
  $\text(softmax)_n(x_i) = exp(x_i) / (n + \sum_j exp(x_j))$
 
107
  Note: softmax_n, with fixed input, is _not_ shift-symmetric when n != 0
108
  """
109
  # compute the maxes along the last dimension
 
125
  """
126
  $\text(softmax)_n(x_i) = exp(x_i) / (1 + \sum_j exp(x_j))$
127
  """
128
+ output = softmax_n_shifted_zeros(input, 1, dim=dim)
129
+ return output if dtype is None else output.type(dtype=dtype)
130
 
131
 
132
  class OPTAttention(nn.Module):
 
169
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
170
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def forward(
173
  self,
174
  hidden_states: torch.Tensor,
 
305
  return attn_output, attn_weights_reshaped, past_key_value
306
 
307
 
308
+
309
+
310
  class OptFlashAttention2(OPTAttention):
311
  """
312
  OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
 
531
  OPT_ATTENTION_CLASSES = {
532
  "eager": OPTAttention,
533
  "flash_attention_2": OptFlashAttention2,
 
534
  }
535
 
536
 
 
539
  super().__init__()
540
  self.embed_dim = config.hidden_size
541
 
542
+ self.self_attn = OPTAttention(
543
  config=config, is_decoder=True)
544
 
545
  self.do_layer_norm_before = config.do_layer_norm_before