feat-swiglu
#17
by
bwang0911
- opened
- activation.py +23 -0
- config.json +1 -1
- mlp.py +4 -2
- modeling_xlm_roberta.py +7 -5
- modeling_xlm_roberta_for_glue.py +0 -109
activation.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
try:
|
5 |
+
from flash_attn.ops.activations import swiglu as flash_swiglu
|
6 |
+
except ImportError:
|
7 |
+
flash_swiglu = None
|
8 |
+
|
9 |
+
if flash_swiglu is None:
|
10 |
+
# PyTorch implementation of SwiGLU
|
11 |
+
class SwiGLU(nn.Module):
|
12 |
+
def forward(self, x):
|
13 |
+
x, gate = x.chunk(2, dim=-1)
|
14 |
+
return F.silu(gate) * x
|
15 |
+
|
16 |
+
def swiglu(x):
|
17 |
+
layer = SwiGLU()
|
18 |
+
return layer(x)
|
19 |
+
|
20 |
+
else:
|
21 |
+
# Use Flash Attention's built-in swiglu
|
22 |
+
def swiglu(x):
|
23 |
+
return flash_swiglu(x)
|
config.json
CHANGED
@@ -12,7 +12,7 @@
|
|
12 |
"attention_probs_dropout_prob": 0.1,
|
13 |
"bos_token_id": 0,
|
14 |
"eos_token_id": 2,
|
15 |
-
"hidden_act": "
|
16 |
"hidden_dropout_prob": 0.1,
|
17 |
"hidden_size": 768,
|
18 |
"initializer_range": 0.02,
|
|
|
12 |
"attention_probs_dropout_prob": 0.1,
|
13 |
"bos_token_id": 0,
|
14 |
"eos_token_id": 2,
|
15 |
+
"hidden_act": "swiglu",
|
16 |
"hidden_dropout_prob": 0.1,
|
17 |
"hidden_size": 768,
|
18 |
"initializer_range": 0.02,
|
mlp.py
CHANGED
@@ -24,6 +24,8 @@ try:
|
|
24 |
except ImportError:
|
25 |
FusedMLP, ParallelFusedMLP = None, None
|
26 |
|
|
|
|
|
27 |
|
28 |
class Mlp(nn.Module):
|
29 |
def __init__(
|
@@ -31,7 +33,7 @@ class Mlp(nn.Module):
|
|
31 |
in_features,
|
32 |
hidden_features=None,
|
33 |
out_features=None,
|
34 |
-
activation=
|
35 |
bias1=True,
|
36 |
bias2=True,
|
37 |
return_residual=False,
|
@@ -60,7 +62,7 @@ class ParallelMLP(nn.Module):
|
|
60 |
in_features,
|
61 |
hidden_features=None,
|
62 |
out_features=None,
|
63 |
-
activation=
|
64 |
process_group: ProcessGroup = None,
|
65 |
sequence_parallel=True,
|
66 |
bias1=True,
|
|
|
24 |
except ImportError:
|
25 |
FusedMLP, ParallelFusedMLP = None, None
|
26 |
|
27 |
+
from .activation import swiglu
|
28 |
+
|
29 |
|
30 |
class Mlp(nn.Module):
|
31 |
def __init__(
|
|
|
33 |
in_features,
|
34 |
hidden_features=None,
|
35 |
out_features=None,
|
36 |
+
activation=swiglu,
|
37 |
bias1=True,
|
38 |
bias2=True,
|
39 |
return_residual=False,
|
|
|
62 |
in_features,
|
63 |
hidden_features=None,
|
64 |
out_features=None,
|
65 |
+
activation=swiglu,
|
66 |
process_group: ProcessGroup = None,
|
67 |
sequence_parallel=True,
|
68 |
bias1=True,
|
modeling_xlm_roberta.py
CHANGED
@@ -45,6 +45,7 @@ from .embedding import XLMRobertaEmbeddings
|
|
45 |
from .mha import MHA
|
46 |
from .mlp import FusedMLP, Mlp
|
47 |
from .stochastic_depth import StochasticDepth
|
|
|
48 |
|
49 |
|
50 |
try:
|
@@ -118,19 +119,19 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
|
118 |
inner_dim = config.intermediate_size
|
119 |
fused_mlp = getattr(config, "fused_mlp", False)
|
120 |
if fused_mlp:
|
121 |
-
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
122 |
"fused_mlp only " "supports approximate gelu"
|
123 |
)
|
124 |
if not fused_mlp:
|
125 |
approximate = (
|
126 |
"tanh"
|
127 |
-
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
128 |
else "none"
|
129 |
)
|
130 |
mlp_cls = partial(
|
131 |
Mlp,
|
132 |
hidden_features=inner_dim,
|
133 |
-
activation=
|
134 |
return_residual=return_residual,
|
135 |
)
|
136 |
else:
|
@@ -330,10 +331,10 @@ class XLMRobertaPredictionHeadTransform(nn.Module):
|
|
330 |
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
331 |
approximate = (
|
332 |
"tanh"
|
333 |
-
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
334 |
else "none"
|
335 |
)
|
336 |
-
self.transform_act_fn =
|
337 |
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
338 |
|
339 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
@@ -424,6 +425,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
424 |
"gelu_new",
|
425 |
"gelu_fast",
|
426 |
"gelu_pytorch_tanh",
|
|
|
427 |
]
|
428 |
|
429 |
self.embeddings = XLMRobertaEmbeddings(
|
|
|
45 |
from .mha import MHA
|
46 |
from .mlp import FusedMLP, Mlp
|
47 |
from .stochastic_depth import StochasticDepth
|
48 |
+
from .activation import swiglu
|
49 |
|
50 |
|
51 |
try:
|
|
|
119 |
inner_dim = config.intermediate_size
|
120 |
fused_mlp = getattr(config, "fused_mlp", False)
|
121 |
if fused_mlp:
|
122 |
+
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh", "swiglu"], (
|
123 |
"fused_mlp only " "supports approximate gelu"
|
124 |
)
|
125 |
if not fused_mlp:
|
126 |
approximate = (
|
127 |
"tanh"
|
128 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh", "swiglu"]
|
129 |
else "none"
|
130 |
)
|
131 |
mlp_cls = partial(
|
132 |
Mlp,
|
133 |
hidden_features=inner_dim,
|
134 |
+
activation=swiglu,
|
135 |
return_residual=return_residual,
|
136 |
)
|
137 |
else:
|
|
|
331 |
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
332 |
approximate = (
|
333 |
"tanh"
|
334 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh", "swiglu"]
|
335 |
else "none"
|
336 |
)
|
337 |
+
self.transform_act_fn = swiglu
|
338 |
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
339 |
|
340 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
425 |
"gelu_new",
|
426 |
"gelu_fast",
|
427 |
"gelu_pytorch_tanh",
|
428 |
+
"swiglu",
|
429 |
]
|
430 |
|
431 |
self.embeddings = XLMRobertaEmbeddings(
|
modeling_xlm_roberta_for_glue.py
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
from typing import Optional, Union, Tuple
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
|
6 |
-
from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
|
7 |
-
|
8 |
-
from .modeling_xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
|
9 |
-
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
10 |
-
|
11 |
-
|
12 |
-
class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
|
13 |
-
def __init__(self, config: XLMRobertaFlashConfig):
|
14 |
-
super().__init__(config)
|
15 |
-
self.num_labels = config.num_labels
|
16 |
-
self.config = config
|
17 |
-
|
18 |
-
self.roberta = XLMRobertaModel(config)
|
19 |
-
classifier_dropout = (
|
20 |
-
config.classifier_dropout
|
21 |
-
if config.classifier_dropout is not None
|
22 |
-
else config.hidden_dropout_prob
|
23 |
-
)
|
24 |
-
self.dropout = nn.Dropout(classifier_dropout)
|
25 |
-
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
26 |
-
|
27 |
-
# Initialize weights and apply final processing
|
28 |
-
self.post_init()
|
29 |
-
|
30 |
-
|
31 |
-
def forward(
|
32 |
-
self,
|
33 |
-
input_ids: Optional[torch.Tensor] = None,
|
34 |
-
attention_mask: Optional[torch.Tensor] = None,
|
35 |
-
token_type_ids: Optional[torch.Tensor] = None,
|
36 |
-
position_ids: Optional[torch.Tensor] = None,
|
37 |
-
head_mask: Optional[torch.Tensor] = None,
|
38 |
-
inputs_embeds: Optional[torch.Tensor] = None,
|
39 |
-
labels: Optional[torch.Tensor] = None,
|
40 |
-
output_attentions: Optional[bool] = None,
|
41 |
-
output_hidden_states: Optional[bool] = None,
|
42 |
-
return_dict: Optional[bool] = None,
|
43 |
-
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
44 |
-
r"""
|
45 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
46 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
47 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
48 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
49 |
-
"""
|
50 |
-
return_dict = (
|
51 |
-
return_dict if return_dict is not None else self.config.use_return_dict
|
52 |
-
)
|
53 |
-
|
54 |
-
assert head_mask is None
|
55 |
-
assert inputs_embeds is None
|
56 |
-
assert output_attentions is None
|
57 |
-
assert output_hidden_states is None
|
58 |
-
assert return_dict
|
59 |
-
outputs = self.roberta(
|
60 |
-
input_ids,
|
61 |
-
attention_mask=attention_mask,
|
62 |
-
token_type_ids=token_type_ids,
|
63 |
-
position_ids=position_ids,
|
64 |
-
head_mask=head_mask,
|
65 |
-
inputs_embeds=inputs_embeds,
|
66 |
-
output_attentions=output_attentions,
|
67 |
-
output_hidden_states=output_hidden_states,
|
68 |
-
return_dict=return_dict,
|
69 |
-
)
|
70 |
-
|
71 |
-
pooled_output = outputs[1]
|
72 |
-
|
73 |
-
pooled_output = self.dropout(pooled_output)
|
74 |
-
logits = self.classifier(pooled_output)
|
75 |
-
|
76 |
-
loss = None
|
77 |
-
if labels is not None:
|
78 |
-
if self.config.problem_type is None:
|
79 |
-
if self.num_labels == 1:
|
80 |
-
self.config.problem_type = "regression"
|
81 |
-
elif self.num_labels > 1 and (
|
82 |
-
labels.dtype == torch.long or labels.dtype == torch.int
|
83 |
-
):
|
84 |
-
self.config.problem_type = "single_label_classification"
|
85 |
-
else:
|
86 |
-
self.config.problem_type = "multi_label_classification"
|
87 |
-
|
88 |
-
if self.config.problem_type == "regression":
|
89 |
-
loss_fct = MSELoss()
|
90 |
-
if self.num_labels == 1:
|
91 |
-
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
92 |
-
else:
|
93 |
-
loss = loss_fct(logits, labels)
|
94 |
-
elif self.config.problem_type == "single_label_classification":
|
95 |
-
loss_fct = CrossEntropyLoss()
|
96 |
-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
97 |
-
elif self.config.problem_type == "multi_label_classification":
|
98 |
-
loss_fct = BCEWithLogitsLoss()
|
99 |
-
loss = loss_fct(logits, labels)
|
100 |
-
if not return_dict:
|
101 |
-
output = (logits,) + outputs[2:]
|
102 |
-
return ((loss,) + output) if loss is not None else output
|
103 |
-
|
104 |
-
return SequenceClassifierOutput(
|
105 |
-
loss=loss,
|
106 |
-
logits=logits,
|
107 |
-
hidden_states=outputs.hidden_states,
|
108 |
-
attentions=outputs.attentions,
|
109 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|