support use_flash_attn in from_pretrained

#18
README.md CHANGED
@@ -1,114 +1,5 @@
1
- ---
2
- tags:
3
- - transformers
4
- - xlm-roberta
5
- library_name: transformers
6
- license: cc-by-nc-4.0
7
- language:
8
- - multilingual
9
- - af
10
- - am
11
- - ar
12
- - as
13
- - az
14
- - be
15
- - bg
16
- - bn
17
- - br
18
- - bs
19
- - ca
20
- - cs
21
- - cy
22
- - da
23
- - de
24
- - el
25
- - en
26
- - eo
27
- - es
28
- - et
29
- - eu
30
- - fa
31
- - fi
32
- - fr
33
- - fy
34
- - ga
35
- - gd
36
- - gl
37
- - gu
38
- - ha
39
- - he
40
- - hi
41
- - hr
42
- - hu
43
- - hy
44
- - id
45
- - is
46
- - it
47
- - ja
48
- - jv
49
- - ka
50
- - kk
51
- - km
52
- - kn
53
- - ko
54
- - ku
55
- - ky
56
- - la
57
- - lo
58
- - lt
59
- - lv
60
- - mg
61
- - mk
62
- - ml
63
- - mn
64
- - mr
65
- - ms
66
- - my
67
- - ne
68
- - nl
69
- - 'no'
70
- - om
71
- - or
72
- - pa
73
- - pl
74
- - ps
75
- - pt
76
- - ro
77
- - ru
78
- - sa
79
- - sd
80
- - si
81
- - sk
82
- - sl
83
- - so
84
- - sq
85
- - sr
86
- - su
87
- - sv
88
- - sw
89
- - ta
90
- - te
91
- - th
92
- - tl
93
- - tr
94
- - ug
95
- - uk
96
- - ur
97
- - uz
98
- - vi
99
- - xh
100
- - yi
101
- - zh
102
- ---
103
- Core implementation of Jina XLM-RoBERTa
104
 
