robinzixuan
commited on
Commit
•
ab3b316
1
Parent(s):
3590b7f
Upload modeling_opt.py
Browse files- modeling_opt.py +23 -9
modeling_opt.py
CHANGED
@@ -17,32 +17,37 @@
|
|
17 |
from typing import List, Optional, Tuple, Union
|
18 |
|
19 |
import torch
|
|
|
20 |
import torch.nn.functional as F
|
21 |
import torch.utils.checkpoint
|
22 |
from torch import nn
|
23 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
24 |
|
25 |
-
from
|
26 |
-
from
|
27 |
-
from
|
28 |
BaseModelOutputWithPast,
|
29 |
CausalLMOutputWithPast,
|
30 |
QuestionAnsweringModelOutput,
|
31 |
SequenceClassifierOutputWithPast,
|
32 |
)
|
33 |
-
|
34 |
-
from
|
|
|
35 |
add_code_sample_docstrings,
|
36 |
add_start_docstrings,
|
37 |
add_start_docstrings_to_model_forward,
|
|
|
38 |
is_flash_attn_2_available,
|
39 |
is_flash_attn_greater_or_equal_2_10,
|
40 |
logging,
|
41 |
replace_return_docstrings,
|
|
|
42 |
)
|
43 |
from .configuration_opt import OPTConfig
|
44 |
|
45 |
|
|
|
46 |
if is_flash_attn_2_available():
|
47 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
48 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
@@ -128,6 +133,16 @@ def softmax_1(input: torch.Tensor, dim=-1, dtype=torch.float32) -> torch.Tensor:
|
|
128 |
output = softmax_n_shifted_zeros(input, 1, dim=dim)
|
129 |
return output if dtype is None else output.type(dtype=dtype)
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
class OPTAttention(nn.Module):
|
133 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
@@ -147,7 +162,7 @@ class OPTAttention(nn.Module):
|
|
147 |
|
148 |
self.head_dim = self.embed_dim // self.num_heads
|
149 |
self.is_causal = True
|
150 |
-
|
151 |
if (self.head_dim * self.num_heads) != self.embed_dim:
|
152 |
raise ValueError(
|
153 |
f"embed_dim must be divisible by num_heads (got `embed_dim`: {
|
@@ -251,10 +266,10 @@ class OPTAttention(nn.Module):
|
|
251 |
|
252 |
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
253 |
if attn_weights.dtype == torch.float16:
|
254 |
-
attn_weights =
|
255 |
attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
|
256 |
else:
|
257 |
-
attn_weights =
|
258 |
|
259 |
if layer_head_mask is not None:
|
260 |
if layer_head_mask.size() != (self.num_heads,):
|
@@ -306,7 +321,6 @@ class OPTAttention(nn.Module):
|
|
306 |
|
307 |
|
308 |
|
309 |
-
|
310 |
class OptFlashAttention2(OPTAttention):
|
311 |
"""
|
312 |
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
|
|
|
17 |
from typing import List, Optional, Tuple, Union
|
18 |
|
19 |
import torch
|
20 |
+
|
21 |
import torch.nn.functional as F
|
22 |
import torch.utils.checkpoint
|
23 |
from torch import nn
|
24 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
25 |
|
26 |
+
from transformers.activations import ACT2FN
|
27 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
28 |
+
from transformers.modeling_outputs import (
|
29 |
BaseModelOutputWithPast,
|
30 |
CausalLMOutputWithPast,
|
31 |
QuestionAnsweringModelOutput,
|
32 |
SequenceClassifierOutputWithPast,
|
33 |
)
|
34 |
+
|
35 |
+
from transformers.modeling_utils import PreTrainedModel
|
36 |
+
from transformers.utils import (
|
37 |
add_code_sample_docstrings,
|
38 |
add_start_docstrings,
|
39 |
add_start_docstrings_to_model_forward,
|
40 |
+
|
41 |
is_flash_attn_2_available,
|
42 |
is_flash_attn_greater_or_equal_2_10,
|
43 |
logging,
|
44 |
replace_return_docstrings,
|
45 |
+
|
46 |
)
|
47 |
from .configuration_opt import OPTConfig
|
48 |
|
49 |
|
50 |
+
|
51 |
if is_flash_attn_2_available():
|
52 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
53 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
133 |
output = softmax_n_shifted_zeros(input, 1, dim=dim)
|
134 |
return output if dtype is None else output.type(dtype=dtype)
|
135 |
|
136 |
+
def clipped_softmax(data, dim=1, eta=1.1, gamma=-0.1, **kw):
|
137 |
+
sm_out = torch.nn.functional.softmax(data, dim=dim, **kw)
|
138 |
+
stretched_out = sm_out * (eta - gamma) + gamma
|
139 |
+
return torch.clip(stretched_out, 0, 1)
|
140 |
+
|
141 |
+
|
142 |
+
def clipped_softmax1(data, dim=1, eta=1.1, gamma=-0.1, **kw):
|
143 |
+
sm_out = softmax_1(data, dim=dim, **kw)
|
144 |
+
stretched_out = sm_out * (eta - gamma) + gamma
|
145 |
+
return torch.clip(stretched_out, 0, 1)
|
146 |
|
147 |
class OPTAttention(nn.Module):
|
148 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
162 |
|
163 |
self.head_dim = self.embed_dim // self.num_heads
|
164 |
self.is_causal = True
|
165 |
+
self.softmax_fn = clipped_softmax1
|
166 |
if (self.head_dim * self.num_heads) != self.embed_dim:
|
167 |
raise ValueError(
|
168 |
f"embed_dim must be divisible by num_heads (got `embed_dim`: {
|
|
|
266 |
|
267 |
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
268 |
if attn_weights.dtype == torch.float16:
|
269 |
+
attn_weights = self.softmax_fn(
|
270 |
attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
|
271 |
else:
|
272 |
+
attn_weights = self.softmax_fn(attn_weights, dim=-1)
|
273 |
|
274 |
if layer_head_mask is not None:
|
275 |
if layer_head_mask.size() != (self.num_heads,):
|
|
|
321 |
|
322 |
|
323 |
|
|
|
324 |
class OptFlashAttention2(OPTAttention):
|
325 |
"""
|
326 |
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
|