activation.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ try:
5
+ from flash_attn.ops.activations import swiglu as flash_swiglu
6
+ except ImportError:
7
+ flash_swiglu = None
8
+
9
+ if flash_swiglu is None:
10
+ # PyTorch implementation of SwiGLU
11
+ class SwiGLU(nn.Module):
12
+ def forward(self, x):
13
+ x, gate = x.chunk(2, dim=-1)
14
+ return F.silu(gate) * x
15
+
16
+ def swiglu(x):
17
+ layer = SwiGLU()
18
+ return layer(x)
19
+
20
+ else:
21
+ # Use Flash Attention's built-in swiglu
22
+ def swiglu(x):
23
+ return flash_swiglu(x)
config.json CHANGED
@@ -12,7 +12,7 @@
12
  "attention_probs_dropout_prob": 0.1,
13
  "bos_token_id": 0,
14
  "eos_token_id": 2,
15
- "hidden_act": "gelu",
16
  "hidden_dropout_prob": 0.1,
17
  "hidden_size": 768,
18
  "initializer_range": 0.02,
 
12
  "attention_probs_dropout_prob": 0.1,
13
  "bos_token_id": 0,
14
  "eos_token_id": 2,
15
+ "hidden_act": "swiglu",
16
  "hidden_dropout_prob": 0.1,
17
  "hidden_size": 768,
18
  "initializer_range": 0.02,
mlp.py CHANGED
@@ -24,6 +24,8 @@ try:
24
  except ImportError:
25
  FusedMLP, ParallelFusedMLP = None, None
26
 
 
 
27
 
28
  class Mlp(nn.Module):
29
  def __init__(
@@ -31,7 +33,7 @@ class Mlp(nn.Module):
31
  in_features,
32
  hidden_features=None,
33
  out_features=None,
34
- activation=F.gelu,
35
  bias1=True,
36
  bias2=True,
37
  return_residual=False,
@@ -60,7 +62,7 @@ class ParallelMLP(nn.Module):
60
  in_features,
61
  hidden_features=None,
62
  out_features=None,
63
- activation=F.gelu,
64
  process_group: ProcessGroup = None,
65
  sequence_parallel=True,
66
  bias1=True,
 
24
  except ImportError:
25
  FusedMLP, ParallelFusedMLP = None, None
26
 
27
+ from .activation import swiglu
28
+
29
 
30
  class Mlp(nn.Module):
31
  def __init__(
 
33
  in_features,
34
  hidden_features=None,
35
  out_features=None,
36
+ activation=swiglu,
37
  bias1=True,
38
  bias2=True,
39
  return_residual=False,
 
62
  in_features,
63
  hidden_features=None,
64
  out_features=None,
65
+ activation=swiglu,
66
  process_group: ProcessGroup = None,
67
  sequence_parallel=True,
68
  bias1=True,
modeling_xlm_roberta.py CHANGED
@@ -45,6 +45,7 @@ from .embedding import XLMRobertaEmbeddings
45
  from .mha import MHA
46
  from .mlp import FusedMLP, Mlp
47
  from .stochastic_depth import StochasticDepth
 
48
 
49
 
50
  try:
@@ -118,19 +119,19 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
118
  inner_dim = config.intermediate_size
119
  fused_mlp = getattr(config, "fused_mlp", False)
120
  if fused_mlp:
121
- assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
122
  "fused_mlp only " "supports approximate gelu"
123
  )
124
  if not fused_mlp:
125
  approximate = (
126
  "tanh"
127
- if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
128
  else "none"
129
  )
130
  mlp_cls = partial(
131
  Mlp,
132
  hidden_features=inner_dim,
133
- activation=partial(F.gelu, approximate=approximate),
134
  return_residual=return_residual,
135
  )
136
  else:
@@ -330,10 +331,10 @@ class XLMRobertaPredictionHeadTransform(nn.Module):
330
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
331
  approximate = (
332
  "tanh"
333
- if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
334
  else "none"
335
  )
336
- self.transform_act_fn = nn.GELU(approximate=approximate)
337
  self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
338
 
339
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -424,6 +425,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
424
  "gelu_new",
425
  "gelu_fast",
426
  "gelu_pytorch_tanh",
 
427
  ]
428
 
429
  self.embeddings = XLMRobertaEmbeddings(
 
45
  from .mha import MHA
46
  from .mlp import FusedMLP, Mlp
47
  from .stochastic_depth import StochasticDepth
48
+ from .activation import swiglu
49
 
50
 
51
  try:
 
119
  inner_dim = config.intermediate_size
120
  fused_mlp = getattr(config, "fused_mlp", False)
121
  if fused_mlp:
122
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh", "swiglu"], (
123
  "fused_mlp only " "supports approximate gelu"
124
  )
