robinzixuan commited on
Commit
ab3b316
1 Parent(s): 3590b7f

Upload modeling_opt.py

Browse files
Files changed (1) hide show
  1. modeling_opt.py +23 -9
modeling_opt.py CHANGED
@@ -17,32 +17,37 @@
17
  from typing import List, Optional, Tuple, Union
18
 
19
  import torch
 
20
  import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
  from torch import nn
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
 
25
- from ...activations import ACT2FN
26
- from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
27
- from ...modeling_outputs import (
28
  BaseModelOutputWithPast,
29
  CausalLMOutputWithPast,
30
  QuestionAnsweringModelOutput,
31
  SequenceClassifierOutputWithPast,
32
  )
33
- from ...modeling_utils import PreTrainedModel
34
- from ...utils import (
 
35
  add_code_sample_docstrings,
36
  add_start_docstrings,
37
  add_start_docstrings_to_model_forward,
 
38
  is_flash_attn_2_available,
39
  is_flash_attn_greater_or_equal_2_10,
40
  logging,
41
  replace_return_docstrings,
 
42
  )
43
  from .configuration_opt import OPTConfig
44
 
45
 
 
46
  if is_flash_attn_2_available():
47
  from flash_attn import flash_attn_func, flash_attn_varlen_func
48
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -128,6 +133,16 @@ def softmax_1(input: torch.Tensor, dim=-1, dtype=torch.float32) -> torch.Tensor:
128
  output = softmax_n_shifted_zeros(input, 1, dim=dim)
129
  return output if dtype is None else output.type(dtype=dtype)
130
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  class OPTAttention(nn.Module):
133
  """Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -147,7 +162,7 @@ class OPTAttention(nn.Module):
147
 
148
  self.head_dim = self.embed_dim // self.num_heads
149
  self.is_causal = True
150
-
151
  if (self.head_dim * self.num_heads) != self.embed_dim:
152
  raise ValueError(
153
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {
@@ -251,10 +266,10 @@ class OPTAttention(nn.Module):
251
 
252
  # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
253
  if attn_weights.dtype == torch.float16:
254
- attn_weights = softmax_1(
255
  attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
256
  else:
257
- attn_weights = softmax_1(attn_weights, dim=-1)
258
 
259
  if layer_head_mask is not None:
260
  if layer_head_mask.size() != (self.num_heads,):
@@ -306,7 +321,6 @@ class OPTAttention(nn.Module):
306
 
307
 
308
 
309
-
310
  class OptFlashAttention2(OPTAttention):
311
  """
312
  OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
 
17
  from typing import List, Optional, Tuple, Union
18
 
19
  import torch
20
+
21
  import torch.nn.functional as F
22
  import torch.utils.checkpoint
23
  from torch import nn
24
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
 
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
28
+ from transformers.modeling_outputs import (
29
  BaseModelOutputWithPast,
30
  CausalLMOutputWithPast,
31
  QuestionAnsweringModelOutput,
32
  SequenceClassifierOutputWithPast,
33
  )
34
+
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import (
37
  add_code_sample_docstrings,
38
  add_start_docstrings,
39
  add_start_docstrings_to_model_forward,
40
+
41
  is_flash_attn_2_available,
42
  is_flash_attn_greater_or_equal_2_10,
43
  logging,
44
  replace_return_docstrings,
45
+
46
  )
47
  from .configuration_opt import OPTConfig
48
 
49
 
50
+
51
  if is_flash_attn_2_available():
52
  from flash_attn import flash_attn_func, flash_attn_varlen_func
53
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
133
  output = softmax_n_shifted_zeros(input, 1, dim=dim)
134
  return output if dtype is None else output.type(dtype=dtype)
135
 
136
+ def clipped_softmax(data, dim=1, eta=1.1, gamma=-0.1, **kw):
137
+ sm_out = torch.nn.functional.softmax(data, dim=dim, **kw)
138
+ stretched_out = sm_out * (eta - gamma) + gamma
139
+ return torch.clip(stretched_out, 0, 1)
140
+
141
+
142
+ def clipped_softmax1(data, dim=1, eta=1.1, gamma=-0.1, **kw):
143
+ sm_out = softmax_1(data, dim=dim, **kw)
144
+ stretched_out = sm_out * (eta - gamma) + gamma
145
+ return torch.clip(stretched_out, 0, 1)
146
 
147
  class OPTAttention(nn.Module):
148
  """Multi-headed attention from 'Attention Is All You Need' paper"""
 
162
 
163
  self.head_dim = self.embed_dim // self.num_heads
164
  self.is_causal = True
165
+ self.softmax_fn = clipped_softmax1
166
  if (self.head_dim * self.num_heads) != self.embed_dim:
167
  raise ValueError(
168
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {
 
266
 
267
  # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
268
  if attn_weights.dtype == torch.float16:
269
+ attn_weights = self.softmax_fn(
270
  attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
271
  else:
272
+ attn_weights = self.softmax_fn(attn_weights, dim=-1)
273
 
274
  if layer_head_mask is not None:
275
  if layer_head_mask.size() != (self.num_heads,):
 
321
 
322
 
323
 
 
324
  class OptFlashAttention2(OPTAttention):
325
  """
326
  OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.