105
- This implementation is adapted from [XLM-Roberta](https://huggingface.co/docs/transformers/en/model_doc/xlm-roberta). In contrast to the original implementation, this model uses Rotary positional encodings and supports flash-attention 2.
106
-
107
- ### Models that use this implementation
108
-
109
- - [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3)
110
- - [jinaai/jina-colbert-v2](https://huggingface.co/jinaai/jina-colbert-v2)
111
-
112
- ### Converting weights
113
-
114
- Weights from an [original XLMRoberta model](https://huggingface.co/FacebookAI/xlm-roberta-large) can be converted using the `convert_roberta_weights_to_flash.py` script in the model repository.
 
1
+ # Converting Weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ ```
4
+ python3 -m "xlm-roberta-flash-implementation".convert_roberta_weights_to_flash --output pytorch_model_xlmr_flash.bin
5
+ ```
 
 
 
 
 
 
 
block.py CHANGED
@@ -8,14 +8,15 @@ from typing import Optional
8
 
9
  import torch
10
  import torch.nn as nn
 
11
  from torch import Tensor
12
 
 
13
  from .mha import MHA
14
  from .mlp import Mlp
15
- from .stochastic_depth import StochasticDepth
16
 
17
  try:
18
- from flash_attn.ops.triton.layer_norm import RMSNorm, layer_norm_fn
19
  except ImportError:
20
  layer_norm_fn, RMSNorm = None, None
21
 
@@ -232,9 +233,7 @@ class Block(nn.Module):
232
  is_rms_norm=isinstance(self.norm1, RMSNorm),
233
  )
234
  if not isinstance(self.mlp, nn.Identity):
235
- mlp_out = self.mlp(
236
- hidden_states, adapter_mask=mixer_kwargs.get("adapter_mask")
237
- )
238
  if self.return_residual: # mlp out is actually a pair here
239
  mlp_out, hidden_states = mlp_out
240
  if not self.fused_dropout_add_ln:
 
8
 
9
  import torch
10
  import torch.nn as nn
11
+ import torch.nn.functional as F
12
  from torch import Tensor
13
 
14
+ from .stochastic_depth import StochasticDepth
15
  from .mha import MHA
16
  from .mlp import Mlp
 
17
 
18
  try:
19
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
20
  except ImportError:
21
  layer_norm_fn, RMSNorm = None, None
22
 
 
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
+ mlp_out = self.mlp(hidden_states)
 
 
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
4
+ "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
5
+ "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
6
+ "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM",
7
+ "AutoModelForSequenceClassification":"modeling_xlm_roberta.XLMRobertaForSequenceClassification"
8
+ },
9
+ "architectures": [
10
+ "XLMRobertaModel"
11
+ ],
12
+ "attention_probs_dropout_prob": 0.1,
13
+ "bos_token_id": 0,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "gelu",
16
+ "hidden_dropout_prob": 0.1,
17
+ "hidden_size": 768,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 3072,
20
+ "layer_norm_eps": 1e-05,
21
+ "max_position_embeddings": 8194,
22
+ "num_attention_heads": 12,
23
+ "num_hidden_layers": 12,
24
+ "output_past": true,
25
+ "pad_token_id": 1,
26
+ "position_embedding_type": "absolute",
27
+ "transformers_version": "4.17.0.dev0",
28
+ "type_vocab_size": 1,
29
+ "use_cache": false,
30
+ "vocab_size": 250002
31
+ }
configuration_xlm_roberta.py CHANGED
@@ -1,94 +1,42 @@
1
- from typing import Any, Dict, List, Optional, Union
2
-
3
- import torch
4
  from transformers import PretrainedConfig
5
-
6
 
7
  class XLMRobertaFlashConfig(PretrainedConfig):
8
-
9
- model_type = "xlm-roberta"
10
-
11
  def __init__(
12
- self,
13
- vocab_size: int = 250002,
14
- hidden_size: int = 1024,
15
- num_hidden_layers: int = 24,
16
- num_attention_heads: int = 16,
17
- intermediate_size: int = 4096,
18
- hidden_act: str = "gelu",
19
- hidden_dropout_prob: float = 0.1,
20
- attention_probs_dropout_prob: float = 0.1,
21
- max_position_embeddings: int = 8194,
22
- type_vocab_size: int = 1,
23
- initializer_range: float = 0.02,
24
- layer_norm_eps: float = 1e-05,
25
- pad_token_id: int = 1,
26
- bos_token_id: int = 0,
27
- eos_token_id: int = 2,
28
- position_embedding_type: str = "rotary",
29
- rotary_emb_base: float = 10000.0,
30
- use_cache: bool = True,
31
- use_reentrant: bool = False,
32
- classifier_dropout: Optional[float] = None,
33
- lora_adaptations: Optional[List[str]] = None,
34
- task_instructions: Optional[Dict[str, str]] = None,
35
- lora_rank: int = 4,
36
- lora_dropout_p: float = 0.0,
37
- lora_alpha: int = 1,
38
- lora_main_params_trainable: bool = False,
39
- load_trained_adapters: bool = False,
40
- use_flash_attn: bool = True,
41
- torch_dtype: Optional[Union[str, torch.dtype]] = None,
42
- emb_pooler: Optional[str] = None,
43
- matryoshka_dimensions: Optional[List[int]] = None,
44
- truncate_dim: Optional[int] = None,
45
- **kwargs: Dict[str, Any],
46
  ):
47
- """
48
- Initialize the XLMRobertaFlashConfig configuration.
49
 
50
- Args:
51
- vocab_size (int): Size of the vocabulary.
52
- hidden_size (int): Dimensionality of the encoder layers and the pooler layer.
53
- num_hidden_layers (int): Number of hidden layers in the Transformer encoder.
54
- num_attention_heads (int): Number of attention heads for each attention layer in the Transformer encoder.
55
- intermediate_size (int): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer.
56
- hidden_act (str): The activation function to use.
57
- hidden_dropout_prob (float): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
58
- attention_probs_dropout_prob (float): The dropout ratio for the attention probabilities.
59
- max_position_embeddings (int): The maximum length of the position embeddings.
60
- type_vocab_size (int): The vocabulary size of the token type ids.
61
- initializer_range (float): The standard deviation for initializing all weight matrices.
62
- layer_norm_eps (float): The epsilon used by the layer normalization layers.
63
- pad_token_id (int): The ID of the padding token.
64
- bos_token_id (int): The ID of the beginning-of-sequence token.
65
- eos_token_id (int): The ID of the end-of-sequence token.
66
- position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
67
- rotary_emb_base (float): Base for rotary embeddings.
68
- use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
69
- use_reentrant (bool): Whether or not the model should enable the 'use_reentrant' flag in gradient checkpointing.
70
- classifier_dropout (Optional[float]): The dropout ratio for the classification head.
71
- lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
72
- lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
73
- lora_rank (int): Rank for LoRA adaptations.
74
- lora_dropout_p (float): Dropout probability for LoRA adaptations.
75
- lora_alpha (int): Alpha parameter for LoRA.
76
- lora_main_params_trainable (bool): Whether to make the main model parameters trainable when using LoRA.
77
- load_trained_adapters (bool): Whether to load trained adapters.
78
- use_flash_attn (bool): Whether to use FlashAttention.
79
- torch_dtype (Optional[Union[str, torch.dtype]]): Data type for the tensors.
80
- emb_pooler (Optional[str]): Pooling layer configuration.
81
- matryoshka_dimensions (Optional[List[int]]): Configuration for matryoshka dimension reduction.
82
- truncate_dim (Optional[int]): Dimension to truncate embeddings to, if any.
83
- **kwargs (Dict[str, Any]): Additional keyword arguments passed to the configuration.
84
- """
85
-
86
- super().__init__(
87
- pad_token_id=pad_token_id,
88
- bos_token_id=bos_token_id,
89
- eos_token_id=eos_token_id,
90
- **kwargs,
91
- )
92
 
93
  self.vocab_size = vocab_size
94
  self.hidden_size = hidden_size
@@ -103,13 +51,10 @@ class XLMRobertaFlashConfig(PretrainedConfig):
103
  self.initializer_range = initializer_range
104
  self.layer_norm_eps = layer_norm_eps
105
  self.position_embedding_type = position_embedding_type
106
- self.rotary_emb_base = rotary_emb_base
107
  self.use_cache = use_cache
108
- self.use_reentrant = use_reentrant
109
  self.classifier_dropout = classifier_dropout
110
  self.load_trained_adapters = load_trained_adapters
111
  self.lora_adaptations = lora_adaptations
112
- self.task_instructions = task_instructions
113
  self.lora_rank = lora_rank
114
  self.lora_dropout_p = lora_dropout_p
115
  self.lora_alpha = lora_alpha
@@ -118,13 +63,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
118
  self.emb_pooler = emb_pooler
119
  self.matryoshka_dimensions = matryoshka_dimensions
120
  self.truncate_dim = truncate_dim
121
- if (
122
- torch_dtype
123
- and hasattr(torch, torch_dtype)
124
- and type(getattr(torch, torch_dtype)) is torch.dtype
125
- ):
126
  self.torch_dtype = getattr(torch, torch_dtype)
127
  else:
128
  self.torch_dtype = torch_dtype
129
- if not self.use_flash_attn or not torch.cuda.is_available():
130
- self.torch_dtype = torch.float32
 
 
 
 
1
  from transformers import PretrainedConfig
2
+ import torch
3
 
4
  class XLMRobertaFlashConfig(PretrainedConfig):
 
 
 
5
  def __init__(
6
+ self,
7
+ vocab_size=30522,
8
+ hidden_size=768,
9
+ num_hidden_layers=12,
10
+ num_attention_heads=12,
11
+ intermediate_size=3072,
12
+ hidden_act="gelu",
13
+ hidden_dropout_prob=0.1,
14
+ attention_probs_dropout_prob=0.1,
15
+ max_position_embeddings=512,
16
+ type_vocab_size=2,
17
+ initializer_range=0.02,
18
+ layer_norm_eps=1e-12,
19
+ pad_token_id=1,
20
+ bos_token_id=0,
21
+ eos_token_id=2,
22
+ position_embedding_type="absolute",
23
+ use_cache=True,
24
+ classifier_dropout=None,
25
+ lora_adaptations=None,
26
+ lora_rank=4,
27
+ lora_dropout_p=0.0,
28
+ lora_alpha=1,
29
+ lora_main_params_trainable=False,
30
+ load_trained_adapters=False,
31
+ use_flash_attn=True,
32
+ torch_dtype=None,
33
+ emb_pooler=None,
34
+ matryoshka_dimensions=None,
35
+ truncate_dim=None,
36
+ **kwargs,
 
 
 
37
  ):
38
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  self.vocab_size = vocab_size
42
  self.hidden_size = hidden_size
 
51
  self.initializer_range = initializer_range
52
  self.layer_norm_eps = layer_norm_eps
53
  self.position_embedding_type = position_embedding_type
 
54
  self.use_cache = use_cache
 
55
  self.classifier_dropout = classifier_dropout
56
  self.load_trained_adapters = load_trained_adapters
57
  self.lora_adaptations = lora_adaptations
 
58
  self.lora_rank = lora_rank
59
  self.lora_dropout_p = lora_dropout_p
60
  self.lora_alpha = lora_alpha
 
63
  self.emb_pooler = emb_pooler
64
  self.matryoshka_dimensions = matryoshka_dimensions
65
  self.truncate_dim = truncate_dim
66
+ if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
 
 
 
 
67
  self.torch_dtype = getattr(torch, torch_dtype)
68
  else:
69
  self.torch_dtype = torch_dtype
 
 
embedding.py CHANGED
@@ -5,8 +5,10 @@
5
 
6
  import torch
7
  import torch.nn as nn
8
- from transformers.models.xlm_roberta.modeling_xlm_roberta import \
9
- create_position_ids_from_input_ids
 
 
10
 
11
 
12
  class XLMRobertaEmbeddings(nn.Module):
@@ -36,60 +38,25 @@ class XLMRobertaEmbeddings(nn.Module):
36
  max_position_embeddings, embed_dim, **factory_kwargs
37
  )
38
  if self.type_vocab_size > 0:
39
- self.token_type_embeddings = nn.Embedding(
40
- type_vocab_size, embed_dim, **factory_kwargs
41
- )
42
 
43
- def forward(
44
- self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None
45
- ):
46
  """
47
  input_ids: (batch, seqlen)
48
  position_ids: (batch, seqlen)
49
  token_type_ids: (batch, seqlen)
50
- adapter_mask: (batch, 1)
51
  """
52
  batch_size, seqlen = input_ids.shape
53
- if adapter_mask is not None:
54
- unique_tasks = torch.unique(adapter_mask)
55
- embedding_dtype = next(self.word_embeddings.parameters()).dtype
56
- embeddings = torch.empty(
57
- *input_ids.shape,
58
- self.word_embeddings.embedding_dim,
59
- dtype=embedding_dtype,
60
- device=input_ids.device
61
- )
62
- for task_id in unique_tasks:
63
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
64
- task_input_ids = input_ids[task_indices]
65
- task_embeddings = self.word_embeddings(task_input_ids, task_id=task_id)
66
- embeddings[task_indices] = task_embeddings
67
- else:
68
- embeddings = self.word_embeddings(input_ids)
69
  if self.max_position_embeddings > 0:
70
  if position_ids is None:
71
- position_ids = create_position_ids_from_input_ids(
72
- input_ids, padding_idx=self.word_embeddings.padding_idx
73
- ).to(input_ids.device)
74
  position_embeddings = self.position_embeddings(position_ids)
75
  embeddings = embeddings + position_embeddings
76
  if self.type_vocab_size > 0:
77
  if token_type_ids is None:
78
- token_type_ids = torch.zeros(
79
- seqlen, dtype=torch.long, device=input_ids.device
80
- )
81
-
82
- if adapter_mask is not None:
83
- unique_tasks = torch.unique(adapter_mask)
84
- for task_id in unique_tasks:
85
- task_token_type_embeddings = self.token_type_embeddings(
86
- token_type_ids, task_id=task_id
87
- )
88
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
89
- embeddings[task_indices] = (
90
- embeddings[task_indices] + task_token_type_embeddings
91
- )
92
- else:
93
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
94
- embeddings = embeddings + token_type_embeddings
95
  return embeddings
 
5
 
6
  import torch
7
  import torch.nn as nn
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
12
 
13
 
14
  class XLMRobertaEmbeddings(nn.Module):
 
38
  max_position_embeddings, embed_dim, **factory_kwargs
39
  )
40
  if self.type_vocab_size > 0:
41
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
 
 
42
 
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
 
 
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
 
48
  """
49
  batch_size, seqlen = input_ids.shape
50
+ embeddings = self.word_embeddings(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if self.max_position_embeddings > 0:
52
  if position_ids is None:
53
+ position_ids =create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
54
+ # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
 
55
  position_embeddings = self.position_embeddings(position_ids)
56
  embeddings = embeddings + position_embeddings
57
  if self.type_vocab_size > 0:
58
  if token_type_ids is None:
59
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
61
+ embeddings = embeddings + token_type_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  return embeddings
mha.py CHANGED
@@ -1,6 +1,5 @@
1
  # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
2
  # Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
3
- # Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
4
 
5
  # Copyright (c) 2023, Tri Dao.
6
 
@@ -12,23 +11,27 @@ import torch.nn as nn
12
  from einops import rearrange, repeat
13
 
14
  try:
15
- from flash_attn import (flash_attn_kvpacked_func,
16
- flash_attn_qkvpacked_func,
17
- flash_attn_varlen_kvpacked_func,
18
- flash_attn_varlen_qkvpacked_func,
19
- flash_attn_with_kvcache)
 
 
20
  except ImportError:
21
  flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
  flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
  flash_attn_with_kvcache = None
24
 
25
  try:
26
- from flash_attn.ops.fused_dense import (ColumnParallelLinear, FusedDense,
27
- RowParallelLinear)
28
  except ImportError:
29
  FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
30
 
31
- from .rotary import RotaryEmbedding
 
 
 
32
 
33
 
34
  # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
@@ -44,9 +47,7 @@ def get_alibi_slopes(nheads):
44
  closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
45
  return (
46
  get_slopes_power_of_2(closest_power_of_2)
47
- + get_alibi_slopes(2 * closest_power_of_2)[0::2][
48
- : nheads - closest_power_of_2
49
- ]
50
  )
51
 
52
 
@@ -71,9 +72,7 @@ class FlashSelfAttention(nn.Module):
71
  deterministic=False,
72
  ):
73
  super().__init__()
74
- assert (
75
- flash_attn_varlen_qkvpacked_func is not None
76
- ), "FlashAttention is not installed"
77
  assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
78
  self.causal = causal
79
  self.softmax_scale = softmax_scale
@@ -153,9 +152,7 @@ class FlashCrossAttention(nn.Module):
153
  deterministic=False,
154
  ):
155
  super().__init__()
156
- assert (
157
- flash_attn_varlen_kvpacked_func is not None
158
- ), "FlashAttention is not installed"
159
  assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
160
  self.causal = causal
161
  self.softmax_scale = softmax_scale
@@ -321,10 +318,7 @@ class CrossAttention(nn.Module):
321
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
322
  if key_padding_mask is not None:
323
  padding_mask = torch.full(
324
- (batch_size, seqlen_k),
325
- -10000.0,
326
- dtype=scores.dtype,
327
- device=scores.device,
328
  )
329
  padding_mask.masked_fill_(key_padding_mask, 0.0)
330
  # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
@@ -436,26 +430,20 @@ class MHA(nn.Module):
436
  else:
437
  alibi_slopes = None
438
  if window_size != (-1, -1):
439
- assert (
440
- use_flash_attn
441
- ), "Local (sliding window) attention code path requires flash_attn"
442
 
443
  self.num_heads = num_heads
444
  self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
445
  assert (
446
  self.num_heads % self.num_heads_kv == 0
447
  ), "num_heads must be divisible by num_heads_kv"
448
- assert (
449
- self.embed_dim % num_heads == 0
450
- ), "embed_dim must be divisible by num_heads"
451
  self.head_dim = self.embed_dim // num_heads
452
  qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
453
  kv_dim = 2 * self.head_dim * self.num_heads_kv
454
 
455
  if self.rotary_emb_dim > 0:
456
- assert (
457
- not cross_attn
458
- ), "MHA with rotary embedding does not support cross-attention yet"
459
  assert RotaryEmbedding is not None, "rotary_emb is not installed"
460
  self.rotary_emb = RotaryEmbedding(
461
  self.rotary_emb_dim,
@@ -463,41 +451,29 @@ class MHA(nn.Module):
463
  scale_base=rotary_emb_scale_base,
464
  interleaved=rotary_emb_interleaved,
465
  device=device,
466
- use_flash_attn=use_flash_attn,
467
  )
468
 
469
  if fused_bias_fc and FusedDense is None:
470
  raise ImportError("fused_dense is not installed")
471
-
472
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
473
  linear_resid_cls = (
474
- LinearResidual
475
- if not fused_bias_fc
476
- else partial(FusedDense, return_residual=True)
477
  )
478
  wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
479
  inner_attn_cls = (
480
- partial(
481
- FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size
482
- )
483
  if use_flash_attn
484
  else SelfAttention
485
  )
486
  inner_cross_attn_cls = (
487
- partial(
488
- FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size
489
- )
490
  if use_flash_attn
491
  else CrossAttention
492
  )
493
  if not self.cross_attn:
494
- self.Wqkv = wqkv_cls(
495
- embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs
496
- )
497
  else:
498
- self.Wq = linear_cls(
499
- embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs
500
- )
501
  self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
502
  if self.dwconv:
503
  if self.num_heads_kv == self.num_heads:
@@ -508,9 +484,7 @@ class MHA(nn.Module):
508
  self.dwconv_q = nn.Conv1d(
509
  embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
510
  )
511
- self.dwconv_kv = nn.Conv1d(
512
- kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim
513
- )
514
  self.inner_attn = inner_attn_cls(
515
  causal=causal,
516
  softmax_scale=softmax_scale,
@@ -519,9 +493,7 @@ class MHA(nn.Module):
519
  self.inner_cross_attn = inner_cross_attn_cls(
520
  causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
521
  )
522
- self.out_proj = linear_cls(
523
- embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs
524
- )
525
 
526
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
527
  dtype = self.out_proj.weight.dtype if dtype is None else dtype
@@ -539,9 +511,7 @@ class MHA(nn.Module):
539
  def _update_kv_cache(self, kv, inference_params):
540
  """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
541
  assert not self.dwconv, "Generation does not support dwconv yet"
542
- assert (
543
- self.layer_idx is not None
544
- ), "Generation requires layer_idx in the constructor"
545
  return _update_kv_cache(kv, inference_params, self.layer_idx)
546
 
547
  def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
@@ -557,10 +527,7 @@ class MHA(nn.Module):
557
  self.rotary_emb._update_cos_sin_cache(
558
  inference_params.max_seqlen, device=q.device, dtype=q.dtype
559
  )
560
- rotary_cos, rotary_sin = (
561
- self.rotary_emb._cos_cached,
562
- self.rotary_emb._sin_cached,
563
- )
564
  else:
565
  rotary_cos, rotary_sin = None, None
566
  batch = q.shape[0]
@@ -582,9 +549,7 @@ class MHA(nn.Module):
582
  cache_seqlens=cache_seqlens,
583
  softmax_scale=self.inner_cross_attn.softmax_scale,
584
  causal=self.inner_cross_attn.causal,
585
- rotary_interleaved=(
586
- self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False
587
- ),
588
  alibi_slopes=alibi_slopes,
589
  )
590
  return context
@@ -629,7 +594,6 @@ class MHA(nn.Module):
629
  max_seqlen=None,
630
  mixer_subset=None,
631
  inference_params=None,
632
- adapter_mask=None,
633
  **kwargs,
634
  ):
635
  """
@@ -655,6 +619,7 @@ class MHA(nn.Module):
655
  assert key_padding_mask is None
656
  assert self.use_flash_attn
657
  assert not self.dwconv
 
658
  if key_padding_mask is not None:
659
  assert cu_seqlens is None
660
  assert max_seqlen is None
@@ -678,50 +643,19 @@ class MHA(nn.Module):
678
  else inference_params.seqlen_offset
679
  )
680
  )
681
- rotary_max_seqlen = (
682
- inference_params.max_sequence_len
683
- if inference_params is not None
684
- else max_seqlen
685
- )
686
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
687
  assert x_kv is None and mixer_subset is None
688
-
689
- if adapter_mask is not None:
690
- unique_tasks = torch.unique(adapter_mask)
691
- qkv_dtype = next(self.Wqkv.parameters()).dtype
692
- qkv = torch.empty(
693
- *x.shape[:-1],
694
- self.Wqkv.out_features,
695
- dtype=qkv_dtype,
696
- device=x.device,
697
- )
698
- for task_id in unique_tasks:
699
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
700
- task_tensor = x[task_indices]
701
- if not self.return_residual:
702
- task_qkv = self.Wqkv(task_tensor, task_id=task_id)
703
- else:
704
- task_qkv, _ = self.Wqkv(
705
- task_tensor, task_id=task_id, residual=True
706
- )
707
- qkv[task_indices] = task_qkv
708
  else:
709
- if not self.return_residual:
710
- qkv = self.Wqkv(x)
711
- else:
712
- if hasattr(self.Wqkv, "parametrizations"):
713
- qkv, x = self.Wqkv(x, residual=True)
714
- else:
715
- qkv, x = self.Wqkv(x)
716
-
717
  if self.dwconv:
718
  qkv = rearrange(
719
- self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
720
- "b d s -> b s d",
721
  ).contiguous()
722
- qkv = rearrange(
723
- qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
724
- )
725
  if (
726
  inference_params is None
727
  or inference_params.seqlen_offset == 0
@@ -730,18 +664,13 @@ class MHA(nn.Module):
730
  ):
731
  if self.rotary_emb_dim > 0:
732
  qkv = self.rotary_emb(
733
- qkv,
734
- seqlen_offset=seqlen_offset,
735
- cu_seqlens=cu_seqlens,
736
- max_seqlen=rotary_max_seqlen,
737
  )
738
  if inference_params is None:
739
  if not self.checkpointing:
740
  context = self.inner_attn(qkv, **kwargs)
741
  else:
742
- context = torch.utils.checkpoint.checkpoint(
743
- self.inner_attn, qkv, **kwargs
744
- )
745
  else:
746
  context = self._update_kvcache_attention(
747
  qkv[:, :, 0], qkv[:, :, 1:], inference_params
@@ -770,17 +699,13 @@ class MHA(nn.Module):
770
  q = qkv[..., : self.num_heads * self.head_dim]
771
  kv = qkv[..., self.num_heads * self.head_dim :]
772
  q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
773
- kv = rearrange(
774
- kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim
775
- )
776
  if self.dwconv:
777
  q = rearrange(
778
- self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
779
- "b d s -> b s d",
780
  ).contiguous()
781
  kv = rearrange(
782
- self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
783
- "b d s -> b s d",
784
  ).contiguous()
785
  if (
786
  inference_params is None
@@ -790,11 +715,7 @@ class MHA(nn.Module):
790
  ):
791
  if self.rotary_emb_dim > 0:
792
  q, kv = self.rotary_emb(
793
- q,
794
- kv,
795
- seqlen_offset=seqlen_offset,
796
- cu_seqlens=cu_seqlens,
797
- max_seqlen=rotary_max_seqlen,
798
  )
799
  if inference_params is None:
800
  if not self.checkpointing:
@@ -806,25 +727,7 @@ class MHA(nn.Module):
806
  else:
807
  context = self._update_kvcache_attention(q, kv, inference_params)
808
  else:
809
- context = self._apply_rotary_update_kvcache_attention(
810
- q, kv, inference_params
811
- )
812
-
813
- inp = rearrange(context, "... h d -> ... (h d)")
814
- if adapter_mask is not None:
815
- unique_tasks = torch.unique(adapter_mask)
816
- out_dtype = next(self.out_proj.parameters()).dtype
817
- out = torch.empty(
818
- *inp.shape[:-1],
819
- self.out_proj.out_features,
820
- dtype=out_dtype,
821
- device=inp.device,
822
- )
823
- for task_id in unique_tasks:
824
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
825
- task_tensor = inp[task_indices]
826
- task_out = self.out_proj(task_tensor, task_id=task_id)
827
- out[task_indices] = task_out
828
- else:
829
- out = self.out_proj(inp)
830
  return out if not self.return_residual else (out, x)
 
 
1
  # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
2
  # Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
 
3
 
4
  # Copyright (c) 2023, Tri Dao.
5
 
 
11
  from einops import rearrange, repeat
12
 
13
  try:
14
+ from flash_attn import (
15
+ flash_attn_kvpacked_func,
16
+ flash_attn_qkvpacked_func,
17
+ flash_attn_varlen_kvpacked_func,
18
+ flash_attn_varlen_qkvpacked_func,
19
+ flash_attn_with_kvcache,
20
+ )
21
  except ImportError:
22
  flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
23
  flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
24
  flash_attn_with_kvcache = None
25
 
26
  try:
27
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
 
28
  except ImportError:
29
  FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
30
 
31
+ try:
32
+ from flash_attn.layers.rotary import RotaryEmbedding
33
+ except ImportError:
34
+ RotaryEmbedding = None
35
 
36
 
37
  # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
 
47
  closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
48
  return (
49
  get_slopes_power_of_2(closest_power_of_2)
50
+ + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
 
 
51
  )
52
 
53
 
 
72
  deterministic=False,
73
  ):
74
  super().__init__()
75
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
 
 
76
  assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
77
  self.causal = causal
78
  self.softmax_scale = softmax_scale
 
152
  deterministic=False,
153
  ):
154
  super().__init__()
155
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
 
 
156
  assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
157
  self.causal = causal
158
  self.softmax_scale = softmax_scale
 
318
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
319
  if key_padding_mask is not None:
320
  padding_mask = torch.full(
321
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
 
 
 
322
  )
323
  padding_mask.masked_fill_(key_padding_mask, 0.0)
324
  # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
 
430
  else:
431
  alibi_slopes = None
432
  if window_size != (-1, -1):
433
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
 
 
434
 
435
  self.num_heads = num_heads
436
  self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
437
  assert (
438
  self.num_heads % self.num_heads_kv == 0
439
  ), "num_heads must be divisible by num_heads_kv"
440
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
 
 
441
  self.head_dim = self.embed_dim // num_heads
442
  qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
443
  kv_dim = 2 * self.head_dim * self.num_heads_kv
444
 
445
  if self.rotary_emb_dim > 0:
446
+ assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
 
 
447
  assert RotaryEmbedding is not None, "rotary_emb is not installed"
448
  self.rotary_emb = RotaryEmbedding(
449
  self.rotary_emb_dim,
 
451
  scale_base=rotary_emb_scale_base,
452
  interleaved=rotary_emb_interleaved,
453
  device=device,
 
454
  )
455
 
456
  if fused_bias_fc and FusedDense is None:
457
  raise ImportError("fused_dense is not installed")
 
458
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
459
  linear_resid_cls = (
460
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
 
 
461
  )
462
  wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
463
  inner_attn_cls = (
464
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
 
 
465
  if use_flash_attn
466
  else SelfAttention
467
  )
468
  inner_cross_attn_cls = (
469
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
 
 
470
  if use_flash_attn
471
  else CrossAttention
472
  )
473
  if not self.cross_attn:
474
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
 
 
475
  else:
476
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
 
 
477
  self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
478
  if self.dwconv:
479
  if self.num_heads_kv == self.num_heads:
 
484
  self.dwconv_q = nn.Conv1d(
485
  embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
486
  )
487
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
 
 
488
  self.inner_attn = inner_attn_cls(
489
  causal=causal,
490
  softmax_scale=softmax_scale,
 
493
  self.inner_cross_attn = inner_cross_attn_cls(
494
  causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
495
  )
496
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
 
 
497
 
498
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
499
  dtype = self.out_proj.weight.dtype if dtype is None else dtype
 
511
  def _update_kv_cache(self, kv, inference_params):
512
  """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
513
  assert not self.dwconv, "Generation does not support dwconv yet"
514
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
 
 
515
  return _update_kv_cache(kv, inference_params, self.layer_idx)
516
 
517
  def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
 
527
  self.rotary_emb._update_cos_sin_cache(
528
  inference_params.max_seqlen, device=q.device, dtype=q.dtype
529
  )
530
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
 
 
 
531
  else:
532
  rotary_cos, rotary_sin = None, None
533
  batch = q.shape[0]
 
549
  cache_seqlens=cache_seqlens,
550
  softmax_scale=self.inner_cross_attn.softmax_scale,
551
  causal=self.inner_cross_attn.causal,
552
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
 
 
553
  alibi_slopes=alibi_slopes,
554
  )
555
  return context
 
594
  max_seqlen=None,
595
  mixer_subset=None,
596
  inference_params=None,
 
597
  **kwargs,
598
  ):
599
  """
 
619
  assert key_padding_mask is None
620
  assert self.use_flash_attn
621
  assert not self.dwconv
622
+ assert self.rotary_emb_dim == 0
623
  if key_padding_mask is not None:
624
  assert cu_seqlens is None
625
  assert max_seqlen is None
 
643
  else inference_params.seqlen_offset
644
  )
645
  )
646
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
647
+ batch, seqlen = x.shape[:2]
 
 
 
648
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
649
  assert x_kv is None and mixer_subset is None
650
+ if not self.return_residual:
651
+ qkv = self.Wqkv(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  else:
653
+ qkv, x = self.Wqkv(x)
 
 
 
 
 
 
 
654
  if self.dwconv:
655
  qkv = rearrange(
656
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
657
  ).contiguous()
658
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
 
 
659
  if (
660
  inference_params is None
661
  or inference_params.seqlen_offset == 0
 
664
  ):
665
  if self.rotary_emb_dim > 0:
666
  qkv = self.rotary_emb(
667
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
 
 
 
668
  )
669
  if inference_params is None:
670
  if not self.checkpointing:
671
  context = self.inner_attn(qkv, **kwargs)
672
  else:
673
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
 
 
674
  else:
675
  context = self._update_kvcache_attention(
676
  qkv[:, :, 0], qkv[:, :, 1:], inference_params
 
699
  q = qkv[..., : self.num_heads * self.head_dim]
700
  kv = qkv[..., self.num_heads * self.head_dim :]
701
  q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
702
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
 
 
703
  if self.dwconv:
704
  q = rearrange(
705
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
706
  ).contiguous()
707
  kv = rearrange(
708
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
709
  ).contiguous()
710
  if (
711
  inference_params is None
 
715
  ):
716
  if self.rotary_emb_dim > 0:
717
  q, kv = self.rotary_emb(
718
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
 
 
 
 
719
  )
720
  if inference_params is None:
721
  if not self.checkpointing:
 
727
  else:
728
  context = self._update_kvcache_attention(q, kv, inference_params)
729
  else:
730
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
731
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
  return out if not self.return_residual else (out, x)
733
+
mlp.py CHANGED
@@ -8,14 +8,14 @@ import torch.nn as nn
8
  import torch.nn.functional as F
9
  from torch.distributed import ProcessGroup
10
 
 
11
  try:
12
  from flash_attn.ops.activations import swiglu
13
  except ImportError:
14
  swiglu = None
15
 
16
  try:
17
- from flash_attn.ops.fused_dense import (ColumnParallelLinear,
18
- RowParallelLinear)
19
  except ImportError:
20
  ColumnParallelLinear, RowParallelLinear = None, None
21
 
@@ -41,48 +41,17 @@ class Mlp(nn.Module):
41
  factory_kwargs = {"device": device, "dtype": dtype}
42
  super().__init__()
43
  out_features = out_features if out_features is not None else in_features
44
- hidden_features = (
45
- hidden_features if hidden_features is not None else in_features * 4
46
- )
47
  self.return_residual = return_residual
48
  self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
49
  self.activation = activation
50
- self.fc2 = nn.Linear(
51
- hidden_features, out_features, bias=bias2, **factory_kwargs
52
- )
53
-
54
- def forward(self, x, adapter_mask=None):
55
- if adapter_mask is not None:
56
- unique_tasks = torch.unique(adapter_mask)
57
- fc1_dtype = next(self.fc1.parameters()).dtype
58
- y = torch.empty(
59
- *x.shape[:-1], self.fc1.out_features, dtype=fc1_dtype, device=x.device
60
- )
61
- for task_id in unique_tasks:
62
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
63
- task_tensor = x[task_indices]
64
- task_y = self.fc1(task_tensor, task_id=task_id)
65
- y[task_indices] = task_y
66
- else:
67
- y = self.fc1(x)
68
 
 
 
69
  y = self.activation(y)
70
-
71
- if adapter_mask is not None:
72
- unique_tasks = torch.unique(adapter_mask)
73
- fc2_dtype = next(self.fc2.parameters()).dtype
74
- out = torch.empty(
75
- *y.shape[:-1], self.fc2.out_features, dtype=fc2_dtype, device=y.device
76
- )
77
- for task_id in unique_tasks:
78
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
79
- task_tensor = y[task_indices]
80
- task_out = self.fc2(task_tensor, task_id=task_id)
81
- out[task_indices] = task_out
82
- else:
83
- out = self.fc2(y)
84
-
85
- return out if not self.return_residual else (out, x)
86
 
87
 
88
  class ParallelMLP(nn.Module):
@@ -104,9 +73,7 @@ class ParallelMLP(nn.Module):
104
  assert ColumnParallelLinear is not None, "Need to install fused_dense"
105
  assert RowParallelLinear is not None, "Need to install fused_dense"
106
  out_features = out_features if out_features is not None else in_features
107
- hidden_features = (
108
- hidden_features if hidden_features is not None else in_features * 4
109
- )
110
  self.fc1 = ColumnParallelLinear(
111
  in_features,
112
  hidden_features,
@@ -152,25 +119,17 @@ class GatedMlp(nn.Module):
152
  hidden_features = (
153
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
154
  )
155
- hidden_features = (
156
- (hidden_features + multiple_of - 1) // multiple_of * multiple_of
157
- )
158
  self.return_residual = return_residual
159
- self.fc1 = nn.Linear(
160
- in_features, 2 * hidden_features, bias=bias1, **factory_kwargs
161
- )
162
  self.activation = activation
163
- self.fc2 = nn.Linear(
164
- hidden_features, out_features, bias=bias2, **factory_kwargs
165
- )
166
 
167
  def forward(self, x):
168
  y = self.fc1(x)
169
  if self.activation == F.sigmoid: # Special case for GLU
170
  y = F.glu(y, dim=-1)
171
- elif (
172
- self.activation == F.silu and swiglu is not None
173
- ): # Special case for SwiGLU
174
  y, gate = y.chunk(2, dim=-1)
175
  y = swiglu(gate, y)
176
  else:
@@ -203,9 +162,7 @@ class ParallelGatedMlp(nn.Module):
203
  hidden_features = (
204
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
205
  )
206
- hidden_features = (
207
- (hidden_features + multiple_of - 1) // multiple_of * multiple_of
208
- )
209
  if ColumnParallelLinear is None or RowParallelLinear is None:
210
  raise ImportError("fused_dense is not installed")
211
  self.fc1 = ColumnParallelLinear(
@@ -234,4 +191,4 @@ class ParallelGatedMlp(nn.Module):
234
  y, gate = y.chunk(2, dim=-1)
235
  y = y * self.activation(gate)
236
  y = self.fc2(y)
237
- return y
 
8
  import torch.nn.functional as F
9
  from torch.distributed import ProcessGroup
10
 
11
+
12
  try:
13
  from flash_attn.ops.activations import swiglu
14
  except ImportError:
15
  swiglu = None
16
 
17
  try:
18
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
 
19
  except ImportError:
20
  ColumnParallelLinear, RowParallelLinear = None, None
21
 
 
41
  factory_kwargs = {"device": device, "dtype": dtype}
42
  super().__init__()
43
  out_features = out_features if out_features is not None else in_features
44
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
 
 
45
  self.return_residual = return_residual
46
  self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
47
  self.activation = activation
48
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ def forward(self, x):
51
+ y = self.fc1(x)
52
  y = self.activation(y)
53
+ y = self.fc2(y)
54
+ return y if not self.return_residual else (y, x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  class ParallelMLP(nn.Module):
 
73
  assert ColumnParallelLinear is not None, "Need to install fused_dense"
74
  assert RowParallelLinear is not None, "Need to install fused_dense"
75
  out_features = out_features if out_features is not None else in_features
76
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
 
 
77
  self.fc1 = ColumnParallelLinear(
78
  in_features,
79
  hidden_features,
 
119
  hidden_features = (
120
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
121
  )
122
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
 
 
123
  self.return_residual = return_residual
124
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
 
 
125
  self.activation = activation
126
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
 
 
127
 
128
  def forward(self, x):
129
  y = self.fc1(x)
130
  if self.activation == F.sigmoid: # Special case for GLU
131
  y = F.glu(y, dim=-1)
132
+ elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
 
 
133
  y, gate = y.chunk(2, dim=-1)
134
  y = swiglu(gate, y)
135
  else:
 
162
  hidden_features = (
163
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
164
  )
165
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
 
 
166
  if ColumnParallelLinear is None or RowParallelLinear is None:
167
  raise ImportError("fused_dense is not installed")
168
  self.fc1 = ColumnParallelLinear(
 
191
  y, gate = y.chunk(2, dim=-1)
192
  y = y * self.activation(gate)
193
  y = self.fc2(y)
194
+ return y
modeling_lora.py CHANGED
@@ -1,5 +1,6 @@
1
  import math
2
  import os
 
3
  from functools import partial
4
  from typing import Iterator, List, Optional, Tuple, Union
5
 
@@ -8,15 +9,12 @@ import torch
8
  import torch.nn.utils.parametrize as parametrize
9
  from torch import nn
10
  from torch.nn import Parameter
11
- from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
14
- from .configuration_xlm_roberta import XLMRobertaFlashConfig
15
- from .modeling_xlm_roberta import (
16
- XLMRobertaFlashConfig,
17
- XLMRobertaModel,
18
- XLMRobertaPreTrainedModel,
19
- )
20
 
21
 
22
  def initialized_weights(
@@ -93,19 +91,22 @@ class LoRAParametrization(nn.Module):
93
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
94
  persistent=False,
95
  )
 
 
96
 
97
  def _dropout(self, A):
98
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
99
  return A * self.lora_dropout(self.lora_dropout_mask)
100
 
101
- def lora_forward(self, X, current_task):
 
102
  return (
103
  X
104
  + torch.matmul(
105
  *self.swap(
106
  (
107
- self.lora_B[current_task],
108
- self.dropout_fn(self.lora_A[current_task]),
109
  )
110
  )
111
  ).view(X.shape)
@@ -113,7 +114,19 @@ class LoRAParametrization(nn.Module):
113
  )
114
 
115
  def forward(self, X):
116
- return X
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  @classmethod
119
  def from_linear(
@@ -166,15 +179,6 @@ class LoRAParametrization(nn.Module):
166
  dropout_p: float,
167
  alpha: float,
168
  ):
169
- """
170
- Registering LoRA adapters to all embedding and linear layers.
171
- Additionally, we implement a custom forward function for LoRA parametrization.
172
- This function modifies the layer's forward pass to optionally use task-specific
173
- parameters. When a `task_id` is provided, it employs a LoRA parametrization
174
- to modify the original weights according to the specific task. This allows
175
- the layer to adapt dynamically to different tasks at runtime. If no `task_id`
176
- is specified, the layer uses its original weights.
177
- """
178
  if isinstance(layer, nn.Linear):
179
  parametrize.register_parametrization(
180
  layer,
@@ -187,23 +191,6 @@ class LoRAParametrization(nn.Module):
187
  alpha=alpha,
188
  ),
189
  )
190
-
191
- def new_forward(self, input, task_id=None, residual=False):
192
- if task_id is not None:
193
- weights = self.parametrizations.weight[0].lora_forward(
194
- self.weight, current_task=task_id
195
- )
196
- else:
197
- weights = self.weight
198
-
199
- out = F.linear(input, weights, self.bias)
200
-
201
- if residual:
202
- return out, input
203
- return out
204
-
205
- layer.forward = new_forward.__get__(layer, layer.__class__)
206
-
207
  elif isinstance(layer, nn.Embedding):
208
  parametrize.register_parametrization(
209
  layer,
@@ -217,43 +204,22 @@ class LoRAParametrization(nn.Module):
217
  ),
218
  )
219
 
220
- def new_forward(self, input, task_id=None):
221
- if task_id is not None:
222
- weights = self.parametrizations.weight[0].lora_forward(
223
- self.weight, current_task=task_id
224
- )
225
- else:
226
- weights = self.weight
227
-
228
- out = F.embedding(
229
- input,
230
- weights,
231
- self.padding_idx,
232
- self.max_norm,
233
- self.norm_type,
234
- self.scale_grad_by_freq,
235
- self.sparse,
236
- )
237
-
238
- return out
239
-
240
- layer.forward = new_forward.__get__(layer, layer.__class__)
241
 
242
 
243
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
244
- """
245
- A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
246
- """
247
-
248
  def __init__(
249
  self,
250
  config: XLMRobertaFlashConfig,
251
- roberta: Optional[XLMRobertaModel] = None,
252
- add_pooling_layer: bool = True,
253
  ):
254
  super().__init__(config)
 
255
  if roberta is None:
256
- self.roberta = XLMRobertaModel(config, add_pooling_layer=add_pooling_layer)
257
  else:
258
  self.roberta = roberta
259
 
@@ -263,19 +229,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
263
  or len(self._lora_adaptations) < 1
264
  ):
265
  raise ValueError(
266
- f"`lora_adaptations` must be a list and contain at least one element"
267
- )
268
- self._task_instructions = config.task_instructions
269
- if (
270
- not isinstance(self._task_instructions, dict)
271
- or len(self._task_instructions) != len(self._lora_adaptations)
272
- or not all(
273
- [v in self._lora_adaptations for v in self._task_instructions.keys()]
274
- )
275
- ):
276
- raise ValueError(
277
- f"`task_instructions` must be a dict and contain the same number of elements "
278
- f"as `lora_adaptations` with all keys in `task_instructions` present in `lora_adaptations`."
279
  )
280
  self._adaptation_map = {
281
  name: idx for idx, name in enumerate(self._lora_adaptations)
@@ -290,14 +244,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
290
  alpha=self._alpha,
291
  )
292
  self.main_params_trainable = config.lora_main_params_trainable
293
-
294
- @property
295
- def rotary_emb_base(self):
296
- return self.roberta.rotary_emb_base
297
-
298
- @rotary_emb_base.setter
299
- def rotary_emb_base(self, base):
300
- self.roberta.rotary_emb_base = base
301
 
302
  @property
303
  def main_params_trainable(self):
@@ -331,30 +280,16 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
331
  use_safetensors: bool = None,
332
  **kwargs,
333
  ):
334
- for key in list(kwargs.keys()):
335
- if key in config.to_dict():
336
- config.update({key: kwargs.pop(key)})
337
- if config.load_trained_adapters: # checkpoint already contains LoRA adapters
 
338
  return super().from_pretrained(
339
- pretrained_model_name_or_path,
340
- *model_args,
341
- config=config,
342
- cache_dir=cache_dir,
343
- ignore_mismatched_sizes=ignore_mismatched_sizes,
344
- force_download=force_download,
345
- local_files_only=local_files_only,
346
- token=token,
347
- revision=revision,
348
- use_safetensors=use_safetensors,
349
- **kwargs,
350
- )
351
- else: # initializing new adapters
352
- roberta = XLMRobertaModel.from_pretrained(
353
- pretrained_model_name_or_path,
354
- *model_args,
355
- use_flash_attn=config.use_flash_attn,
356
- **kwargs,
357
  )
 
 
358
  return cls(config, roberta=roberta)
359
 
360
  def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
@@ -368,7 +303,39 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
368
  )
369
  )
370
 
371
- def forward(self, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  return self.roberta(*args, **kwargs)
373
 
374
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
@@ -387,40 +354,28 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
387
  @torch.inference_mode()
388
  def encode(
389
  self,
390
- sentences: Union[str, List[str]],
391
  *args,
392
- task: Optional[str] = None,
393
  **kwargs,
394
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
395
  """
396
- Computes sentence embeddings.
397
- sentences(`str` or `List[str]`):
398
- Sentence or sentences to be encoded
399
- task(`str`, *optional*, defaults to `None`):
400
- Specifies the task for which the encoding is intended. If `task` is not provided,
401
- all LoRA adapters are disabled, and the model reverts to its original,
402
- general-purpose weights.
 
 
403
  """
404
- if task and task not in self._lora_adaptations:
405
- raise ValueError(
406
- f"Unsupported task '{task}'. "
407
- f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
408
- f"Alternatively, don't pass the `task` argument to disable LoRA."
409
- )
410
- adapter_mask = None
411
- if task:
412
- task_id = self._adaptation_map[task]
413
- num_examples = 1 if isinstance(sentences, str) else len(sentences)
414
- adapter_mask = torch.full(
415
- (num_examples,), task_id, dtype=torch.int32, device=self.device
416
- )
417
- if isinstance(sentences, str):
418
- sentences = self._task_instructions[task] + sentences
419
- else:
420
- sentences = [
421
- self._task_instructions[task] + sentence for sentence in sentences
422
- ]
423
- return self.roberta.encode(
424
- sentences, *args, adapter_mask=adapter_mask, **kwargs
425
- )
426
 
 
 
1
  import math
2
  import os
3
+ import warnings
4
  from functools import partial
5
  from typing import Iterator, List, Optional, Tuple, Union
6
 
 
9
  import torch.nn.utils.parametrize as parametrize
10
  from torch import nn
11
  from torch.nn import Parameter
 
12
  from transformers import PretrainedConfig
13
 
14
+ from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
15
+
16
+
17
+ LORA_NO_UPDATE = '__lora_no_update__'
 
 
18
 
19
 
20
  def initialized_weights(
 
91
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
92
  persistent=False,
93
  )
94
+ self.forward_fn = lambda x: x
95
+ self.current_task = None
96
 
97
  def _dropout(self, A):
98
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
99
  return A * self.lora_dropout(self.lora_dropout_mask)
100
 
101
+ def lora_forward(self, X):
102
+ assert self.current_task is not None
103
  return (
104
  X
105
  + torch.matmul(
106
  *self.swap(
107
  (
108
+ self.lora_B[self.current_task],
109
+ self.dropout_fn(self.lora_A[self.current_task]),
110
  )
111
  )
112
  ).view(X.shape)
 
114
  )
115
 
116
  def forward(self, X):
117
+ return self.forward_fn(X)
118
+
119
+ @property
120
+ def current_task(self):
121
+ return self._current_task
122
+
123
+ @current_task.setter
124
+ def current_task(self, task: Union[None, int]):
125
+ self._current_task = task
126
+ if task is None:
127
+ self.forward_fn = lambda x: x
128
+ else:
129
+ self.forward_fn = self.lora_forward
130
 
131
  @classmethod
132
  def from_linear(
 
179
  dropout_p: float,
180
  alpha: float,
181
  ):
 
 
 
 
 
 
 
 
 
182
  if isinstance(layer, nn.Linear):
183
  parametrize.register_parametrization(
184
  layer,
 
191
  alpha=alpha,
192
  ),
193
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  elif isinstance(layer, nn.Embedding):
195
  parametrize.register_parametrization(
196
  layer,
 
204
  ),
205
  )
206
 
207
+ @staticmethod
208
+ def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
209
+ if isinstance(layer, LoRAParametrization):
210
+ layer.current_task = task_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
 
 
 
 
214
  def __init__(
215
  self,
216
  config: XLMRobertaFlashConfig,
217
+ roberta: Optional[XLMRobertaModel] = None
 
218
  ):
219
  super().__init__(config)
220
+
221
  if roberta is None:
222
+ self.roberta = XLMRobertaModel(config)
223
  else:
224
  self.roberta = roberta
225
 
 
229
  or len(self._lora_adaptations) < 1
230
  ):
231
  raise ValueError(
232
+ f'`lora_adaptations` must be a list and contain at least one element'
 
 
 
 
 
 
 
 
 
 
 
 
233
  )
234
  self._adaptation_map = {
235
  name: idx for idx, name in enumerate(self._lora_adaptations)
 
244
  alpha=self._alpha,
245
  )
246
  self.main_params_trainable = config.lora_main_params_trainable
247
+ self._task_idx = None
248
+ # By default, disable LoRA until it's specified which adapter/task to use
249
+ self.current_task = None
 
 
 
 
 
250
 
251
  @property
252
  def main_params_trainable(self):
 
280
  use_safetensors: bool = None,
281
  **kwargs,
282
  ):
283
+ config = XLMRobertaFlashConfig.from_pretrained(
284
+ pretrained_model_name_or_path, *model_args, **kwargs
285
+ )
286
+
287
+ if config.load_trained_adapters:
288
  return super().from_pretrained(
289
+ pretrained_model_name_or_path, *model_args, **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  )
291
+ else:
292
+ roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
293
  return cls(config, roberta=roberta)
294
 
295
  def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
 
303
  )
304
  )
305
 
306
+ @property
307
+ def current_task(self):
308
+ """Which LoRA is currently selected
309
+ :return: Integer or None (when LoRA is disabled)
310
+ """
311
+ return self._task_idx
312
+
313
+ @current_task.setter
314
+ def current_task(self, task_name: Union[None, str]):
315
+ """Set the LoRA that is to be used.
316
+ The LoRA is specified by `task_idx`, which may be an integer >= 0,
317
+ indexing the available LoRAs. If it is None, no LoRA is used.
318
+ :param task_name: Which LoRA to use
319
+ :return:
320
+ """
321
+ if task_name and task_name not in self._lora_adaptations:
322
+ raise ValueError(
323
+ f"Unsupported task '{task_name}'. "
324
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
325
+ f"Alternatively, set `task` to `None` if you want to disable LoRA."
326
+ )
327
+ task_idx = self._adaptation_map[task_name] if task_name else None
328
+ if self._task_idx != task_idx:
329
+ # In this case, we need to update the LoRAs everywhere
330
+ self._task_idx = task_idx
331
+ self.apply(
332
+ partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
333
+ )
334
+
335
+ def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
336
+ if task != LORA_NO_UPDATE:
337
+ self.current_task = task
338
+
339
  return self.roberta(*args, **kwargs)
340
 
341
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
 
354
  @torch.inference_mode()
355
  def encode(
356
  self,
 
357
  *args,
358
+ task: Union[str, None] = LORA_NO_UPDATE,
359
  **kwargs,
360
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
361
  """
362
+ Computes sentence embeddings
363
+
364
+ task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
365
+ Specifies the task for which the encoding is intended. This parameter controls the
366
+ use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
367
+ to `LORA_NO_UPDATE`, there will be no update to the current task, retaining the
368
+ existing adapter configuration. If `task` is explicitly set to `None`, all LoRA
369
+ adapters are disabled, and the model reverts to its original, general-purpose weights.
370
+ If `task` is set to a specific LoRA adaptation, that adaptation is activated.
371
  """
372
+ if task != LORA_NO_UPDATE:
373
+ if not task:
374
+ warnings.warn(
375
+ f"Task-specific embeddings are disabled. To enable, specify the `task` "
376
+ f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
377
+ category=UserWarning,
378
+ )
379
+ self.current_task = task
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
+ return self.roberta.encode(*args, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -13,30 +13,39 @@ import re
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
16
- from typing import List, Optional, Tuple, Union
17
-
18
  import numpy as np
 
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
  import torch.utils.checkpoint
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
- from transformers import AutoTokenizer, PretrainedConfig
25
- from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
26
  from transformers.modeling_utils import PreTrainedModel
 
 
 
27
  from transformers.models.bert.modeling_bert import (
28
  BaseModelOutputWithPoolingAndCrossAttentions,
29
  BertForPreTrainingOutput,
30
  )
31
- from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
32
 
33
- from .rotary import RotaryEmbedding
34
- from .block import Block
 
 
 
 
 
 
35
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
 
36
  from .embedding import XLMRobertaEmbeddings
37
  from .mha import MHA
38
  from .mlp import FusedMLP, Mlp
39
- from .xlm_padding import index_first_axis_residual, pad_input, unpad_input
 
40
 
41
  try:
42
  from flash_attn.ops.fused_dense import FusedDense
@@ -64,11 +73,13 @@ logger = logging.getLogger(__name__)
64
 
65
 
66
  def get_use_flash_attn(config: XLMRobertaFlashConfig):
67
- if not getattr(config, "use_flash_attn", False) or not torch.cuda.is_available():
 
 
68
  return False
69
  if importlib.util.find_spec("flash_attn") is None:
70
  logger.warning(
71
- "flash_attn is not installed. Using PyTorch native attention implementation."
72
  )
73
  return False
74
  return True
@@ -80,9 +91,9 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
80
  rotary_kwargs = {}
81
  if config.position_embedding_type == "rotary":
82
  rotary_kwargs["rotary_emb_dim"] = getattr(
83
- config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
84
  )
85
- rotary_kwargs["rotary_emb_base"] = config.rotary_emb_base
86
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
87
  config, "rotary_emb_scale_base", None
88
  )
@@ -98,7 +109,6 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
98
  fused_bias_fc=fused_bias_fc,
99
  use_flash_attn=use_flash_attn,
100
  return_residual=return_residual,
101
- use_alibi=config.position_embedding_type == "alibi",
102
  **rotary_kwargs,
103
  )
104
  return mixer_cls
@@ -180,7 +190,6 @@ class XLMRobertaEncoder(nn.Module):
180
  def __init__(self, config: XLMRobertaFlashConfig):
181
  super().__init__()
182
  self.use_flash_attn = get_use_flash_attn(config)
183
- self.use_reentrant = config.use_reentrant
184
  self.layers = nn.ModuleList(
185
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
186
  )
@@ -194,72 +203,46 @@ class XLMRobertaEncoder(nn.Module):
194
  def gradient_checkpointing(self, value):
195
  self._grad_checkpointing = value
196
 
197
- def forward(
198
- self,
199
- hidden_states,
200
- key_padding_mask=None,
201
- subset_mask=None,
202
- adapter_mask=None,
203
- output_hidden_states: Optional[bool] = None,
204
- ):
205
  """If subset_mask is not None, we only want output for the subset of the sequence.
206
  This means that we only compute the last layer output for these tokens.
207
  subset_mask: (batch, seqlen), dtype=torch.bool
208
  """
209
-
210
- all_hidden_states = () if output_hidden_states else None
211
-
212
- if output_hidden_states and subset_mask:
213
- raise ValueError('output_hidden_states is not supported for subset_masks')
214
-
215
  if key_padding_mask is None or not self.use_flash_attn:
216
- mixer_kwargs = {"adapter_mask": adapter_mask}
217
- if key_padding_mask is not None:
218
- mixer_kwargs["key_padding_mask"] = key_padding_mask.bool()
 
 
219
  for layer in self.layers:
220
- if output_hidden_states:
221
- all_hidden_states = all_hidden_states + (hidden_states,)
222
  if self._grad_checkpointing:
223
  hidden_states = torch.utils.checkpoint.checkpoint(
224
  layer,
225
  hidden_states,
226
- use_reentrant=self.use_reentrant,
227
  mixer_kwargs=mixer_kwargs,
228
  )
229
  else:
230
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
231
- if output_hidden_states:
232
- all_hidden_states = all_hidden_states + (hidden_states,)
233
  if subset_mask is not None:
234
  hidden_states = hidden_states[subset_mask]
235
  else:
236
  batch, seqlen = hidden_states.shape[:2]
237
- if output_hidden_states:
238
- all_hidden_states = all_hidden_states + (hidden_states,)
239
- hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = (
240
- unpad_input(hidden_states, key_padding_mask, adapter_mask)
241
  )
242
- mixer_kwargs = {
243
- "cu_seqlens": cu_seqlens,
244
- "max_seqlen": max_seqlen_in_batch,
245
- "adapter_mask": cu_adapter_mask,
246
- }
247
-
248
  if subset_mask is None:
249
  for layer in self.layers:
250
  if self._grad_checkpointing:
251
  hidden_states = torch.utils.checkpoint.checkpoint(
252
  layer,
253
  hidden_states,
254
- use_reentrant=self.use_reentrant,
255
  mixer_kwargs=mixer_kwargs,
256
  )
257
  else:
258
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
259
- if output_hidden_states:
260
- all_hidden_states = all_hidden_states + (
261
- pad_input(hidden_states, indices, batch, seqlen),
262
- )
263
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
264
  else:
265
  for layer in self.layers[:-1]:
@@ -267,7 +250,7 @@ class XLMRobertaEncoder(nn.Module):
267
  hidden_states = torch.utils.checkpoint.checkpoint(
268
  layer,
269
  hidden_states,
270
- use_reentrant=self.use_reentrant,
271
  mixer_kwargs=mixer_kwargs,
272
  )
273
  else:
@@ -305,14 +288,14 @@ class XLMRobertaEncoder(nn.Module):
305
  torch.utils.checkpoint.checkpoint(
306
  self.layers[-1],
307
  hidden_states_subset,
308
- use_reentrant=self.use_reentrant,
309
  mixer_kwargs=mixer_kwargs,
310
  )
311
  else:
312
  hidden_states = self.layers[-1](
313
  hidden_states_subset, mixer_kwargs=mixer_kwargs
314
  )
315
- return all_hidden_states if output_hidden_states else hidden_states
316
 
317
 
318
  class XLMRobertaPooler(nn.Module):
@@ -325,28 +308,11 @@ class XLMRobertaPooler(nn.Module):
325
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
326
  self.activation = nn.Tanh()
327
 
328
- def forward(self, hidden_states, pool=True, adapter_mask=None):
329
  # We "pool" the model by simply taking the hidden state corresponding
330
  # to the first token.
331
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
332
- if adapter_mask is not None:
333
- unique_tasks = torch.unique(adapter_mask)
334
- pool_dtype = next(self.dense.parameters()).dtype
335
- pooled_output = torch.empty(
336
- first_token_tensor.shape[0],
337
- self.dense.out_features,
338
- dtype=pool_dtype,
339
- device=first_token_tensor.device,
340
- )
341
- for task_id in unique_tasks:
342
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
343
- task_first_token_tensor = first_token_tensor[task_indices]
344
- task_pooled_output = self.dense(
345
- task_first_token_tensor, task_id=task_id
346
- )
347
- pooled_output[task_indices] = task_pooled_output
348
- else:
349
- pooled_output = self.dense(first_token_tensor)
350
  pooled_output = self.activation(pooled_output)
351
  return pooled_output
352
 
@@ -425,7 +391,6 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
425
  config_class = XLMRobertaFlashConfig
426
  base_model_prefix = "roberta"
427
  supports_gradient_checkpointing = True
428
- _supports_param_buffer_assignment = False
429
 
430
  def _set_gradient_checkpointing(self, module, value=False):
431
  if isinstance(module, XLMRobertaEncoder):
@@ -437,11 +402,12 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
437
  *args,
438
  **kwargs,
439
  ):
440
- if not "torch_dtype" in kwargs:
441
- kwargs["torch_dtype"] = "auto"
442
  return super().from_pretrained(*args, **kwargs)
443
 
444
 
 
445
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
446
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
447
  super().__init__(config)
@@ -459,14 +425,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
459
  "gelu_fast",
460
  "gelu_pytorch_tanh",
461
  ]
 
462
  self.embeddings = XLMRobertaEmbeddings(
463
  config.hidden_size,
464
  config.vocab_size,
465
- (
466
- config.max_position_embeddings
467
- if config.position_embedding_type == "absolute"
468
- else -1
469
- ),
470
  config.type_vocab_size,
471
  padding_idx=config.pad_token_id,
472
  )
@@ -476,25 +439,20 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
476
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
477
 
478
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
479
- self.tokenizer = AutoTokenizer.from_pretrained(
480
- self.name_or_path, trust_remote_code=True
481
- )
482
- self._rotary_emb_base = config.rotary_emb_base
483
 
484
  @torch.inference_mode()
485
  def encode(
486
- self: "XLMRobertaModel",
487
  sentences: Union[str, List[str]],
488
  batch_size: int = 32,
489
  show_progress_bar: Optional[bool] = None,
490
- output_value: str = "sentence_embedding",
491
  convert_to_numpy: bool = True,
492
  convert_to_tensor: bool = False,
493
  device: Optional[torch.device] = None,
494
- normalize_embeddings: bool = True,
495
  truncate_dim: Optional[int] = None,
496
- adapter_mask: Optional[torch.Tensor] = None,
497
- task: Optional[str] = None,
498
  **tokenizer_kwargs,
499
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
500
  """
@@ -520,7 +478,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
520
  Overwrites any setting from convert_to_numpy
521
  device(`torch.device`, *optional*, defaults to None):
522
  Which torch.device to use for the computation
523
- normalize_embeddings(`bool`, *optional*, defaults to True):
524
  If set to true, returned vectors will have length 1. In that case, the
525
  faster dot-product (util.dot_score) instead of cosine similarity can
526
  be used.
@@ -533,6 +491,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
533
  If convert_to_tensor, a stacked tensor is returned.
534
  If convert_to_numpy, a numpy matrix is returned.
535
  """
 
 
 
 
 
 
536
  is_training = self.training
537
  self.eval()
538
 
@@ -545,12 +509,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
545
  if convert_to_tensor:
546
  convert_to_numpy = False
547
 
548
- if output_value != "sentence_embedding":
549
  convert_to_tensor = False
550
  convert_to_numpy = False
551
 
552
  input_was_string = False
553
- if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
554
  sentences = [sentences]
555
  input_was_string = True
556
 
@@ -561,11 +525,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
561
  inverse_permutation = np.argsort(permutation)
562
  sentences = [sentences[idx] for idx in permutation]
563
 
564
- tokenizer_kwargs["padding"] = tokenizer_kwargs.get("padding", True)
565
- tokenizer_kwargs["max_length"] = tokenizer_kwargs.get(
566
- "max_length", self.tokenizer.init_kwargs.get("model_max_length", 8192)
567
  )
568
- tokenizer_kwargs["truncation"] = tokenizer_kwargs.get("truncation", True)
569
 
570
  all_embeddings = []
571
 
@@ -583,33 +547,33 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
583
  for i in range_iter:
584
  encoded_input = self.tokenizer(
585
  sentences[i : i + batch_size],
586
- return_tensors="pt",
587
  **tokenizer_kwargs,
588
  ).to(self.device)
589
- lora_arguments = (
590
- {"adapter_mask": adapter_mask[i : i + batch_size]}
591
- if adapter_mask is not None
592
- else {}
593
- )
594
- token_embs = self.forward(**encoded_input, **lora_arguments)[0]
595
 
596
  # Accumulate in fp32 to avoid overflow
597
  token_embs = token_embs.float()
598
 
599
- if output_value == "token_embeddings":
600
  raise NotImplementedError
601
  elif output_value is None:
602
  raise NotImplementedError
603
  else:
604
- if self.config.emb_pooler == "cls":
605
  embeddings = self.cls_pooling(
606
- token_embs, encoded_input["attention_mask"]
607
  )
608
  else:
609
  embeddings = self.mean_pooling(
610
- token_embs, encoded_input["attention_mask"]
611
  )
612
 
 
 
 
 
 
613
  all_embeddings.extend(embeddings)
614
 
615
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
@@ -618,16 +582,10 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
618
  if truncate_dim:
619
  all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
620
 
621
- if normalize_embeddings:
622
- all_embeddings = [
623
- torch.nn.functional.normalize(embedding, p=2, dim=0)
624
- for embedding in all_embeddings
625
- ]
626
-
627
  if convert_to_tensor:
628
  all_embeddings = torch.stack(all_embeddings)
629
  elif convert_to_numpy:
630
- all_embeddings = np.asarray([emb.cpu().numpy() for emb in all_embeddings])
631
 
632
  if input_was_string:
633
  all_embeddings = all_embeddings[0]
@@ -635,19 +593,18 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
635
  self.train(is_training)
636
  return all_embeddings
637
 
 
638
  def truncate_embeddings(self, embeddings, truncate_dim):
639
  if not self.config.matryoshka_dimensions:
640
  logger.warning(
641
- "Matryoshka embeddings are not supported, so dimension truncation will not be performed."
642
  )
643
  return embeddings
644
  elif truncate_dim in self.config.matryoshka_dimensions:
645
  return [tensor[:truncate_dim] for tensor in embeddings]
646
  else:
647
- raise ValueError(
648
- f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
649
- f"Supported dimensions are {self.config.matryoshka_dimensions}."
650
- )
651
 
652
  def mean_pooling(
653
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
@@ -659,21 +616,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
659
  input_mask_expanded.sum(1), min=1e-9
660
  )
661
 
662
- def cls_pooling(self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor):
663
- return token_embeddings[:, 0]
664
 
665
- @property
666
- def rotary_emb_base(self):
667
- return self._rotary_emb_base
668
-
669
- @rotary_emb_base.setter
670
- def rotary_emb_base(self, base):
671
- if not isinstance(base, (int, float)):
672
- raise TypeError("Base must be an integer or float")
673
- logger.info(f"Changing RoPE base value to {base}")
674
- for layer in self.encoder.layers:
675
- layer.mixer.rotary_emb.base = base
676
- self._rotary_emb_base = base
677
 
678
  def forward(
679
  self,
@@ -683,7 +631,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
683
  attention_mask=None,
684
  masked_tokens_mask=None,
685
  return_dict=None,
686
- output_hidden_states=None,
687
  **kwargs,
688
  ):
689
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
@@ -691,12 +638,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
691
  layer output for these tokens.
692
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
693
  """
694
- adapter_mask = kwargs.pop("adapter_mask", None)
695
  if kwargs:
696
  for key, value in kwargs.items():
697
  if value is not None:
698
  logger.warning(
699
- "Flash attention implementation does not support kwargs: %s",
700
  key,
701
  )
702
 
@@ -705,10 +652,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
705
  )
706
 
707
  hidden_states = self.embeddings(
708
- input_ids,
709
- position_ids=position_ids,
710
- token_type_ids=token_type_ids,
711
- adapter_mask=adapter_mask,
712
  )
713
  # TD [2022-12:18]: Don't need to force residual in fp32
714
  # BERT puts embedding LayerNorm before embedding dropout.
@@ -732,24 +676,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
732
  subset_mask = None
733
 
734
  sequence_output = self.encoder(
735
- hidden_states,
736
- key_padding_mask=attention_mask,
737
- subset_mask=subset_mask,
738
- adapter_mask=adapter_mask,
739
- output_hidden_states=output_hidden_states,
740
  )
741
 
742
- if output_hidden_states:
743
- all_hidden_states = sequence_output
744
- sequence_output = sequence_output[-1]
745
- else:
746
- all_hidden_states = None
747
-
748
  if masked_tokens_mask is None:
749
  pooled_output = (
750
- self.pooler(sequence_output, adapter_mask=adapter_mask)
751
- if self.pooler is not None
752
- else None
753
  )
754
  else:
755
  # TD [2022-03-01]: the indexing here is very tricky.
@@ -763,9 +695,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
763
  pool_input = sequence_output[first_col_mask[subset_mask]]
764
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
765
  pooled_output = (
766
- self.pooler(pool_input, pool=False, adapter_mask=adapter_mask)
767
- if self.pooler is not None
768
- else None
769
  )
770
 
771
  if not return_dict:
@@ -774,7 +704,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
774
  return BaseModelOutputWithPoolingAndCrossAttentions(
775
  last_hidden_state=sequence_output,
776
  pooler_output=pooled_output,
777
- hidden_states=all_hidden_states,
778
  )
779
 
780
 
@@ -871,6 +800,103 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
871
  )
872
 
873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
  def remap_state_dict(state_dict, config: PretrainedConfig):
875
  """
876
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
@@ -1022,47 +1048,47 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
1022
  if not last_layer_subset or d != (config.num_hidden_layers - 1):
1023
  Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
1024
  Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
1025
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
1026
- Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
1027
- )
1028
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
1029
- Wqkv_weights[
1030
- Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
1031
- ]
1032
- )
1033
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
1034
- Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
1035
- )
1036
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = (
1037
- Wqkv_biases[: Wqkv_biases.shape[0] // 3]
1038
- )
1039
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = (
1040
- Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
1041
- )
1042
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
1043
- Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
1044
- )
1045
  else:
1046
  Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
1047
  Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
1048
  Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
1049
  Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
1050
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
1051
- Wq_weight
1052
- )
1053
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
1054
- Wkv_weights[: Wkv_weights.shape[0] // 2, :]
1055
- )
1056
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
1057
- Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
1058
- )
1059
  state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
1060
  state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
1061
  : Wkv_biases.shape[0] // 2
1062
  ]
1063
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
1064
- Wkv_biases[Wkv_biases.shape[0] // 2 :]
1065
- )
1066
 
1067
  def inv_key_mapping_ln(key):
1068
  key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
 
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
 
 
16
  import numpy as np
17
+
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+ from einops import rearrange
24
+ from transformers import PretrainedConfig
25
  from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
+
29
  from transformers.models.bert.modeling_bert import (
30
  BaseModelOutputWithPoolingAndCrossAttentions,
31
  BertForPreTrainingOutput,
32
  )
 
33
 
34
+ from typing import List, Optional, Tuple, Union
35
+
36
+ from .xlm_padding import (
37
+ index_first_axis,
38
+ index_first_axis_residual,
39
+ pad_input,
40
+ unpad_input,
41
+ )
42
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
43
+ from .block import Block
44
  from .embedding import XLMRobertaEmbeddings
45
  from .mha import MHA
46
  from .mlp import FusedMLP, Mlp
47
+ from .stochastic_depth import StochasticDepth
48
+
49
 
50
  try:
51
  from flash_attn.ops.fused_dense import FusedDense
 
73
 
74
 
75
  def get_use_flash_attn(config: XLMRobertaFlashConfig):
76
+ if not getattr(config, "use_flash_attn", False):
77
+ return False
78
+ if not torch.cuda.is_available():
79
  return False
80
  if importlib.util.find_spec("flash_attn") is None:
81
  logger.warning(
82
+ 'flash_attn is not installed. Using PyTorch native attention implementation.'
83
  )
84
  return False
85
  return True
 
91
  rotary_kwargs = {}
92
  if config.position_embedding_type == "rotary":
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
+ config, "rotary_emb_dim", config.hidden_size
95
  )
96
+ rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
98
  config, "rotary_emb_scale_base", None
99
  )
 
109
  fused_bias_fc=fused_bias_fc,
110
  use_flash_attn=use_flash_attn,
111
  return_residual=return_residual,
 
112
  **rotary_kwargs,
113
  )
114
  return mixer_cls
 
190
  def __init__(self, config: XLMRobertaFlashConfig):
191
  super().__init__()
192
  self.use_flash_attn = get_use_flash_attn(config)
 
193
  self.layers = nn.ModuleList(
194
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
195
  )
 
203
  def gradient_checkpointing(self, value):
204
  self._grad_checkpointing = value
205
 
206
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
 
 
 
 
 
 
 
207
  """If subset_mask is not None, we only want output for the subset of the sequence.
208
  This means that we only compute the last layer output for these tokens.
209
  subset_mask: (batch, seqlen), dtype=torch.bool
210
  """
 
 
 
 
 
 
211
  if key_padding_mask is None or not self.use_flash_attn:
212
+ mixer_kwargs = (
213
+ {"key_padding_mask": key_padding_mask.bool()}
214
+ if key_padding_mask is not None
215
+ else None
216
+ )
217
  for layer in self.layers:
 
 
218
  if self._grad_checkpointing:
219
  hidden_states = torch.utils.checkpoint.checkpoint(
220
  layer,
221
  hidden_states,
222
+ use_reentrant=False,
223
  mixer_kwargs=mixer_kwargs,
224
  )
225
  else:
226
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
227
  if subset_mask is not None:
228
  hidden_states = hidden_states[subset_mask]
229
  else:
230
  batch, seqlen = hidden_states.shape[:2]
231
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
232
+ hidden_states, key_padding_mask
 
 
233
  )
234
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
 
 
 
 
 
235
  if subset_mask is None:
236
  for layer in self.layers:
237
  if self._grad_checkpointing:
238
  hidden_states = torch.utils.checkpoint.checkpoint(
239
  layer,
240
  hidden_states,
241
+ use_reentrant=False,
242
  mixer_kwargs=mixer_kwargs,
243
  )
244
  else:
245
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
246
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
247
  else:
248
  for layer in self.layers[:-1]:
 
250
  hidden_states = torch.utils.checkpoint.checkpoint(
251
  layer,
252
  hidden_states,
253
+ use_reentrant=False,
254
  mixer_kwargs=mixer_kwargs,
255
  )
256
  else:
 
288
  torch.utils.checkpoint.checkpoint(
289
  self.layers[-1],
290
  hidden_states_subset,
291
+ use_reentrant=False,
292
  mixer_kwargs=mixer_kwargs,
293
  )
294
  else:
295
  hidden_states = self.layers[-1](
296
  hidden_states_subset, mixer_kwargs=mixer_kwargs
297
  )
298
+ return hidden_states
299
 
300
 
301
  class XLMRobertaPooler(nn.Module):
 
308
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
309
  self.activation = nn.Tanh()
310
 
311
+ def forward(self, hidden_states, pool=True):
312
  # We "pool" the model by simply taking the hidden state corresponding
313
  # to the first token.
314
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
315
+ pooled_output = self.dense(first_token_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  pooled_output = self.activation(pooled_output)
317
  return pooled_output
318
 
 
391
  config_class = XLMRobertaFlashConfig
392
  base_model_prefix = "roberta"
393
  supports_gradient_checkpointing = True
 
394
 
395
  def _set_gradient_checkpointing(self, module, value=False):
396
  if isinstance(module, XLMRobertaEncoder):
 
402
  *args,
403
  **kwargs,
404
  ):
405
+ if not 'torch_dtype' in kwargs:
406
+ kwargs['torch_dtype'] = 'auto'
407
  return super().from_pretrained(*args, **kwargs)
408
 
409
 
410
+
411
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
412
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
413
  super().__init__(config)
 
425
  "gelu_fast",
426
  "gelu_pytorch_tanh",
427
  ]
428
+
429
  self.embeddings = XLMRobertaEmbeddings(
430
  config.hidden_size,
431
  config.vocab_size,
432
+ config.max_position_embeddings,
 
 
 
 
433
  config.type_vocab_size,
434
  padding_idx=config.pad_token_id,
435
  )
 
439
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
440
 
441
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
442
+
 
 
 
443
 
444
  @torch.inference_mode()
445
  def encode(
446
+ self: 'XLMRobertaModel',
447
  sentences: Union[str, List[str]],
448
  batch_size: int = 32,
449
  show_progress_bar: Optional[bool] = None,
450
+ output_value: str = 'sentence_embedding',
451
  convert_to_numpy: bool = True,
452
  convert_to_tensor: bool = False,
453
  device: Optional[torch.device] = None,
454
+ normalize_embeddings: bool = False,
455
  truncate_dim: Optional[int] = None,
 
 
456
  **tokenizer_kwargs,
457
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
458
  """
 
478
  Overwrites any setting from convert_to_numpy
479
  device(`torch.device`, *optional*, defaults to None):
480
  Which torch.device to use for the computation
481
+ normalize_embeddings(`bool`, *optional*, defaults to False):
482
  If set to true, returned vectors will have length 1. In that case, the
483
  faster dot-product (util.dot_score) instead of cosine similarity can
484
  be used.
 
491
  If convert_to_tensor, a stacked tensor is returned.
492
  If convert_to_numpy, a numpy matrix is returned.
493
  """
494
+ from transformers import AutoTokenizer
495
+
496
+ self.tokenizer = AutoTokenizer.from_pretrained(
497
+ self.name_or_path, trust_remote_code=True
498
+ )
499
+
500
  is_training = self.training
501
  self.eval()
502
 
 
509
  if convert_to_tensor:
510
  convert_to_numpy = False
511
 
512
+ if output_value != 'sentence_embedding':
513
  convert_to_tensor = False
514
  convert_to_numpy = False
515
 
516
  input_was_string = False
517
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
518
  sentences = [sentences]
519
  input_was_string = True
520
 
 
525
  inverse_permutation = np.argsort(permutation)
526
  sentences = [sentences[idx] for idx in permutation]
527
 
528
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
529
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
530
+ 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
531
  )
532
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
533
 
534
  all_embeddings = []
535
 
 
547
  for i in range_iter:
548
  encoded_input = self.tokenizer(
549
  sentences[i : i + batch_size],
550
+ return_tensors='pt',
551
  **tokenizer_kwargs,
552
  ).to(self.device)
553
+ token_embs = self.forward(**encoded_input)[0]
 
 
 
 
 
554
 
555
  # Accumulate in fp32 to avoid overflow
556
  token_embs = token_embs.float()
557
 
558
+ if output_value == 'token_embeddings':
559
  raise NotImplementedError
560
  elif output_value is None:
561
  raise NotImplementedError
562
  else:
563
+ if self.config.emb_pooler == 'cls':
564
  embeddings = self.cls_pooling(
565
+ token_embs, encoded_input['attention_mask']
566
  )
567
  else:
568
  embeddings = self.mean_pooling(
569
+ token_embs, encoded_input['attention_mask']
570
  )
571
 
572
+ if normalize_embeddings:
573
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
574
+
575
+ if convert_to_numpy:
576
+ embeddings = embeddings.cpu()
577
  all_embeddings.extend(embeddings)
578
 
579
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
 
582
  if truncate_dim:
583
  all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
584
 
 
 
 
 
 
 
585
  if convert_to_tensor:
586
  all_embeddings = torch.stack(all_embeddings)
587
  elif convert_to_numpy:
588
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
589
 
590
  if input_was_string:
591
  all_embeddings = all_embeddings[0]
 
593
  self.train(is_training)
594
  return all_embeddings
595
 
596
+
597
  def truncate_embeddings(self, embeddings, truncate_dim):
598
  if not self.config.matryoshka_dimensions:
599
  logger.warning(
600
+ 'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
601
  )
602
  return embeddings
603
  elif truncate_dim in self.config.matryoshka_dimensions:
604
  return [tensor[:truncate_dim] for tensor in embeddings]
605
  else:
606
+ raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
607
+ f'Supported dimensions are {self.config.matryoshka_dimensions}.')
 
 
608
 
609
  def mean_pooling(
610
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
 
616
  input_mask_expanded.sum(1), min=1e-9
617
  )
618
 
 
 
619
 
620
+ def cls_pooling(
621
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
622
+ ):
623
+ return token_embeddings[:,0]
624
+
 
 
 
 
 
 
 
625
 
626
  def forward(
627
  self,
 
631
  attention_mask=None,
632
  masked_tokens_mask=None,
633
  return_dict=None,
 
634
  **kwargs,
635
  ):
636
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
 
638
  layer output for these tokens.
639
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
640
  """
641
+
642
  if kwargs:
643
  for key, value in kwargs.items():
644
  if value is not None:
645
  logger.warning(
646
+ 'Flash attention implementation does not support kwargs: %s',
647
  key,
648
  )
649
 
 
652
  )
653
 
654
  hidden_states = self.embeddings(
655
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
 
 
 
656
  )
657
  # TD [2022-12:18]: Don't need to force residual in fp32
658
  # BERT puts embedding LayerNorm before embedding dropout.
 
676
  subset_mask = None
677
 
678
  sequence_output = self.encoder(
679
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
 
 
 
 
680
  )
681
 
 
 
 
 
 
 
682
  if masked_tokens_mask is None:
683
  pooled_output = (
684
+ self.pooler(sequence_output) if self.pooler is not None else None
 
 
685
  )
686
  else:
687
  # TD [2022-03-01]: the indexing here is very tricky.
 
695
  pool_input = sequence_output[first_col_mask[subset_mask]]
696
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
697
  pooled_output = (
698
+ self.pooler(pool_input, pool=False) if self.pooler is not None else None
 
 
699
  )
700
 
701
  if not return_dict:
 
704
  return BaseModelOutputWithPoolingAndCrossAttentions(
705
  last_hidden_state=sequence_output,
706
  pooler_output=pooled_output,
 
707
  )
708
 
709
 
 
800
  )
801
 
802
 
803
+ # class XLMRobertaForPreTraining(XLMRobertaPreTrainedModel):
804
+ # def __init__(self, config: XLMRobertaFlashConfig):
805
+ # super().__init__(config)
806
+ # # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
807
+ # # (around 15%) to the classifier heads.
808
+ # self.dense_seq_output = getattr(config, "dense_seq_output", False)
809
+ # # If last_layer_subset, we only need the compute the last layer for a subset of tokens
810
+ # # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
811
+ # self.last_layer_subset = getattr(config, "last_layer_subset", False)
812
+ # if self.last_layer_subset:
813
+ # assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
814
+ # use_xentropy = getattr(config, "use_xentropy", False)
815
+ # if use_xentropy and CrossEntropyLoss is None:
816
+ # raise ImportError("xentropy_cuda is not installed")
817
+ # loss_cls = (
818
+ # nn.CrossEntropyLoss
819
+ # if not use_xentropy
820
+ # else partial(CrossEntropyLoss, inplace_backward=True)
821
+ # )
822
+ #
823
+ # self.xlm = XLMRobertaModel(config)
824
+ # self.cls = XLMRobertaPreTrainingHeads(config)
825
+ # self.mlm_loss = loss_cls(ignore_index=0)
826
+ # self.nsp_loss = loss_cls(ignore_index=-1)
827
+ #
828
+ # # Initialize weights and apply final processing
829
+ # self.apply(partial(_init_weights, initializer_range=config.initializer_range))
830
+ # self.tie_weights()
831
+ #
832
+ # def tie_weights(self):
833
+ # self.cls.predictions.decoder.weight = self.xlm.embeddings.word_embeddings.weight
834
+ #
835
+ # def forward(
836
+ # self,
837
+ # input_ids,
838
+ # position_ids=None,
839
+ # token_type_ids=None,
840
+ # attention_mask=None,
841
+ # labels=None,
842
+ # next_sentence_label=None,
843
+ # ):
844
+ # """
845
+ # If labels are provided, they must be 0 for masked out tokens (as specified in the attention
846
+ # mask).
847
+ # Outputs:
848
+ # if `labels` and `next_sentence_label` are not `None`:
849
+ # Outputs the total_loss which is the sum of the masked language modeling loss and the next
850
+ # sentence classification loss.
851
+ # if `labels` or `next_sentence_label` is `None`:
852
+ # Outputs a tuple comprising
853
+ # - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
854
+ # - the next sentence classification logits of shape [batch_size, 2].
855
+ #
856
+ # """
857
+ # masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
858
+ # outputs = self.xlm(
859
+ # input_ids,
860
+ # position_ids=position_ids,
861
+ # token_type_ids=token_type_ids,
862
+ # attention_mask=attention_mask.bool() if attention_mask is not None else None,
863
+ # masked_tokens_mask=masked_tokens_mask,
864
+ # )
865
+ # sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
866
+ # if self.dense_seq_output and labels is not None:
867
+ # masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
868
+ # if not self.last_layer_subset:
869
+ # sequence_output = index_first_axis(
870
+ # rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
871
+ # )
872
+ # prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
873
+ #
874
+ # total_loss = None
875
+ # if labels is not None and next_sentence_label is not None:
876
+ # if (
877
+ # self.dense_seq_output and labels is not None
878
+ # ): # prediction_scores are already flattened
879
+ # masked_lm_loss = self.mlm_loss(
880
+ # prediction_scores, labels.flatten()[masked_token_idx]
881
+ # )
882
+ # else:
883
+ # masked_lm_loss = self.mlm_loss(
884
+ # rearrange(prediction_scores, "... v -> (...) v"),
885
+ # rearrange(labels, "... -> (...)"),
886
+ # )
887
+ # next_sentence_loss = self.nsp_loss(
888
+ # rearrange(seq_relationship_score, "... t -> (...) t"),
889
+ # rearrange(next_sentence_label, "... -> (...)"),
890
+ # )
891
+ # total_loss = masked_lm_loss.float() + next_sentence_loss.float()
892
+ #
893
+ # return BertForPreTrainingOutput(
894
+ # loss=total_loss,
895
+ # prediction_logits=prediction_scores,
896
+ # seq_relationship_logits=seq_relationship_score,
897
+ # )
898
+
899
+
900
  def remap_state_dict(state_dict, config: PretrainedConfig):
901
  """
902
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
 
1048
  if not last_layer_subset or d != (config.num_hidden_layers - 1):
1049
  Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
1050
  Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
1051
+ state_dict[
1052
+ f"bert.encoder.layers.{d}.attention.self.query.weight"
1053
+ ] = Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
1054
+ state_dict[
1055
+ f"bert.encoder.layers.{d}.attention.self.key.weight"
1056
+ ] = Wqkv_weights[
1057
+ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
1058
+ ]
1059
+ state_dict[
1060
+ f"bert.encoder.layers.{d}.attention.self.value.weight"
1061
+ ] = Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
1062
+ state_dict[
1063
+ f"bert.encoder.layers.{d}.attention.self.query.bias"
1064
+ ] = Wqkv_biases[: Wqkv_biases.shape[0] // 3]
1065
+ state_dict[
1066
+ f"bert.encoder.layers.{d}.attention.self.key.bias"
1067
+ ] = Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
1068
+ state_dict[
1069
+ f"bert.encoder.layers.{d}.attention.self.value.bias"
1070
+ ] = Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
1071
  else:
1072
  Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
1073
  Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
1074
  Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
1075
  Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
1076
+ state_dict[
1077
+ f"bert.encoder.layers.{d}.attention.self.query.weight"
1078
+ ] = Wq_weight
1079
+ state_dict[
1080
+ f"bert.encoder.layers.{d}.attention.self.key.weight"
1081
+ ] = Wkv_weights[: Wkv_weights.shape[0] // 2, :]
1082
+ state_dict[
1083
+ f"bert.encoder.layers.{d}.attention.self.value.weight"
1084
+ ] = Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
1085
  state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
1086
  state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
1087
  : Wkv_biases.shape[0] // 2
1088
  ]
1089
+ state_dict[
1090
+ f"bert.encoder.layers.{d}.attention.self.value.bias"
1091
+ ] = Wkv_biases[Wkv_biases.shape[0] // 2 :]
1092
 
1093
  def inv_key_mapping_ln(key):
1094
  key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
modeling_xlm_roberta_for_glue.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
6
+ from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
7
+
8
+ from .modeling_xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
9
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
10
+
11
+
12
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
13
+ def __init__(self, config: XLMRobertaFlashConfig):
14
+ super().__init__(config)
15
+ self.num_labels = config.num_labels
16
+ self.config = config
17
+
18
+ self.roberta = XLMRobertaModel(config)
19
+ classifier_dropout = (
20
+ config.classifier_dropout
21
+ if config.classifier_dropout is not None
22
+ else config.hidden_dropout_prob
23
+ )
24
+ self.dropout = nn.Dropout(classifier_dropout)
25
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
26
+
27
+ # Initialize weights and apply final processing
28
+ self.post_init()
29
+
30
+
31
+ def forward(
32
+ self,
33
+ input_ids: Optional[torch.Tensor] = None,
34
+ attention_mask: Optional[torch.Tensor] = None,
35
+ token_type_ids: Optional[torch.Tensor] = None,
36
+ position_ids: Optional[torch.Tensor] = None,
37
+ head_mask: Optional[torch.Tensor] = None,
38
+ inputs_embeds: Optional[torch.Tensor] = None,
39
+ labels: Optional[torch.Tensor] = None,
40
+ output_attentions: Optional[bool] = None,
41
+ output_hidden_states: Optional[bool] = None,
42
+ return_dict: Optional[bool] = None,
43
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
44
+ r"""
45
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
46
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
47
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
48
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
49
+ """
50
+ return_dict = (
51
+ return_dict if return_dict is not None else self.config.use_return_dict
52
+ )
53
+
54
+ assert head_mask is None
55
+ assert inputs_embeds is None
56
+ assert output_attentions is None
57
+ assert output_hidden_states is None
58
+ assert return_dict
59
+ outputs = self.roberta(
60
+ input_ids,
61
+ attention_mask=attention_mask,
62
+ token_type_ids=token_type_ids,
63
+ position_ids=position_ids,
64
+ head_mask=head_mask,
65
+ inputs_embeds=inputs_embeds,
66
+ output_attentions=output_attentions,
67
+ output_hidden_states=output_hidden_states,
68
+ return_dict=return_dict,
69
+ )
70
+
71
+ pooled_output = outputs[1]
72
+
73
+ pooled_output = self.dropout(pooled_output)
74
+ logits = self.classifier(pooled_output)
75
+
76
+ loss = None
77
+ if labels is not None:
78
+ if self.config.problem_type is None:
79
+ if self.num_labels == 1:
80
+ self.config.problem_type = "regression"
81
+ elif self.num_labels > 1 and (
82
+ labels.dtype == torch.long or labels.dtype == torch.int
83
+ ):
84
+ self.config.problem_type = "single_label_classification"
85
+ else:
86
+ self.config.problem_type = "multi_label_classification"
87
+
88
+ if self.config.problem_type == "regression":
89
+ loss_fct = MSELoss()
90
+ if self.num_labels == 1:
91
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
92
+ else:
93
+ loss = loss_fct(logits, labels)
94
+ elif self.config.problem_type == "single_label_classification":
95
+ loss_fct = CrossEntropyLoss()
96
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
97
+ elif self.config.problem_type == "multi_label_classification":
98
+ loss_fct = BCEWithLogitsLoss()
99
+ loss = loss_fct(logits, labels)
100
+ if not return_dict:
101
+ output = (logits,) + outputs[2:]
102
+ return ((loss,) + output) if loss is not None else output
103
+
104
+ return SequenceClassifierOutput(
105
+ loss=loss,
106
+ logits=logits,
107
+ hidden_states=outputs.hidden_states,
108
+ attentions=outputs.attentions,
109
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfa8fa7c7e120199548fe7149512c0adfe58f6bc13ce19f09b895aa25e8af910
3
+ size 1113232188
rotary.py DELETED
@@ -1,659 +0,0 @@
1
- # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py
2
- # Commit id: 3566596ad867ee415dd3c12616dd50c610176f6c
3
- # Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
4
-
5
- # Copyright (c) 2023, Tri Dao.
6
-
7
- from typing import Optional, Tuple, Union
8
-
9
- import torch
10
- from einops import rearrange, repeat
11
-
12
- if torch.cuda.is_available():
13
- try:
14
- from flash_attn.ops.triton.rotary import apply_rotary
15
- except ImportError:
16
-
17
- def apply_rotary(*args, **kwargs):
18
- raise RuntimeError(
19
- "FlashAttention is not installed. To proceed with training, please install FlashAttention. "
20
- "For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
21
- )
22
-
23
-
24
- def rotate_half(x, interleaved=False):
25
- if not interleaved:
26
- x1, x2 = x.chunk(2, dim=-1)
27
- return torch.cat((-x2, x1), dim=-1)
28
- else:
29
- x1, x2 = x[..., ::2], x[..., 1::2]
30
- return rearrange(
31
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
32
- )
33
-
34
-
35
- def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
36
- """
37
- x: (batch_size, seqlen, nheads, headdim)
38
- cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
39
- """
40
- ro_dim = cos.shape[-1] * 2
41
- assert ro_dim <= x.shape[-1]
42
- cos, sin = (
43
- cos[: x.shape[1]],
44
- sin[: x.shape[1]],
45
- )
46
- cos = repeat(
47
- cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
48
- )
49
- sin = repeat(
50
- sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
51
- )
52
- return torch.cat(
53
- [
54
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
55
- x[..., ro_dim:],
56
- ],
57
- dim=-1,
58
- )
59
-
60
-
61
- class ApplyRotaryEmb(torch.autograd.Function):
62
- @staticmethod
63
- def forward(
64
- ctx,
65
- x,
66
- cos,
67
- sin,
68
- interleaved=False,
69
- inplace=False,
70
- seqlen_offsets: Union[int, torch.Tensor] = 0,
71
- cu_seqlens: Optional[torch.Tensor] = None,
72
- max_seqlen: Optional[int] = None,
73
- ):
74
- out = apply_rotary(
75
- x,
76
- cos,
77
- sin,
78
- seqlen_offsets=seqlen_offsets,
79
- cu_seqlens=cu_seqlens,
80
- max_seqlen=max_seqlen,
81
- interleaved=interleaved,
82
- inplace=inplace,
83
- )
84
-
85
- if isinstance(seqlen_offsets, int):
86
- ctx.save_for_backward(
87
- cos, sin, cu_seqlens
88
- ) # Can't save int with save_for_backward
89
- ctx.seqlen_offsets = seqlen_offsets
90
- else:
91
- ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
92
- ctx.seqlen_offsets = None
93
- ctx.interleaved = interleaved
94
- ctx.inplace = inplace
95
- ctx.max_seqlen = max_seqlen
96
- return out if not inplace else x
97
-
98
- @staticmethod
99
- def backward(ctx, do):
100
- seqlen_offsets = ctx.seqlen_offsets
101
- if seqlen_offsets is None:
102
- cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
103
- else:
104
- cos, sin, cu_seqlens = ctx.saved_tensors
105
- # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
106
- # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
107
- if not ctx.interleaved and not ctx.inplace:
108
- do = do.clone()
109
-
110
- dx = apply_rotary(
111
- do,
112
- cos,
113
- sin,
114
- seqlen_offsets=seqlen_offsets,
115
- cu_seqlens=cu_seqlens,
116
- max_seqlen=ctx.max_seqlen,
117
- interleaved=ctx.interleaved,
118
- inplace=ctx.inplace,
119
- conjugate=True,
120
- )
121
- return dx, None, None, None, None, None, None, None
122
-
123
-
124
- def apply_rotary_emb(
125
- x,
126
- cos,
127
- sin,
128
- interleaved=False,
129
- inplace=False,
130
- seqlen_offsets: Union[int, torch.Tensor] = 0,
131
- cu_seqlens: Optional[torch.Tensor] = None,
132
- max_seqlen: Optional[int] = None,
133
- ):
134
- """
135
- Arguments:
136
- x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
137
- else (total_seqlen, nheads, headdim)
138
- cos, sin: (seqlen_rotary, rotary_dim / 2)
139
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
140
- of 1st half and 2nd half (GPT-NeoX style).
141
- inplace: if True, apply rotary embedding in-place.
142
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
143
- Most commonly used in inference when we have KV cache.
144
- cu_seqlens: (batch + 1,) or None
145
- max_seqlen: int
146
- Return:
147
- out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
148
- else (total_seqlen, nheads, headdim)
149
- rotary_dim must be <= headdim
150
- Apply rotary embedding to the first rotary_dim of x.
151
- """
152
- return ApplyRotaryEmb.apply(
153
- x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
154
- )
155
-
156
-
157
- # For backward compatibility
158
- apply_rotary_emb_func = apply_rotary_emb
159
-
160
-
161
- class ApplyRotaryEmbQKV_(torch.autograd.Function):
162
- @staticmethod
163
- def forward(
164
- ctx,
165
- qkv,
166
- cos,
167
- sin,
168
- cos_k=None,
169
- sin_k=None,
170
- interleaved=False,
171
- seqlen_offsets: Union[int, torch.Tensor] = 0,
172
- cu_seqlens: Optional[torch.Tensor] = None,
173
- max_seqlen: Optional[int] = None,
174
- use_flash_attn: bool = True,
175
- ):
176
- # batch, seqlen, three, nheads, headdim = qkv.shape
177
- assert qkv.shape[-3] == 3
178
- if cos_k is None and sin_k is None and qkv.is_contiguous():
179
-
180
- if use_flash_attn:
181
- # Call 1 kernel instead of 2 kernels
182
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
183
- # dimensions, we get the same tensor
184
- qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
185
- # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
186
- apply_rotary(
187
- qk,
188
- cos,
189
- sin,
190
- seqlen_offsets=seqlen_offsets,
191
- interleaved=interleaved,
192
- inplace=True,
193
- cu_seqlens=cu_seqlens,
194
- max_seqlen=max_seqlen,
195
- )
196
- else:
197
- q_rot = apply_rotary_emb_torch(
198
- qkv[:, :, 0],
199
- cos,
200
- sin,
201
- interleaved=interleaved,
202
- )
203
- k_rot = apply_rotary_emb_torch(
204
- qkv[:, :, 1],
205
- cos,
206
- sin,
207
- interleaved=interleaved,
208
- )
209
- qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
210
- else:
211
- cos_k = cos if cos_k is None else cos_k
212
- sin_k = sin if sin_k is None else sin_k
213
- q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
214
- apply_rotary(
215
- q,
216
- cos,
217
- sin,
218
- seqlen_offsets,
219
- interleaved=interleaved,
220
- inplace=True,
221
- cu_seqlens=cu_seqlens,
222
- max_seqlen=max_seqlen,
223
- )
224
- apply_rotary(
225
- k,
226
- cos_k,
227
- sin_k,
228
- seqlen_offsets,
229
- interleaved=interleaved,
230
- inplace=True,
231
- cu_seqlens=cu_seqlens,
232
- max_seqlen=max_seqlen,
233
- )
234
- ctx.save_for_backward(cos, sin, cos_k, sin_k)
235
- if isinstance(seqlen_offsets, int):
236
- ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens)
237
- ctx.seqlen_offsets = seqlen_offsets
238
- else:
239
- ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets)
240
- ctx.seqlen_offsets = None
241
- ctx.max_seqlen = max_seqlen
242
- ctx.interleaved = interleaved
243
- return qkv
244
-
245
- @staticmethod
246
- def backward(ctx, dqkv):
247
- seqlen_offsets = ctx.seqlen_offsets
248
- if seqlen_offsets is None:
249
- cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
250
- else:
251
- cos, sin, cos_k, sin_k, cu_seqlens = ctx.saved_tensors
252
- if cos_k is None and sin_k is None and dqkv.is_contiguous():
253
- # Call 1 kernel instead of 2 kernels
254
- # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
255
- # dimensions, we get the same tensor
256
- dqk = rearrange(dqkv[..., :2, :, :], "... t h d -> ... (t h) d")
257
- apply_rotary(
258
- dqk,
259
- cos,
260
- sin,
261
- seqlen_offsets=seqlen_offsets,
262
- interleaved=ctx.interleaved,
263
- inplace=True,
264
- conjugate=True,
265
- cu_seqlens=cu_seqlens,
266
- max_seqlen=ctx.max_seqlen,
267
- )
268
- else:
269
- cos_k = cos if cos_k is None else cos_k
270
- sin_k = sin if sin_k is None else sin_k
271
- dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
272
- apply_rotary(
273
- dq,
274
- cos,
275
- sin,
276
- seqlen_offsets,
277
- interleaved=ctx.interleaved,
278
- inplace=True,
279
- conjugate=True,
280
- cu_seqlens=cu_seqlens,
281
- max_seqlen=ctx.max_seqlen,
282
- )
283
- apply_rotary(
284
- dk,
285
- cos_k,
286
- sin_k,
287
- seqlen_offsets,
288
- interleaved=ctx.interleaved,
289
- inplace=True,
290
- conjugate=True,
291
- cu_seqlens=cu_seqlens,
292
- max_seqlen=ctx.max_seqlen,
293
- )
294
- return dqkv, None, None, None, None, None, None, None, None, None
295
-
296
-
297
- def apply_rotary_emb_qkv_(
298
- qkv,
299
- cos,
300
- sin,
301
- cos_k=None,
302
- sin_k=None,
303
- interleaved=False,
304
- seqlen_offsets: Union[int, torch.Tensor] = 0,
305
- cu_seqlens: Optional[torch.Tensor] = None,
306
- max_seqlen: Optional[int] = None,
307
- use_flash_attn=True,
308
- ):
309
- """
310
- Arguments:
311
- qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
312
- else (total_seqlen, 3, nheads, headdim)
313
- cos, sin: (seqlen, rotary_dim / 2)
314
- cos_k, sin_k: (seqlen, rotary_dim / 2), optional
315
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
316
- 1st half and 2nd half (GPT-NeoX style).
317
- seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
318
- Most commonly used in inference when we have KV cache.
319
- cu_seqlens: (batch + 1,) or None
320
- max_seqlen: int
321
- Return:
322
- qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
323
- else (total_seqlen, 3, nheads, headdim)
324
- rotary_dim must be <= headdim
325
- Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
326
- """
327
- return ApplyRotaryEmbQKV_.apply(
328
- qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
329
- )
330
-
331
-
332
- class ApplyRotaryEmbKV_(torch.autograd.Function):
333
- @staticmethod
334
- def forward(
335
- ctx,
336
- kv,
337
- cos,
338
- sin,
339
- interleaved=False,
340
- seqlen_offsets: Union[int, torch.Tensor] = 0,
341
- cu_seqlens: Optional[torch.Tensor] = None,
342
- max_seqlen: Optional[int] = None,
343
- ):
344
- # batch, seqlen, two, nheads, headdim = kv.shape
345
- assert kv.shape[-3] == 2
346
- k = kv[..., 0, :, :]
347
- apply_rotary(
348
- k,
349
- cos,
350
- sin,
351
- seqlen_offsets=seqlen_offsets,
352
- interleaved=interleaved,
353
- inplace=True,
354
- cu_seqlens=cu_seqlens,
355
- max_seqlen=max_seqlen,
356
- )
357
- if isinstance(seqlen_offsets, int):
358
- ctx.save_for_backward(
359
- cos, sin, cu_seqlens
360
- ) # Can't save int with save_for_backward
361
- ctx.seqlen_offsets = seqlen_offsets
362
- else:
363
- ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
364
- ctx.seqlen_offsets = None
365
- ctx.max_seqlen = max_seqlen
366
- ctx.interleaved = interleaved
367
- return kv
368
-
369
- @staticmethod
370
- def backward(ctx, dkv):
371
- seqlen_offsets = ctx.seqlen_offsets
372
- if seqlen_offsets is None:
373
- cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
374
- else:
375
- cos, sin, cu_seqlens = ctx.saved_tensors
376
- apply_rotary(
377
- dkv[..., 0, :, :],
378
- cos,
379
- sin,
380
- seqlen_offsets=seqlen_offsets,
381
- interleaved=ctx.interleaved,
382
- inplace=True,
383
- conjugate=True,
384
- cu_seqlens=cu_seqlens,
385
- max_seqlen=ctx.max_seqlen,
386
- )
387
- return dkv, None, None, None, None, None, None
388
-
389
-
390
- apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
391
-
392
-
393
- def apply_rotary_emb_kv_(
394
- kv,
395
- cos,
396
- sin,
397
- interleaved=False,
398
- seqlen_offsets: Union[int, torch.Tensor] = 0,
399
- cu_seqlens: Optional[torch.Tensor] = None,
400
- max_seqlen: Optional[int] = None,
401
- ):
402
- """
403
- Arguments:
404
- kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
405
- else (total_seqlen, 2, nheads, headdim)
406
- cos, sin: (seqlen, rotary_dim / 2)
407
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
408
- 1st half and 2nd half (GPT-NeoX style).
409
- seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
410
- Most commonly used in inference when we have KV cache.
411
- cu_seqlens: (batch + 1,) or None
412
- max_seqlen: int
413
- Return:
414
- kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
415
- else (total_seqlen, 2, nheads, headdim)
416
- rotary_dim must be <= headdim
417
- Apply rotary embedding *inplace* to the first rotary_dim of K.
418
- """
419
- return ApplyRotaryEmbKV_.apply(
420
- kv, cos, sin, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
421
- )
422
-
423
-
424
- class RotaryEmbedding(torch.nn.Module):
425
- """
426
- The rotary position embeddings from RoFormer_ (Su et. al).
427
- A crucial insight from the method is that the query and keys are
428
- transformed by rotation matrices which depend on the relative positions.
429
-
430
- Other implementations are available in the Rotary Transformer repo_ and in
431
- GPT-NeoX_, GPT-NeoX was an inspiration
432
-
433
- .. _RoFormer: https://arxiv.org/abs/2104.09864
434
- .. _repo: https://github.com/ZhuiyiTechnology/roformer
435
- .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
436
-
437
- If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
438
- A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
439
- Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
440
- """
441
-
442
- def __init__(
443
- self,
444
- dim: int,
445
- base=10000.0,
446
- interleaved=False,
447
- scale_base=None,
448
- pos_idx_in_fp32=True,
449
- device=None,
450
- use_flash_attn=True,
451
- ):
452
- """
453
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
454
- of 1st half and 2nd half (GPT-NeoX style).
455
- pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
456
- otherwise they might be in lower precision.
457
- This option was added because previously (before 2023-07-02), when we construct
458
- the position indices, we use the dtype of self.inv_freq. In most cases this would
459
- be fp32, but if the model is trained in pure bf16 (not mixed precision), then
460
- self.inv_freq would be bf16, and the position indices are also in bf16.
461
- Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
462
- embeddings for some positions will coincide.
463
- To maintain compatibility with models previously trained in pure bf16,
464
- we add this option.
465
- """
466
- super().__init__()
467
- self.dim = dim
468
- self._base = float(base)
469
- self.pos_idx_in_fp32 = pos_idx_in_fp32
470
- self.use_flash_attn = use_flash_attn
471
- # Generate and save the inverse frequency buffer (non trainable)
472
- inv_freq = self._compute_inv_freq(device)
473
- self.register_buffer("inv_freq", inv_freq, persistent=False)
474
- self.interleaved = interleaved
475
- self.scale_base = scale_base
476
- scale = (
477
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
478
- / (1.4 * dim)
479
- if scale_base is not None
480
- else None
481
- )
482
- self.register_buffer("scale", scale, persistent=False)
483
-
484
- self._seq_len_cached = 0
485
- self._cos_cached = None
486
- self._sin_cached = None
487
- self._cos_k_cached = None
488
- self._sin_k_cached = None
489
-
490
- @property
491
- def base(self):
492
- return self._base
493
-
494
- @base.setter
495
- def base(self, new_base):
496
- new_base = float(new_base)
497
- if new_base > 0:
498
- if self._base != new_base: # only update if the base value has changed
499
- self._base = new_base
500
- self._update_cos_sin_cache(
501
- self._seq_len_cached,
502
- device=self.inv_freq.device,
503
- dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
504
- rotary_base_changed=True,
505
- )
506
- else:
507
- raise ValueError("Rotary base value must be positive")
508
-
509
- def _compute_inv_freq(self, device=None):
510
- return 1.0 / (
511
- self.base
512
- ** (
513
- torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
514
- / self.dim
515
- )
516
- )
517
-
518
- def _update_cos_sin_cache(
519
- self, seqlen, device=None, dtype=None, rotary_base_changed=False
520
- ):
521
- # Reset the tables if the sequence length has changed,
522
- # if we're on a new device (possibly due to tracing for instance),
523
- # or if we're switching from inference mode to training
524
- # or if the rotary base value was changed
525
- if (
526
- seqlen > self._seq_len_cached
527
- or self._cos_cached is None
528
- or self._cos_cached.device != device
529
- or self._cos_cached.dtype != dtype
530
- or (self.training and self._cos_cached.is_inference())
531
- or rotary_base_changed
532
- ):
533
- self._seq_len_cached = seqlen
534
- # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
535
- # And the output of arange can be quite large, so bf16 would lose a lot of precision.
536
- # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
537
- if rotary_base_changed:
538
- self.inv_freq = self._compute_inv_freq(device=device)
539
- if self.pos_idx_in_fp32:
540
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
541
- # We want fp32 here as well since inv_freq will be multiplied with t, and the output
542
- # will be large. Having it in bf16 will lose a lot of precision and cause the
543
- # cos & sin output to change significantly.
544
- # We want to recompute self.inv_freq if it was not loaded in fp32
545
- if self.inv_freq.dtype != torch.float32:
546
- inv_freq = self._compute_inv_freq(device=device)
547
- else:
548
- inv_freq = self.inv_freq
549
- else:
550
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
551
- inv_freq = self.inv_freq
552
-
553
- # Don't do einsum, it converts fp32 to fp16 under AMP
554
- # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
555
- freqs = torch.outer(t, inv_freq)
556
- if self.scale is None:
557
- self._cos_cached = torch.cos(freqs).to(dtype)
558
- self._sin_cached = torch.sin(freqs).to(dtype)
559
- else:
560
- power = (
561
- torch.arange(
562
- seqlen, dtype=self.scale.dtype, device=self.scale.device
563
- )
564
- - seqlen // 2
565
- ) / self.scale_base
566
- scale = self.scale.to(device=power.device) ** rearrange(
567
- power, "s -> s 1"
568
- )
569
- # We want the multiplication by scale to happen in fp32
570
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
571
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
572
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
573
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
574
-
575
- def forward(
576
- self,
577
- qkv: torch.Tensor,
578
- kv: Optional[torch.Tensor] = None,
579
- seqlen_offset: Union[int, torch.Tensor] = 0,
580
- cu_seqlens: Optional[torch.Tensor] = None,
581
- max_seqlen: Optional[int] = None,
582
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
583
- """
584
- qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
585
- else it's just q of shape (batch, seqlen, nheads, headdim)
586
- kv: (batch, seqlen, 2, nheads, headdim)
587
- seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
588
- Most commonly used in inference when we have KV cache.
589
- If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
590
- should pass in max_seqlen, which will update the cos / sin cache up to that length.
591
- Apply rotary embedding *inplace* to qkv and / or kv.
592
- """
593
- if cu_seqlens is not None:
594
- assert max_seqlen is not None
595
- seqlen = qkv.shape[1] if max_seqlen is None else max_seqlen
596
- if max_seqlen is not None:
597
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
598
- elif isinstance(seqlen_offset, int):
599
- self._update_cos_sin_cache(
600
- seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype
601
- )
602
- if kv is None:
603
- if self.scale is None:
604
- return apply_rotary_emb_qkv_(
605
- qkv,
606
- self._cos_cached,
607
- self._sin_cached,
608
- interleaved=self.interleaved,
609
- seqlen_offsets=seqlen_offset,
610
- cu_seqlens=cu_seqlens,
611
- max_seqlen=max_seqlen,
612
- use_flash_attn=self.use_flash_attn,
613
- )
614
- else:
615
- return apply_rotary_emb_qkv_(
616
- qkv,
617
- self._cos_cached,
618
- self._sin_cached,
619
- self._cos_k_cached,
620
- self._sin_k_cached,
621
- interleaved=self.interleaved,
622
- seqlen_offsets=seqlen_offset,
623
- cu_seqlens=cu_seqlens,
624
- max_seqlen=max_seqlen,
625
- use_flash_attn=self.use_flash_attn,
626
- )
627
- else:
628
- q = qkv
629
- q = apply_rotary_emb_func(
630
- q,
631
- self._cos_cached,
632
- self._sin_cached,
633
- interleaved=self.interleaved,
634
- inplace=True,
635
- seqlen_offsets=seqlen_offset,
636
- cu_seqlens=cu_seqlens,
637
- max_seqlen=max_seqlen,
638
- )
639
- if self.scale is None:
640
- kv = apply_rotary_emb_kv_(
641
- kv,
642
- self._cos_cached,
643
- self._sin_cached,
644
- interleaved=self.interleaved,
645
- seqlen_offsets=seqlen_offset,
646
- cu_seqlens=cu_seqlens,
647
- max_seqlen=max_seqlen,
648
- )
649
- else:
650
- kv = apply_rotary_emb_kv_(
651
- kv,
652
- self._cos_k_cached,
653
- self._sin_k_cached,
654
- interleaved=self.interleaved,
655
- seqlen_offsets=seqlen_offset,
656
- cu_seqlens=cu_seqlens,
657
- max_seqlen=max_seqlen,
658
- )
659
- return q, kv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stochastic_depth.py CHANGED
@@ -34,7 +34,7 @@
34
 
35
  import torch
36
  import torch.fx
37
- from torch import Tensor, nn
38
 
39
 
40
  def stochastic_depth(
 
34
 
35
  import torch
36
  import torch.fx
37
+ from torch import nn, Tensor
38
 
39
 
40
  def stochastic_depth(
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "model_max_length": 8194,
3
+ "tokenizer_class": "XLMRobertaTokenizer"
4
+ }
xlm_padding.py CHANGED
@@ -18,9 +18,7 @@ class IndexFirstAxis(torch.autograd.Function):
18
  # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
  # return input[indices]
20
  return torch.gather(
21
- rearrange(input, "b ... -> b (...)"),
22
- 0,
23
- repeat(indices, "z -> z d", d=second_dim),
24
  ).reshape(-1, *other_shape)
25
 
26
  @staticmethod
@@ -36,9 +34,7 @@ class IndexFirstAxis(torch.autograd.Function):
36
  )
37
  # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
38
  # grad_input[indices] = grad_output
39
- grad_input.scatter_(
40
- 0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
41
- )
42
  return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
43
 
44
 
@@ -102,7 +98,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
102
  index_first_axis_residual = IndexFirstAxisResidual.apply
103
 
104
 
105
- def unpad_input(hidden_states, attention_mask, adapter_mask=None):
106
  """
107
  Arguments:
108
  hidden_states: (batch, seqlen, ...)
@@ -116,16 +112,7 @@ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
116
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
117
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
118
  max_seqlen_in_batch = seqlens_in_batch.max().item()
119
- cu_seqlens = F.pad(
120
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
121
- )
122
-
123
- cu_adapter_mask = (
124
- torch.repeat_interleave(adapter_mask, cu_seqlens[1:] - cu_seqlens[:-1])
125
- if adapter_mask is not None
126
- else None
127
- )
128
-
129
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
130
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
131
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
@@ -136,7 +123,6 @@ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
136
  indices,
137
  cu_seqlens,
138
  max_seqlen_in_batch,
139
- cu_adapter_mask,
140
  )
141
 
142
 
@@ -194,18 +180,14 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
194
  """
195
  length = attention_mask_in_length.sum(dim=-1)
196
  seqlen = attention_mask_in_length.size(-1)
197
- attention_mask_2d = torch.arange(
198
- seqlen, device=length.device, dtype=length.dtype
199
- ).expand(len(length), seqlen) < length.unsqueeze(1)
200
- real_indices_idx = torch.nonzero(
201
- attention_mask_in_length.flatten(), as_tuple=False
202
- ).flatten()
203
  seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
204
  indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
205
  max_seqlen_in_batch = seqlens_in_batch.max().item()
206
- cu_seqlens = F.pad(
207
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
208
- )
209
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
210
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
211
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
@@ -233,4 +215,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
233
  # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
234
  # output[indices] = hidden_states
235
  output = index_put_first_axis(hidden_states, indices, batch * seqlen)
236
- return rearrange(output, "(b s) ... -> b s ...", b=batch)
 
18
  # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
  # return input[indices]
20
  return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
 
 
22
  ).reshape(-1, *other_shape)
23
 
24
  @staticmethod
 
34
  )
35
  # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
36
  # grad_input[indices] = grad_output
37
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
 
 
38
  return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
39
 
40
 
 
98
  index_first_axis_residual = IndexFirstAxisResidual.apply
99
 
100
 
101
+ def unpad_input(hidden_states, attention_mask):
102
  """
103
  Arguments:
104
  hidden_states: (batch, seqlen, ...)
 
112
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
113
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
  max_seqlen_in_batch = seqlens_in_batch.max().item()
115
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
 
 
 
 
 
 
 
116
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
117
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
118
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
 
123
  indices,
124
  cu_seqlens,
125
  max_seqlen_in_batch,
 
126
  )
127
 
128
 
 
180
  """
181
  length = attention_mask_in_length.sum(dim=-1)
182
  seqlen = attention_mask_in_length.size(-1)
183
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
184
+ seqlen) < length.unsqueeze(
185
+ 1)
186
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
 
 
187
  seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
188
  indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
189
  max_seqlen_in_batch = seqlens_in_batch.max().item()
190
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
191
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
192
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
193
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
 
215
  # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
216
  # output[indices] = hidden_states
217
  output = index_put_first_axis(hidden_states, indices, batch * seqlen)
218
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)