125
  if not fused_mlp:
126
  approximate = (
127
  "tanh"
128
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh", "swiglu"]
129
  else "none"
130
  )
131
  mlp_cls = partial(
132
  Mlp,
133
  hidden_features=inner_dim,
134
+ activation=swiglu,
135
  return_residual=return_residual,
136
  )
137
  else:
 
331
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
332
  approximate = (
333
  "tanh"
334
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh", "swiglu"]
335
  else "none"
336
  )
337
+ self.transform_act_fn = swiglu
338
  self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
339
 
340
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
425
  "gelu_new",
426
  "gelu_fast",
427
  "gelu_pytorch_tanh",
428
+ "swiglu",
429
  ]
430
 
431
  self.embeddings = XLMRobertaEmbeddings(
modeling_xlm_roberta_for_glue.py DELETED
@@ -1,109 +0,0 @@
1
- from typing import Optional, Union, Tuple
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
6
- from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
7
-
8
- from .modeling_xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
9
- from .configuration_xlm_roberta import XLMRobertaFlashConfig
10
-
11
-
12
- class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
13
- def __init__(self, config: XLMRobertaFlashConfig):
14
- super().__init__(config)
15
- self.num_labels = config.num_labels
16
- self.config = config
17
-
18
- self.roberta = XLMRobertaModel(config)
19
- classifier_dropout = (
20
- config.classifier_dropout
21
- if config.classifier_dropout is not None
22
- else config.hidden_dropout_prob
23
- )
24
- self.dropout = nn.Dropout(classifier_dropout)
25
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
26
-
27
- # Initialize weights and apply final processing
28
- self.post_init()
29
-
30
-
31
- def forward(
32
- self,
33
- input_ids: Optional[torch.Tensor] = None,
34
- attention_mask: Optional[torch.Tensor] = None,
35
- token_type_ids: Optional[torch.Tensor] = None,
36
- position_ids: Optional[torch.Tensor] = None,
37
- head_mask: Optional[torch.Tensor] = None,
38
- inputs_embeds: Optional[torch.Tensor] = None,
39
- labels: Optional[torch.Tensor] = None,
40
- output_attentions: Optional[bool] = None,
41
- output_hidden_states: Optional[bool] = None,
42
- return_dict: Optional[bool] = None,
43
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
44
- r"""
45
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
46
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
47
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
48
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
49
- """
50
- return_dict = (
51
- return_dict if return_dict is not None else self.config.use_return_dict
52
- )
53
-
54
- assert head_mask is None
55
- assert inputs_embeds is None
56
- assert output_attentions is None
57
- assert output_hidden_states is None
58
- assert return_dict
59
- outputs = self.roberta(
60
- input_ids,
61
- attention_mask=attention_mask,
62
- token_type_ids=token_type_ids,
63
- position_ids=position_ids,
64
- head_mask=head_mask,
65
- inputs_embeds=inputs_embeds,
66
- output_attentions=output_attentions,
67
- output_hidden_states=output_hidden_states,
68
- return_dict=return_dict,
69
- )
70
-
71
- pooled_output = outputs[1]
72
-
73
- pooled_output = self.dropout(pooled_output)
74
- logits = self.classifier(pooled_output)
75
-
76
- loss = None
77
- if labels is not None:
78
- if self.config.problem_type is None:
79
- if self.num_labels == 1:
80
- self.config.problem_type = "regression"
81
- elif self.num_labels > 1 and (
82
- labels.dtype == torch.long or labels.dtype == torch.int
83
- ):
84
- self.config.problem_type = "single_label_classification"
85
- else:
86
- self.config.problem_type = "multi_label_classification"
87
-
88
- if self.config.problem_type == "regression":
89
- loss_fct = MSELoss()
90
- if self.num_labels == 1:
91
- loss = loss_fct(logits.squeeze(), labels.squeeze())
92
- else:
93
- loss = loss_fct(logits, labels)
94
- elif self.config.problem_type == "single_label_classification":
95
- loss_fct = CrossEntropyLoss()
96
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
97
- elif self.config.problem_type == "multi_label_classification":
98
- loss_fct = BCEWithLogitsLoss()
99
- loss = loss_fct(logits, labels)
100
- if not return_dict:
101
- output = (logits,) + outputs[2:]
102
- return ((loss,) + output) if loss is not None else output
103
-
104
- return SequenceClassifierOutput(
105
- loss=loss,
106
- logits=logits,
107
- hidden_states=outputs.hidden_states,
108
- attentions=outputs.attentions,
109
- )