fix-use-cpu-on-gpu-machine

#9
README.md CHANGED
@@ -1,114 +1,5 @@
1
- ---
2
- tags:
3
- - transformers
4
- - xlm-roberta
5
- library_name: transformers
6
- license: cc-by-nc-4.0
7
- language:
8
- - multilingual
9
- - af
10
- - am
11
- - ar
12
- - as
13
- - az
14
- - be
15
- - bg
16
- - bn
17
- - br
18
- - bs
19
- - ca
20
- - cs
21
- - cy
22
- - da
23
- - de
24
- - el
25
- - en
26
- - eo
27
- - es
28
- - et
29
- - eu
30
- - fa
31
- - fi
32
- - fr
33
- - fy
34
- - ga
35
- - gd
36
- - gl
37
- - gu
38
- - ha
39
- - he
40
- - hi
41
- - hr
42
- - hu
43
- - hy
44
- - id
45
- - is
46
- - it
47
- - ja
48
- - jv
49
- - ka
50
- - kk
51
- - km
52
- - kn
53
- - ko
54
- - ku
55
- - ky
56
- - la
57
- - lo
58
- - lt
59
- - lv
60
- - mg
61
- - mk
62
- - ml
63
- - mn
64
- - mr
65
- - ms
66
- - my
67
- - ne
68
- - nl
69
- - 'no'
70
- - om
71
- - or
72
- - pa
73
- - pl
74
- - ps
75
- - pt
76
- - ro
77
- - ru
78
- - sa
79
- - sd
80
- - si
81
- - sk
82
- - sl
83
- - so
84
- - sq
85
- - sr
86
- - su
87
- - sv
88
- - sw
89
- - ta
90
- - te
91
- - th
92
- - tl
93
- - tr
94
- - ug
95
- - uk
96
- - ur
97
- - uz
98
- - vi
99
- - xh
100
- - yi
101
- - zh
102
- ---
103
- Core implementation of Jina XLM-RoBERTa
104
 
105
- This implementation is adapted from [XLM-Roberta](https://huggingface.co/docs/transformers/en/model_doc/xlm-roberta). In contrast to the original implementation, this model uses Rotary positional encodings and supports flash-attention 2.
106
-
107
- ### Models that use this implementation
108
-
109
- - [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3)
110
- - [jinaai/jina-colbert-v2](https://huggingface.co/jinaai/jina-colbert-v2)
111
-
112
- ### Converting weights
113
-
114
- Weights from an [original XLMRoberta model](https://huggingface.co/FacebookAI/xlm-roberta-large) can be converted using the `convert_roberta_weights_to_flash.py` script in the model repository.
 
1
+ # Converting Weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ ```
4
+ python3 -m "xlm-roberta-flash-implementation".convert_roberta_weights_to_flash --output pytorch_model_xlmr_flash.bin
5
+ ```
 
 
 
 
 
 
 
block.py CHANGED
@@ -8,14 +8,15 @@ from typing import Optional
8
 
9
  import torch
10
  import torch.nn as nn
 
11
  from torch import Tensor
12
 
 
13
  from .mha import MHA
14
  from .mlp import Mlp
15
- from .stochastic_depth import StochasticDepth
16
 
17
  try:
18
- from flash_attn.ops.triton.layer_norm import RMSNorm, layer_norm_fn
19
  except ImportError:
20
  layer_norm_fn, RMSNorm = None, None
21
 
@@ -232,9 +233,7 @@ class Block(nn.Module):
232
  is_rms_norm=isinstance(self.norm1, RMSNorm),
233
  )
234
  if not isinstance(self.mlp, nn.Identity):
235
- mlp_out = self.mlp(
236
- hidden_states, adapter_mask=mixer_kwargs.get("adapter_mask")
237
- )
238
  if self.return_residual: # mlp out is actually a pair here
239
  mlp_out, hidden_states = mlp_out
240
  if not self.fused_dropout_add_ln:
 
8
 
9
  import torch
10
  import torch.nn as nn
11
+ import torch.nn.functional as F
12
  from torch import Tensor
13
 
14
+ from .stochastic_depth import StochasticDepth
15
  from .mha import MHA
16
  from .mlp import Mlp
 
17
 
18
  try:
19
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
20
  except ImportError:
21
  layer_norm_fn, RMSNorm = None, None
22
 
 
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
+ mlp_out = self.mlp(hidden_states)
 
 
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
4
+ "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
5
+ "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
6
+ "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM",
7
+ "AutoModelForSequenceClassification":"modeling_xlm_roberta.XLMRobertaForSequenceClassification"
8
+ },
9
+ "architectures": [
10
+ "XLMRobertaModel"
11
+ ],
12
+ "attention_probs_dropout_prob": 0.1,
13
+ "bos_token_id": 0,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "gelu",
16
+ "hidden_dropout_prob": 0.1,
17
+ "hidden_size": 768,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 3072,
20
+ "layer_norm_eps": 1e-05,
21
+ "max_position_embeddings": 514,
22
+ "num_attention_heads": 12,
23
+ "num_hidden_layers": 12,
24
+ "output_past": true,
25
+ "pad_token_id": 1,
26
+ "position_embedding_type": "absolute",
27
+ "transformers_version": "4.17.0.dev0",
28
+ "type_vocab_size": 1,
29
+ "use_cache": false,
30
+ "vocab_size": 250002
31
+ }
configuration_xlm_roberta.py CHANGED
@@ -1,94 +1,36 @@
1
- from typing import Any, Dict, List, Optional, Union
2
-
3
- import torch
4
  from transformers import PretrainedConfig
5
-
6
 
7
  class XLMRobertaFlashConfig(PretrainedConfig):
8
-
9
- model_type = "xlm-roberta"
10
-
11
  def __init__(
12
- self,
13
- vocab_size: int = 250002,
14
- hidden_size: int = 1024,
15
- num_hidden_layers: int = 24,
16
- num_attention_heads: int = 16,
17
- intermediate_size: int = 4096,
18
- hidden_act: str = "gelu",
19
- hidden_dropout_prob: float = 0.1,
20
- attention_probs_dropout_prob: float = 0.1,
21
- max_position_embeddings: int = 8194,
22
- type_vocab_size: int = 1,
23
- initializer_range: float = 0.02,
24
- layer_norm_eps: float = 1e-05,
25
- pad_token_id: int = 1,
26
- bos_token_id: int = 0,
27
- eos_token_id: int = 2,
28
- position_embedding_type: str = "rotary",
29
- rotary_emb_base: float = 10000.0,
30
- use_cache: bool = True,
31
- use_reentrant: bool = False,
32
- classifier_dropout: Optional[float] = None,
33
- lora_adaptations: Optional[List[str]] = None,
34
- task_instructions: Optional[Dict[str, str]] = None,
35
- lora_rank: int = 4,
36
- lora_dropout_p: float = 0.0,
37
- lora_alpha: int = 1,
38
- lora_main_params_trainable: bool = False,
39
- load_trained_adapters: bool = False,
40
- use_flash_attn: bool = True,
41
- torch_dtype: Optional[Union[str, torch.dtype]] = None,
42
- emb_pooler: Optional[str] = None,
43
- matryoshka_dimensions: Optional[List[int]] = None,
44
- truncate_dim: Optional[int] = None,
45
- **kwargs: Dict[str, Any],
46
  ):
47
- """
48
- Initialize the XLMRobertaFlashConfig configuration.
49
 
50
- Args:
51
- vocab_size (int): Size of the vocabulary.
52
- hidden_size (int): Dimensionality of the encoder layers and the pooler layer.
53
- num_hidden_layers (int): Number of hidden layers in the Transformer encoder.
54
- num_attention_heads (int): Number of attention heads for each attention layer in the Transformer encoder.
55
- intermediate_size (int): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer.
56
- hidden_act (str): The activation function to use.
57
- hidden_dropout_prob (float): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
58
- attention_probs_dropout_prob (float): The dropout ratio for the attention probabilities.
59
- max_position_embeddings (int): The maximum length of the position embeddings.
60
- type_vocab_size (int): The vocabulary size of the token type ids.
61
- initializer_range (float): The standard deviation for initializing all weight matrices.
62
- layer_norm_eps (float): The epsilon used by the layer normalization layers.
63
- pad_token_id (int): The ID of the padding token.
64
- bos_token_id (int): The ID of the beginning-of-sequence token.
65
- eos_token_id (int): The ID of the end-of-sequence token.
66
- position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
67
- rotary_emb_base (float): Base for rotary embeddings.
68
- use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
69
- use_reentrant (bool): Whether or not the model should enable the 'use_reentrant' flag in gradient checkpointing.
70
- classifier_dropout (Optional[float]): The dropout ratio for the classification head.
71
- lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
72
- lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
73
- lora_rank (int): Rank for LoRA adaptations.
74
- lora_dropout_p (float): Dropout probability for LoRA adaptations.
75
- lora_alpha (int): Alpha parameter for LoRA.
76
- lora_main_params_trainable (bool): Whether to make the main model parameters trainable when using LoRA.
77
- load_trained_adapters (bool): Whether to load trained adapters.
78
- use_flash_attn (bool): Whether to use FlashAttention.
79
- torch_dtype (Optional[Union[str, torch.dtype]]): Data type for the tensors.
80
- emb_pooler (Optional[str]): Pooling layer configuration.
81
- matryoshka_dimensions (Optional[List[int]]): Configuration for matryoshka dimension reduction.
82
- truncate_dim (Optional[int]): Dimension to truncate embeddings to, if any.
83
- **kwargs (Dict[str, Any]): Additional keyword arguments passed to the configuration.
84
- """
85
-
86
- super().__init__(
87
- pad_token_id=pad_token_id,
88
- bos_token_id=bos_token_id,
89
- eos_token_id=eos_token_id,
90
- **kwargs,
91
- )
92
 
93
  self.vocab_size = vocab_size
94
  self.hidden_size = hidden_size
@@ -103,28 +45,13 @@ class XLMRobertaFlashConfig(PretrainedConfig):
103
  self.initializer_range = initializer_range
104
  self.layer_norm_eps = layer_norm_eps
105
  self.position_embedding_type = position_embedding_type
106
- self.rotary_emb_base = rotary_emb_base
107
  self.use_cache = use_cache
108
- self.use_reentrant = use_reentrant
109
  self.classifier_dropout = classifier_dropout
 
110
  self.load_trained_adapters = load_trained_adapters
111
- self.lora_adaptations = lora_adaptations
112
- self.task_instructions = task_instructions
113
- self.lora_rank = lora_rank
114
- self.lora_dropout_p = lora_dropout_p
115
- self.lora_alpha = lora_alpha
116
- self.lora_main_params_trainable = lora_main_params_trainable
117
  self.use_flash_attn = use_flash_attn
118
  self.emb_pooler = emb_pooler
119
- self.matryoshka_dimensions = matryoshka_dimensions
120
- self.truncate_dim = truncate_dim
121
- if (
122
- torch_dtype
123
- and hasattr(torch, torch_dtype)
124
- and type(getattr(torch, torch_dtype)) is torch.dtype
125
- ):
126
  self.torch_dtype = getattr(torch, torch_dtype)
127
  else:
128
  self.torch_dtype = torch_dtype
129
- if not self.use_flash_attn or not torch.cuda.is_available():
130
- self.torch_dtype = torch.float32
 
 
 
 
1
  from transformers import PretrainedConfig
2
+ import torch
3
 
4
  class XLMRobertaFlashConfig(PretrainedConfig):
 
 
 
5
  def __init__(
6
+ self,
7
+ vocab_size=30522,
8
+ hidden_size=768,
9
+ num_hidden_layers=12,
10
+ num_attention_heads=12,
11
+ intermediate_size=3072,
12
+ hidden_act="gelu",
13
+ hidden_dropout_prob=0.1,
14
+ attention_probs_dropout_prob=0.1,
15
+ max_position_embeddings=512,
16
+ type_vocab_size=2,
17
+ initializer_range=0.02,
18
+ layer_norm_eps=1e-12,
19
+ pad_token_id=1,
20
+ bos_token_id=0,
21
+ eos_token_id=2,
22
+ position_embedding_type="absolute",
23
+ use_cache=True,
24
+ classifier_dropout=None,
25
+ num_loras=1,
26
+ load_trained_adapters=False,
27
+ use_flash_attn=True,
28
+ torch_dtype=None,
29
+ emb_pooler=None,
30
+ **kwargs,
 
 
 
 
 
 
 
 
 
31
  ):
32
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  self.vocab_size = vocab_size
36
  self.hidden_size = hidden_size
 
45
  self.initializer_range = initializer_range
46
  self.layer_norm_eps = layer_norm_eps
47
  self.position_embedding_type = position_embedding_type
 
48
  self.use_cache = use_cache
 
49
  self.classifier_dropout = classifier_dropout
50
+ self.num_loras = num_loras
51
  self.load_trained_adapters = load_trained_adapters
 
 
 
 
 
 
52
  self.use_flash_attn = use_flash_attn
53
  self.emb_pooler = emb_pooler
54
+ if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
 
 
 
 
 
 
55
  self.torch_dtype = getattr(torch, torch_dtype)
56
  else:
57
  self.torch_dtype = torch_dtype
 
 
embedding.py CHANGED
@@ -5,8 +5,10 @@
5
 
6
  import torch
7
  import torch.nn as nn
8
- from transformers.models.xlm_roberta.modeling_xlm_roberta import \
9
- create_position_ids_from_input_ids
 
 
10
 
11
 
12
  class XLMRobertaEmbeddings(nn.Module):
@@ -36,60 +38,25 @@ class XLMRobertaEmbeddings(nn.Module):
36
  max_position_embeddings, embed_dim, **factory_kwargs
37
  )
38
  if self.type_vocab_size > 0:
39
- self.token_type_embeddings = nn.Embedding(
40
- type_vocab_size, embed_dim, **factory_kwargs
41
- )
42
 
43
- def forward(
44
- self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None
45
- ):
46
  """
47
  input_ids: (batch, seqlen)
48
  position_ids: (batch, seqlen)
49
  token_type_ids: (batch, seqlen)
50
- adapter_mask: (batch, 1)
51
  """
52
  batch_size, seqlen = input_ids.shape
53
- if adapter_mask is not None:
54
- unique_tasks = torch.unique(adapter_mask)
55
- embedding_dtype = next(self.word_embeddings.parameters()).dtype
56
- embeddings = torch.empty(
57
- *input_ids.shape,
58
- self.word_embeddings.embedding_dim,
59
- dtype=embedding_dtype,
60
- device=input_ids.device
61
- )
62
- for task_id in unique_tasks:
63
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
64
- task_input_ids = input_ids[task_indices]
65
- task_embeddings = self.word_embeddings(task_input_ids, task_id=task_id)
66
- embeddings[task_indices] = task_embeddings
67
- else:
68
- embeddings = self.word_embeddings(input_ids)
69
  if self.max_position_embeddings > 0:
70
  if position_ids is None:
71
- position_ids = create_position_ids_from_input_ids(
72
- input_ids, padding_idx=self.word_embeddings.padding_idx
73
- ).to(input_ids.device)
74
  position_embeddings = self.position_embeddings(position_ids)
75
  embeddings = embeddings + position_embeddings
76
  if self.type_vocab_size > 0:
77
  if token_type_ids is None:
78
- token_type_ids = torch.zeros(
79
- seqlen, dtype=torch.long, device=input_ids.device
80
- )
81
-
82
- if adapter_mask is not None:
83
- unique_tasks = torch.unique(adapter_mask)
84
- for task_id in unique_tasks:
85
- task_token_type_embeddings = self.token_type_embeddings(
86
- token_type_ids, task_id=task_id
87
- )
88
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
89
- embeddings[task_indices] = (
90
- embeddings[task_indices] + task_token_type_embeddings
91
- )
92
- else:
93
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
94
- embeddings = embeddings + token_type_embeddings
95
  return embeddings
 
5
 
6
  import torch
7
  import torch.nn as nn
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
12
 
13
 
14
  class XLMRobertaEmbeddings(nn.Module):
 
38
  max_position_embeddings, embed_dim, **factory_kwargs
39
  )
40
  if self.type_vocab_size > 0:
41
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
 
 
42
 
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
 
 
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
 
48
  """
49
  batch_size, seqlen = input_ids.shape
50
+ embeddings = self.word_embeddings(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if self.max_position_embeddings > 0:
52
  if position_ids is None:
53
+ position_ids =create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
54
+ # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
 
55
  position_embeddings = self.position_embeddings(position_ids)
56
  embeddings = embeddings + position_embeddings
57
  if self.type_vocab_size > 0:
58
  if token_type_ids is None:
59
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
61
+ embeddings = embeddings + token_type_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  return embeddings
mha.py CHANGED
@@ -1,6 +1,5 @@
1
  # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
2
  # Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
3
- # Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
4
 
5
  # Copyright (c) 2023, Tri Dao.
6
 
@@ -12,23 +11,27 @@ import torch.nn as nn
12
  from einops import rearrange, repeat
13
 
14
  try:
15
- from flash_attn import (flash_attn_kvpacked_func,
16
- flash_attn_qkvpacked_func,
17
- flash_attn_varlen_kvpacked_func,
18
- flash_attn_varlen_qkvpacked_func,
19
- flash_attn_with_kvcache)
 
 
20
  except ImportError:
21
  flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
  flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
  flash_attn_with_kvcache = None
24
 
25
  try:
26
- from flash_attn.ops.fused_dense import (ColumnParallelLinear, FusedDense,
27
- RowParallelLinear)
28
  except ImportError:
29
  FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
30
 
31
- from .rotary import RotaryEmbedding
 
 
 
32
 
33
 
34
  # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
@@ -44,9 +47,7 @@ def get_alibi_slopes(nheads):
44
  closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
45
  return (
46
  get_slopes_power_of_2(closest_power_of_2)
47
- + get_alibi_slopes(2 * closest_power_of_2)[0::2][
48
- : nheads - closest_power_of_2
49
- ]
50
  )
51
 
52
 
@@ -71,9 +72,7 @@ class FlashSelfAttention(nn.Module):
71
  deterministic=False,
72
  ):
73
  super().__init__()
74
- assert (
75
- flash_attn_varlen_qkvpacked_func is not None
76
- ), "FlashAttention is not installed"
77
  assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
78
  self.causal = causal
79
  self.softmax_scale = softmax_scale
@@ -153,9 +152,7 @@ class FlashCrossAttention(nn.Module):
153
  deterministic=False,
154
  ):
155
  super().__init__()
156
- assert (
157
- flash_attn_varlen_kvpacked_func is not None
158
- ), "FlashAttention is not installed"
159
  assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
160
  self.causal = causal
161
  self.softmax_scale = softmax_scale
@@ -321,10 +318,7 @@ class CrossAttention(nn.Module):
321
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
322
  if key_padding_mask is not None:
323
  padding_mask = torch.full(
324
- (batch_size, seqlen_k),
325
- -10000.0,
326
- dtype=scores.dtype,
327
- device=scores.device,
328
  )
329
  padding_mask.masked_fill_(key_padding_mask, 0.0)
330
  # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
@@ -436,26 +430,20 @@ class MHA(nn.Module):
436
  else:
437
  alibi_slopes = None
438
  if window_size != (-1, -1):
439
- assert (
440
- use_flash_attn
441
- ), "Local (sliding window) attention code path requires flash_attn"
442
 
443
  self.num_heads = num_heads
444
  self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
445
  assert (
446
  self.num_heads % self.num_heads_kv == 0
447
  ), "num_heads must be divisible by num_heads_kv"
448
- assert (
449
- self.embed_dim % num_heads == 0
450
- ), "embed_dim must be divisible by num_heads"
451
  self.head_dim = self.embed_dim // num_heads
452
  qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
453
  kv_dim = 2 * self.head_dim * self.num_heads_kv
454
 
455
  if self.rotary_emb_dim > 0:
456
- assert (
457
- not cross_attn
458
- ), "MHA with rotary embedding does not support cross-attention yet"
459
  assert RotaryEmbedding is not None, "rotary_emb is not installed"
460
  self.rotary_emb = RotaryEmbedding(
461
  self.rotary_emb_dim,
@@ -463,41 +451,29 @@ class MHA(nn.Module):
463
  scale_base=rotary_emb_scale_base,
464
  interleaved=rotary_emb_interleaved,
465
  device=device,
466
- use_flash_attn=use_flash_attn,
467
  )
468
 
469
  if fused_bias_fc and FusedDense is None:
470
  raise ImportError("fused_dense is not installed")
471
-
472
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
473
  linear_resid_cls = (
474
- LinearResidual
475
- if not fused_bias_fc
476
- else partial(FusedDense, return_residual=True)
477
  )
478
  wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
479
  inner_attn_cls = (
480
- partial(
481
- FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size
482
- )
483
  if use_flash_attn
484
  else SelfAttention
485
  )
486
  inner_cross_attn_cls = (
487
- partial(
488
- FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size
489
- )
490
  if use_flash_attn
491
  else CrossAttention
492
  )
493
  if not self.cross_attn:
494
- self.Wqkv = wqkv_cls(
495
- embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs
496
- )
497
  else:
498
- self.Wq = linear_cls(
499
- embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs
500
- )
501
  self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
502
  if self.dwconv:
503
  if self.num_heads_kv == self.num_heads:
@@ -508,9 +484,7 @@ class MHA(nn.Module):
508
  self.dwconv_q = nn.Conv1d(
509
  embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
510
  )
511
- self.dwconv_kv = nn.Conv1d(
512
- kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim
513
- )
514
  self.inner_attn = inner_attn_cls(
515
  causal=causal,
516
  softmax_scale=softmax_scale,
@@ -519,9 +493,7 @@ class MHA(nn.Module):
519
  self.inner_cross_attn = inner_cross_attn_cls(
520
  causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
521
  )
522
- self.out_proj = linear_cls(
523
- embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs
524
- )
525
 
526
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
527
  dtype = self.out_proj.weight.dtype if dtype is None else dtype
@@ -539,9 +511,7 @@ class MHA(nn.Module):
539
  def _update_kv_cache(self, kv, inference_params):
540
  """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
541
  assert not self.dwconv, "Generation does not support dwconv yet"
542
- assert (
543
- self.layer_idx is not None
544
- ), "Generation requires layer_idx in the constructor"
545
  return _update_kv_cache(kv, inference_params, self.layer_idx)
546
 
547
  def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
@@ -557,10 +527,7 @@ class MHA(nn.Module):
557
  self.rotary_emb._update_cos_sin_cache(
558
  inference_params.max_seqlen, device=q.device, dtype=q.dtype
559
  )
560
- rotary_cos, rotary_sin = (
561
- self.rotary_emb._cos_cached,
562
- self.rotary_emb._sin_cached,
563
- )
564
  else:
565
  rotary_cos, rotary_sin = None, None
566
  batch = q.shape[0]
@@ -582,9 +549,7 @@ class MHA(nn.Module):
582
  cache_seqlens=cache_seqlens,
583
  softmax_scale=self.inner_cross_attn.softmax_scale,
584
  causal=self.inner_cross_attn.causal,
585
- rotary_interleaved=(
586
- self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False
587
- ),
588
  alibi_slopes=alibi_slopes,
589
  )
590
  return context
@@ -629,7 +594,6 @@ class MHA(nn.Module):
629
  max_seqlen=None,
630
  mixer_subset=None,
631
  inference_params=None,
632
- adapter_mask=None,
633
  **kwargs,
634
  ):
635
  """
@@ -655,6 +619,7 @@ class MHA(nn.Module):
655
  assert key_padding_mask is None
656
  assert self.use_flash_attn
657
  assert not self.dwconv
 
658
  if key_padding_mask is not None:
659
  assert cu_seqlens is None
660
  assert max_seqlen is None
@@ -678,50 +643,19 @@ class MHA(nn.Module):
678
  else inference_params.seqlen_offset
679
  )
680
  )
681
- rotary_max_seqlen = (
682
- inference_params.max_sequence_len
683
- if inference_params is not None
684
- else max_seqlen
685
- )
686
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
687
  assert x_kv is None and mixer_subset is None
688
-
689
- if adapter_mask is not None:
690
- unique_tasks = torch.unique(adapter_mask)
691
- qkv_dtype = next(self.Wqkv.parameters()).dtype
692
- qkv = torch.empty(
693
- *x.shape[:-1],
694
- self.Wqkv.out_features,
695
- dtype=qkv_dtype,
696
- device=x.device,
697
- )
698
- for task_id in unique_tasks:
699
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
700
- task_tensor = x[task_indices]
701
- if not self.return_residual:
702
- task_qkv = self.Wqkv(task_tensor, task_id=task_id)
703
- else:
704
- task_qkv, _ = self.Wqkv(
705
- task_tensor, task_id=task_id, residual=True
706
- )
707
- qkv[task_indices] = task_qkv
708
  else:
709
- if not self.return_residual:
710
- qkv = self.Wqkv(x)
711
- else:
712
- if hasattr(self.Wqkv, "parametrizations"):
713
- qkv, x = self.Wqkv(x, residual=True)
714
- else:
715
- qkv, x = self.Wqkv(x)
716
-
717
  if self.dwconv:
718
  qkv = rearrange(
719
- self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
720
- "b d s -> b s d",
721
  ).contiguous()
722
- qkv = rearrange(
723
- qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
724
- )
725
  if (
726
  inference_params is None
727
  or inference_params.seqlen_offset == 0
@@ -730,18 +664,13 @@ class MHA(nn.Module):
730
  ):
731
  if self.rotary_emb_dim > 0:
732
  qkv = self.rotary_emb(
733
- qkv,
734
- seqlen_offset=seqlen_offset,
735
- cu_seqlens=cu_seqlens,
736
- max_seqlen=rotary_max_seqlen,
737
  )
738
  if inference_params is None:
739
  if not self.checkpointing:
740
  context = self.inner_attn(qkv, **kwargs)
741
  else:
742
- context = torch.utils.checkpoint.checkpoint(
743
- self.inner_attn, qkv, **kwargs
744
- )
745
  else:
746
  context = self._update_kvcache_attention(
747
  qkv[:, :, 0], qkv[:, :, 1:], inference_params
@@ -770,17 +699,13 @@ class MHA(nn.Module):
770
  q = qkv[..., : self.num_heads * self.head_dim]
771
  kv = qkv[..., self.num_heads * self.head_dim :]
772
  q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
773
- kv = rearrange(
774
- kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim
775
- )
776
  if self.dwconv:
777
  q = rearrange(
778
- self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
779
- "b d s -> b s d",
780
  ).contiguous()
781
  kv = rearrange(
782
- self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
783
- "b d s -> b s d",
784
  ).contiguous()
785
  if (
786
  inference_params is None
@@ -790,11 +715,7 @@ class MHA(nn.Module):
790
  ):
791
  if self.rotary_emb_dim > 0:
792
  q, kv = self.rotary_emb(
793
- q,
794
- kv,
795
- seqlen_offset=seqlen_offset,
796
- cu_seqlens=cu_seqlens,
797
- max_seqlen=rotary_max_seqlen,
798
  )
799
  if inference_params is None:
800
  if not self.checkpointing:
@@ -806,25 +727,7 @@ class MHA(nn.Module):
806
  else:
807
  context = self._update_kvcache_attention(q, kv, inference_params)
808
  else:
809
- context = self._apply_rotary_update_kvcache_attention(
810
- q, kv, inference_params
811
- )
812
-
813
- inp = rearrange(context, "... h d -> ... (h d)")
814
- if adapter_mask is not None:
815
- unique_tasks = torch.unique(adapter_mask)
816
- out_dtype = next(self.out_proj.parameters()).dtype
817
- out = torch.empty(
818
- *inp.shape[:-1],
819
- self.out_proj.out_features,
820
- dtype=out_dtype,
821
- device=inp.device,
822
- )
823
- for task_id in unique_tasks:
824
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
825
- task_tensor = inp[task_indices]
826
- task_out = self.out_proj(task_tensor, task_id=task_id)
827
- out[task_indices] = task_out
828
- else:
829
- out = self.out_proj(inp)
830
  return out if not self.return_residual else (out, x)
 
 
1
  # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
2
  # Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
 
3
 
4
  # Copyright (c) 2023, Tri Dao.
5
 
 
11
  from einops import rearrange, repeat
12
 
13
  try:
14
+ from flash_attn import (
15
+ flash_attn_kvpacked_func,
16
+ flash_attn_qkvpacked_func,
17
+ flash_attn_varlen_kvpacked_func,
18
+ flash_attn_varlen_qkvpacked_func,
19
+ flash_attn_with_kvcache,
20
+ )
21
  except ImportError:
22
  flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
23
  flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
24
  flash_attn_with_kvcache = None
25
 
26
  try:
27
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
 
28
  except ImportError:
29
  FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
30
 
31
+ try:
32
+ from flash_attn.layers.rotary import RotaryEmbedding
33
+ except ImportError:
34
+ RotaryEmbedding = None
35
 
36
 
37
  # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
 
47
  closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
48
  return (
49
  get_slopes_power_of_2(closest_power_of_2)
50
+ + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
 
 
51
  )
52
 
53
 
 
72
  deterministic=False,
73
  ):
74
  super().__init__()
75
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
 
 
76
  assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
77
  self.causal = causal
78
  self.softmax_scale = softmax_scale
 
152
  deterministic=False,
153
  ):
154
  super().__init__()
155
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
 
 
156
  assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
157
  self.causal = causal
158
  self.softmax_scale = softmax_scale
 
318
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
319
  if key_padding_mask is not None:
320
  padding_mask = torch.full(
321
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
 
 
 
322
  )
323
  padding_mask.masked_fill_(key_padding_mask, 0.0)
324
  # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
 
430
  else:
431
  alibi_slopes = None
432
  if window_size != (-1, -1):
433
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
 
 
434
 
435
  self.num_heads = num_heads
436
  self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
437
  assert (
438
  self.num_heads % self.num_heads_kv == 0
439
  ), "num_heads must be divisible by num_heads_kv"
440
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
 
 
441
  self.head_dim = self.embed_dim // num_heads
442
  qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
443
  kv_dim = 2 * self.head_dim * self.num_heads_kv
444
 
445
  if self.rotary_emb_dim > 0:
446
+ assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
 
 
447
  assert RotaryEmbedding is not None, "rotary_emb is not installed"
448
  self.rotary_emb = RotaryEmbedding(
449
  self.rotary_emb_dim,
 
451
  scale_base=rotary_emb_scale_base,
452
  interleaved=rotary_emb_interleaved,
453
  device=device,
 
454
  )
455
 
456
  if fused_bias_fc and FusedDense is None:
457
  raise ImportError("fused_dense is not installed")
 
458
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
459
  linear_resid_cls = (
460
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
 
 
461
  )
462
  wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
463
  inner_attn_cls = (
464
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
 
 
465
  if use_flash_attn
466
  else SelfAttention
467
  )
468
  inner_cross_attn_cls = (
469
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
 
 
470
  if use_flash_attn
471
  else CrossAttention
472
  )
473
  if not self.cross_attn:
474
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
 
 
475
  else:
476
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
 
 
477
  self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
478
  if self.dwconv:
479
  if self.num_heads_kv == self.num_heads:
 
484
  self.dwconv_q = nn.Conv1d(
485
  embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
486
  )
487
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
 
 
488
  self.inner_attn = inner_attn_cls(
489
  causal=causal,
490
  softmax_scale=softmax_scale,
 
493
  self.inner_cross_attn = inner_cross_attn_cls(
494
  causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
495
  )
496
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
 
 
497
 
498
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
499
  dtype = self.out_proj.weight.dtype if dtype is None else dtype
 
511
  def _update_kv_cache(self, kv, inference_params):
512
  """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
513
  assert not self.dwconv, "Generation does not support dwconv yet"
514
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
 
 
515
  return _update_kv_cache(kv, inference_params, self.layer_idx)
516
 
517
  def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
 
527
  self.rotary_emb._update_cos_sin_cache(
528
  inference_params.max_seqlen, device=q.device, dtype=q.dtype
529
  )
530
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
 
 
 
531
  else:
532
  rotary_cos, rotary_sin = None, None
533
  batch = q.shape[0]
 
549
  cache_seqlens=cache_seqlens,
550
  softmax_scale=self.inner_cross_attn.softmax_scale,
551
  causal=self.inner_cross_attn.causal,
552
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
 
 
553
  alibi_slopes=alibi_slopes,
554
  )
555
  return context
 
594
  max_seqlen=None,
595
  mixer_subset=None,
596
  inference_params=None,
 
597
  **kwargs,
598
  ):
599
  """
 
619
  assert key_padding_mask is None
620
  assert self.use_flash_attn
621
  assert not self.dwconv
622
+ assert self.rotary_emb_dim == 0
623
  if key_padding_mask is not None:
624
  assert cu_seqlens is None
625
  assert max_seqlen is None
 
643
  else inference_params.seqlen_offset
644
  )
645
  )
646
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
647
+ batch, seqlen = x.shape[:2]
 
 
 
648
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
649
  assert x_kv is None and mixer_subset is None
650
+ if not self.return_residual:
651
+ qkv = self.Wqkv(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  else:
653
+ qkv, x = self.Wqkv(x)
 
 
 
 
 
 
 
654
  if self.dwconv:
655
  qkv = rearrange(
656
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
657
  ).contiguous()
658
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
 
 
659
  if (
660
  inference_params is None
661
  or inference_params.seqlen_offset == 0
 
664
  ):
665
  if self.rotary_emb_dim > 0:
666
  qkv = self.rotary_emb(
667
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
 
 
 
668
  )
669
  if inference_params is None:
670
  if not self.checkpointing:
671
  context = self.inner_attn(qkv, **kwargs)
672
  else:
673
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
 
 
674
  else:
675
  context = self._update_kvcache_attention(
676
  qkv[:, :, 0], qkv[:, :, 1:], inference_params
 
699
  q = qkv[..., : self.num_heads * self.head_dim]
700
  kv = qkv[..., self.num_heads * self.head_dim :]
701
  q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
702
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
 
 
703
  if self.dwconv:
704
  q = rearrange(
705
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
706
  ).contiguous()
707
  kv = rearrange(
708
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
709
  ).contiguous()
710
  if (
711
  inference_params is None
 
715
  ):
716
  if self.rotary_emb_dim > 0:
717
  q, kv = self.rotary_emb(
718
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
 
 
 
 
719
  )
720
  if inference_params is None:
721
  if not self.checkpointing:
 
727
  else:
728
  context = self._update_kvcache_attention(q, kv, inference_params)
729
  else:
730
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
731
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
  return out if not self.return_residual else (out, x)
733
+
mlp.py CHANGED
@@ -8,14 +8,14 @@ import torch.nn as nn
8
  import torch.nn.functional as F
9
  from torch.distributed import ProcessGroup
10
 
 
11
  try:
12
  from flash_attn.ops.activations import swiglu
13
  except ImportError:
14
  swiglu = None
15
 
16
  try:
17
- from flash_attn.ops.fused_dense import (ColumnParallelLinear,
18
- RowParallelLinear)
19
  except ImportError:
20
  ColumnParallelLinear, RowParallelLinear = None, None
21
 
@@ -41,48 +41,17 @@ class Mlp(nn.Module):
41
  factory_kwargs = {"device": device, "dtype": dtype}
42
  super().__init__()
43
  out_features = out_features if out_features is not None else in_features
44
- hidden_features = (
45
- hidden_features if hidden_features is not None else in_features * 4
46
- )
47
  self.return_residual = return_residual
48
  self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
49
  self.activation = activation
50
- self.fc2 = nn.Linear(
51
- hidden_features, out_features, bias=bias2, **factory_kwargs
52
- )
53
-
54
- def forward(self, x, adapter_mask=None):
55
- if adapter_mask is not None:
56
- unique_tasks = torch.unique(adapter_mask)
57
- fc1_dtype = next(self.fc1.parameters()).dtype
58
- y = torch.empty(
59
- *x.shape[:-1], self.fc1.out_features, dtype=fc1_dtype, device=x.device
60
- )
61
- for task_id in unique_tasks:
62
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
63
- task_tensor = x[task_indices]
64
- task_y = self.fc1(task_tensor, task_id=task_id)
65
- y[task_indices] = task_y
66
- else:
67
- y = self.fc1(x)
68
 
 
 
69
  y = self.activation(y)
70
-
71
- if adapter_mask is not None:
72
- unique_tasks = torch.unique(adapter_mask)
73
- fc2_dtype = next(self.fc2.parameters()).dtype
74
- out = torch.empty(
75
- *y.shape[:-1], self.fc2.out_features, dtype=fc2_dtype, device=y.device
76
- )
77
- for task_id in unique_tasks:
78
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
79
- task_tensor = y[task_indices]
80
- task_out = self.fc2(task_tensor, task_id=task_id)
81
- out[task_indices] = task_out
82
- else:
83
- out = self.fc2(y)
84
-
85
- return out if not self.return_residual else (out, x)
86
 
87
 
88
  class ParallelMLP(nn.Module):
@@ -104,9 +73,7 @@ class ParallelMLP(nn.Module):
104
  assert ColumnParallelLinear is not None, "Need to install fused_dense"
105
  assert RowParallelLinear is not None, "Need to install fused_dense"
106
  out_features = out_features if out_features is not None else in_features
107
- hidden_features = (
108
- hidden_features if hidden_features is not None else in_features * 4
109
- )
110
  self.fc1 = ColumnParallelLinear(
111
  in_features,
112
  hidden_features,
@@ -152,25 +119,17 @@ class GatedMlp(nn.Module):
152
  hidden_features = (
153
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
154
  )
155
- hidden_features = (
156
- (hidden_features + multiple_of - 1) // multiple_of * multiple_of
157
- )
158
  self.return_residual = return_residual
159
- self.fc1 = nn.Linear(
160
- in_features, 2 * hidden_features, bias=bias1, **factory_kwargs
161
- )
162
  self.activation = activation
163
- self.fc2 = nn.Linear(
164
- hidden_features, out_features, bias=bias2, **factory_kwargs
165
- )
166
 
167
  def forward(self, x):
168
  y = self.fc1(x)
169
  if self.activation == F.sigmoid: # Special case for GLU
170
  y = F.glu(y, dim=-1)
171
- elif (
172
- self.activation == F.silu and swiglu is not None
173
- ): # Special case for SwiGLU
174
  y, gate = y.chunk(2, dim=-1)
175
  y = swiglu(gate, y)
176
  else:
@@ -203,9 +162,7 @@ class ParallelGatedMlp(nn.Module):
203
  hidden_features = (
204
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
205
  )
206
- hidden_features = (
207
- (hidden_features + multiple_of - 1) // multiple_of * multiple_of
208
- )
209
  if ColumnParallelLinear is None or RowParallelLinear is None:
210
  raise ImportError("fused_dense is not installed")
211
  self.fc1 = ColumnParallelLinear(
@@ -234,4 +191,4 @@ class ParallelGatedMlp(nn.Module):
234
  y, gate = y.chunk(2, dim=-1)
235
  y = y * self.activation(gate)
236
  y = self.fc2(y)
237
- return y
 
8
  import torch.nn.functional as F
9
  from torch.distributed import ProcessGroup
10
 
11
+
12
  try:
13
  from flash_attn.ops.activations import swiglu
14
  except ImportError:
15
  swiglu = None
16
 
17
  try:
18
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
 
19
  except ImportError:
20
  ColumnParallelLinear, RowParallelLinear = None, None
21
 
 
41
  factory_kwargs = {"device": device, "dtype": dtype}
42
  super().__init__()
43
  out_features = out_features if out_features is not None else in_features
44
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
 
 
45
  self.return_residual = return_residual
46
  self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
47
  self.activation = activation
48
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ def forward(self, x):
51
+ y = self.fc1(x)
52
  y = self.activation(y)
53
+ y = self.fc2(y)
54
+ return y if not self.return_residual else (y, x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  class ParallelMLP(nn.Module):
 
73
  assert ColumnParallelLinear is not None, "Need to install fused_dense"
74
  assert RowParallelLinear is not None, "Need to install fused_dense"
75
  out_features = out_features if out_features is not None else in_features
76
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
 
 
77
  self.fc1 = ColumnParallelLinear(
78
  in_features,
79
  hidden_features,
 
119
  hidden_features = (
120
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
121
  )
122
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
 
 
123
  self.return_residual = return_residual
124
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
 
 
125
  self.activation = activation
126
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
 
 
127
 
128
  def forward(self, x):
129
  y = self.fc1(x)
130
  if self.activation == F.sigmoid: # Special case for GLU
131
  y = F.glu(y, dim=-1)
132
+ elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
 
 
133
  y, gate = y.chunk(2, dim=-1)
134
  y = swiglu(gate, y)
135
  else:
 
162
  hidden_features = (
163
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
164
  )
165
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
 
 
166
  if ColumnParallelLinear is None or RowParallelLinear is None:
167
  raise ImportError("fused_dense is not installed")
168
  self.fc1 = ColumnParallelLinear(
 
191
  y, gate = y.chunk(2, dim=-1)
192
  y = y * self.activation(gate)
193
  y = self.fc2(y)
194
+ return y
modeling_lora.py CHANGED
@@ -1,29 +1,22 @@
1
  import math
2
  import os
3
  from functools import partial
4
- from typing import Iterator, List, Optional, Tuple, Union
5
 
6
- import numpy as np
7
  import torch
8
  import torch.nn.utils.parametrize as parametrize
9
  from torch import nn
10
  from torch.nn import Parameter
11
- from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
14
- from .configuration_xlm_roberta import XLMRobertaFlashConfig
15
- from .modeling_xlm_roberta import (
16
- XLMRobertaFlashConfig,
17
- XLMRobertaModel,
18
- XLMRobertaPreTrainedModel,
19
- )
20
 
21
 
22
  def initialized_weights(
23
- shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
24
  ) -> torch.Tensor:
25
  weight_data = []
26
- for _ in range(num_adaptations):
27
  new_adaption = torch.zeros(shape)
28
  if init == "kaiming":
29
  nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
@@ -52,16 +45,15 @@ class LoRAParametrization(nn.Module):
52
  WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
53
  SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
54
  """
55
-
56
  def __init__(
57
  self,
58
  fan_in: int,
59
  fan_out: int,
60
  layer_type: str = "linear",
61
- num_adaptations: int = 1,
62
  rank: int = 4,
63
- dropout_p: float = 0.0,
64
- alpha: float = 1,
65
  ):
66
  super().__init__()
67
  # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
@@ -71,41 +63,46 @@ class LoRAParametrization(nn.Module):
71
 
72
  if layer_type == "linear":
73
  self.lora_A = nn.Parameter(
74
- initialized_weights((rank, fan_in), num_adaptations, init="kaiming")
75
  )
76
- self.lora_B = nn.Parameter(torch.zeros((num_adaptations, fan_out, rank)))
77
  elif layer_type == "embedding":
78
- self.lora_A = nn.Parameter(torch.zeros((num_adaptations, fan_in, rank)))
79
  self.lora_B = nn.Parameter(
80
  initialized_weights(
81
- (rank, fan_out), num_adaptations=num_adaptations, init="normal"
82
  )
83
  )
84
  else:
85
  raise NotImplementedError
86
 
87
- self.lora_alpha, self.rank = alpha, rank
88
- self.scaling = alpha / rank
89
- self.lora_dropout = nn.Dropout(p=dropout_p) if dropout_p > 0 else lambda x: x
90
- self.dropout_fn = self._dropout if dropout_p > 0 else lambda x: x
 
 
91
  self.register_buffer(
92
  "lora_dropout_mask",
93
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
94
  persistent=False,
95
  )
 
 
96
 
97
  def _dropout(self, A):
98
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
99
  return A * self.lora_dropout(self.lora_dropout_mask)
100
 
101
- def lora_forward(self, X, current_task):
 
102
  return (
103
  X
104
  + torch.matmul(
105
  *self.swap(
106
  (
107
- self.lora_B[current_task],
108
- self.dropout_fn(self.lora_A[current_task]),
109
  )
110
  )
111
  ).view(X.shape)
@@ -113,191 +110,115 @@ class LoRAParametrization(nn.Module):
113
  )
114
 
115
  def forward(self, X):
116
- return X
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  @classmethod
119
  def from_linear(
120
  cls,
121
  layer: nn.Module,
122
- num_adaptations: int,
123
- rank: int,
124
- dropout_p: float,
125
- alpha: float,
126
  ):
127
  assert isinstance(layer, nn.Linear)
128
  fan_out, fan_in = layer.weight.shape
129
  return cls(
130
  fan_in,
131
  fan_out,
132
- num_adaptations=num_adaptations,
133
  layer_type="linear",
134
  rank=rank,
135
- dropout_p=dropout_p,
136
- alpha=alpha,
137
  )
138
 
139
  @classmethod
140
  def from_embedding(
141
- cls,
142
- layer: nn.Module,
143
- num_adaptations: int,
144
- rank: int,
145
- dropout_p: float,
146
- alpha: float,
147
  ):
148
  assert isinstance(layer, nn.Embedding)
149
  fan_in, fan_out = layer.weight.shape
150
  return cls(
151
  fan_in,
152
  fan_out,
153
- num_adaptations=num_adaptations,
154
  layer_type="embedding",
155
  rank=rank,
156
- dropout_p=dropout_p,
157
- alpha=alpha,
158
  )
159
 
160
  @classmethod
161
  def add_to_layer(
162
- cls,
163
- layer: nn.Module,
164
- num_adaptations: int,
165
- rank: int,
166
- dropout_p: float,
167
- alpha: float,
168
  ):
169
- """
170
- Registering LoRA adapters to all embedding and linear layers.
171
- Additionally, we implement a custom forward function for LoRA parametrization.
172
- This function modifies the layer's forward pass to optionally use task-specific
173
- parameters. When a `task_id` is provided, it employs a LoRA parametrization
174
- to modify the original weights according to the specific task. This allows
175
- the layer to adapt dynamically to different tasks at runtime. If no `task_id`
176
- is specified, the layer uses its original weights.
177
- """
178
  if isinstance(layer, nn.Linear):
179
  parametrize.register_parametrization(
180
  layer,
181
  "weight",
182
  cls.from_linear(
183
  layer,
184
- num_adaptations=num_adaptations,
185
  rank=rank,
186
- dropout_p=dropout_p,
187
- alpha=alpha,
188
  ),
189
  )
190
-
191
- def new_forward(self, input, task_id=None, residual=False):
192
- if task_id is not None:
193
- weights = self.parametrizations.weight[0].lora_forward(
194
- self.weight, current_task=task_id
195
- )
196
- else:
197
- weights = self.weight
198
-
199
- out = F.linear(input, weights, self.bias)
200
-
201
- if residual:
202
- return out, input
203
- return out
204
-
205
- layer.forward = new_forward.__get__(layer, layer.__class__)
206
-
207
  elif isinstance(layer, nn.Embedding):
208
  parametrize.register_parametrization(
209
  layer,
210
  "weight",
211
  cls.from_embedding(
212
  layer,
213
- num_adaptations=num_adaptations,
214
  rank=rank,
215
- dropout_p=dropout_p,
216
- alpha=alpha,
217
  ),
218
  )
219
 
220
- def new_forward(self, input, task_id=None):
221
- if task_id is not None:
222
- weights = self.parametrizations.weight[0].lora_forward(
223
- self.weight, current_task=task_id
224
- )
225
- else:
226
- weights = self.weight
227
-
228
- out = F.embedding(
229
- input,
230
- weights,
231
- self.padding_idx,
232
- self.max_norm,
233
- self.norm_type,
234
- self.scale_grad_by_freq,
235
- self.sparse,
236
- )
237
-
238
- return out
239
 
240
- layer.forward = new_forward.__get__(layer, layer.__class__)
 
 
 
 
241
 
242
 
243
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
244
- """
245
- A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
246
- """
247
-
248
- def __init__(
249
- self,
250
- config: XLMRobertaFlashConfig,
251
- roberta: Optional[XLMRobertaModel] = None,
252
- add_pooling_layer: bool = True,
253
- ):
254
  super().__init__(config)
 
255
  if roberta is None:
256
  self.roberta = XLMRobertaModel(config, add_pooling_layer=add_pooling_layer)
257
  else:
258
  self.roberta = roberta
259
 
260
- self._lora_adaptations = config.lora_adaptations
261
- if (
262
- not isinstance(self._lora_adaptations, list)
263
- or len(self._lora_adaptations) < 1
264
- ):
265
- raise ValueError(
266
- f"`lora_adaptations` must be a list and contain at least one element"
267
- )
268
- self._task_instructions = config.task_instructions
269
- if (
270
- not isinstance(self._task_instructions, dict)
271
- or len(self._task_instructions) != len(self._lora_adaptations)
272
- or not all(
273
- [v in self._lora_adaptations for v in self._task_instructions.keys()]
274
- )
275
- ):
276
- raise ValueError(
277
- f"`task_instructions` must be a dict and contain the same number of elements "
278
- f"as `lora_adaptations` with all keys in `task_instructions` present in `lora_adaptations`."
279
- )
280
- self._adaptation_map = {
281
- name: idx for idx, name in enumerate(self._lora_adaptations)
282
- }
283
- self._rank = config.lora_rank
284
- self._dropout_p = config.lora_dropout_p
285
- self._alpha = config.lora_alpha
286
- self._register_lora(
287
- num_adaptations=len(self._lora_adaptations),
288
- rank=self._rank,
289
- dropout_p=self._dropout_p,
290
- alpha=self._alpha,
291
- )
292
- self.main_params_trainable = config.lora_main_params_trainable
293
-
294
- @property
295
- def rotary_emb_base(self):
296
- return self.roberta.rotary_emb_base
297
 
298
- @rotary_emb_base.setter
299
- def rotary_emb_base(self, base):
300
- self.roberta.rotary_emb_base = base
 
301
 
302
  @property
303
  def main_params_trainable(self):
@@ -316,6 +237,13 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
316
  if "lora" not in name:
317
  param.requires_grad_(val)
318
 
 
 
 
 
 
 
 
319
  @classmethod
320
  def from_pretrained(
321
  cls,
@@ -331,44 +259,56 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
331
  use_safetensors: bool = None,
332
  **kwargs,
333
  ):
334
- for key in list(kwargs.keys()):
335
- if key in config.to_dict():
336
- config.update({key: kwargs.pop(key)})
337
- if config.load_trained_adapters: # checkpoint already contains LoRA adapters
338
  return super().from_pretrained(
339
  pretrained_model_name_or_path,
340
  *model_args,
341
- config=config,
342
- cache_dir=cache_dir,
343
- ignore_mismatched_sizes=ignore_mismatched_sizes,
344
- force_download=force_download,
345
- local_files_only=local_files_only,
346
- token=token,
347
- revision=revision,
348
- use_safetensors=use_safetensors,
349
- **kwargs,
350
- )
351
- else: # initializing new adapters
352
- roberta = XLMRobertaModel.from_pretrained(
353
- pretrained_model_name_or_path,
354
- *model_args,
355
- use_flash_attn=config.use_flash_attn,
356
- **kwargs,
357
  )
 
 
358
  return cls(config, roberta=roberta)
359
 
360
- def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
361
  self.apply(
362
  partial(
363
  LoRAParametrization.add_to_layer,
364
- num_adaptations=num_adaptations,
365
  rank=rank,
366
- dropout_p=dropout_p,
367
- alpha=alpha,
368
  )
369
  )
370
 
371
- def forward(self, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  return self.roberta(*args, **kwargs)
373
 
374
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
@@ -383,44 +323,3 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
383
  ):
384
  if "lora" in name or self.main_params_trainable:
385
  yield name, param
386
-
387
- @torch.inference_mode()
388
- def encode(
389
- self,
390
- sentences: Union[str, List[str]],
391
- *args,
392
- task: Optional[str] = None,
393
- **kwargs,
394
- ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
395
- """
396
- Computes sentence embeddings.
397
- sentences(`str` or `List[str]`):
398
- Sentence or sentences to be encoded
399
- task(`str`, *optional*, defaults to `None`):
400
- Specifies the task for which the encoding is intended. If `task` is not provided,
401
- all LoRA adapters are disabled, and the model reverts to its original,
402
- general-purpose weights.
403
- """
404
- if task and task not in self._lora_adaptations:
405
- raise ValueError(
406
- f"Unsupported task '{task}'. "
407
- f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
408
- f"Alternatively, don't pass the `task` argument to disable LoRA."
409
- )
410
- adapter_mask = None
411
- if task:
412
- task_id = self._adaptation_map[task]
413
- num_examples = 1 if isinstance(sentences, str) else len(sentences)
414
- adapter_mask = torch.full(
415
- (num_examples,), task_id, dtype=torch.int32, device=self.device
416
- )
417
- if isinstance(sentences, str):
418
- sentences = self._task_instructions[task] + sentences
419
- else:
420
- sentences = [
421
- self._task_instructions[task] + sentence for sentence in sentences
422
- ]
423
- return self.roberta.encode(
424
- sentences, *args, adapter_mask=adapter_mask, **kwargs
425
- )
426
-
 
1
  import math
2
  import os
3
  from functools import partial
4
+ from typing import Iterator, Optional, Tuple, Union
5
 
 
6
  import torch
7
  import torch.nn.utils.parametrize as parametrize
8
  from torch import nn
9
  from torch.nn import Parameter
 
10
  from transformers import PretrainedConfig
11
 
12
+ from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaPreTrainedModel, XLMRobertaFlashConfig
 
 
 
 
 
13
 
14
 
15
  def initialized_weights(
16
+ shape: Tuple[int], num_adaptions: int, init: str = "kaiming"
17
  ) -> torch.Tensor:
18
  weight_data = []
19
+ for _ in range(num_adaptions):
20
  new_adaption = torch.zeros(shape)
21
  if init == "kaiming":
22
  nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
 
45
  WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
46
  SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
47
  """
 
48
  def __init__(
49
  self,
50
  fan_in: int,
51
  fan_out: int,
52
  layer_type: str = "linear",
53
+ num_adaptions: int = 1,
54
  rank: int = 4,
55
+ lora_dropout_p: float = 0.0,
56
+ lora_alpha: float = 1,
57
  ):
58
  super().__init__()
59
  # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
 
63
 
64
  if layer_type == "linear":
65
  self.lora_A = nn.Parameter(
66
+ initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
67
  )
68
+ self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
69
  elif layer_type == "embedding":
70
+ self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
71
  self.lora_B = nn.Parameter(
72
  initialized_weights(
73
+ (rank, fan_out), num_adaptions=num_adaptions, init="normal"
74
  )
75
  )
76
  else:
77
  raise NotImplementedError
78
 
79
+ self.lora_alpha, self.rank = lora_alpha, rank
80
+ self.scaling = lora_alpha / rank
81
+ self.lora_dropout = (
82
+ nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
83
+ )
84
+ self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
85
  self.register_buffer(
86
  "lora_dropout_mask",
87
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
88
  persistent=False,
89
  )
90
+ self.forward_fn = lambda x: x
91
+ self.current_task = None
92
 
93
  def _dropout(self, A):
94
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
95
  return A * self.lora_dropout(self.lora_dropout_mask)
96
 
97
+ def lora_forward(self, X):
98
+ assert self.current_task is not None
99
  return (
100
  X
101
  + torch.matmul(
102
  *self.swap(
103
  (
104
+ self.lora_B[self.current_task],
105
+ self.dropout_fn(self.lora_A[self.current_task]),
106
  )
107
  )
108
  ).view(X.shape)
 
110
  )
111
 
112
  def forward(self, X):
113
+ return self.forward_fn(X)
114
+
115
+ @property
116
+ def current_task(self):
117
+ return self._current_task
118
+
119
+ @current_task.setter
120
+ def current_task(self, task: Union[None, int]):
121
+ self._current_task = task
122
+ if task is None:
123
+ self.forward_fn = lambda x: x
124
+ else:
125
+ self.forward_fn = self.lora_forward
126
 
127
  @classmethod
128
  def from_linear(
129
  cls,
130
  layer: nn.Module,
131
+ num_adaptions: int = 1,
132
+ rank: int = 4,
133
+ lora_dropout_p: float = 0.0,
134
+ lora_alpha: int = 1,
135
  ):
136
  assert isinstance(layer, nn.Linear)
137
  fan_out, fan_in = layer.weight.shape
138
  return cls(
139
  fan_in,
140
  fan_out,
141
+ num_adaptions=num_adaptions,
142
  layer_type="linear",
143
  rank=rank,
144
+ lora_dropout_p=lora_dropout_p,
145
+ lora_alpha=lora_alpha,
146
  )
147
 
148
  @classmethod
149
  def from_embedding(
150
+ cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
 
 
 
 
 
151
  ):
152
  assert isinstance(layer, nn.Embedding)
153
  fan_in, fan_out = layer.weight.shape
154
  return cls(
155
  fan_in,
156
  fan_out,
157
+ num_adaptions=num_adaptions,
158
  layer_type="embedding",
159
  rank=rank,
160
+ lora_dropout_p=lora_dropout_p,
161
+ lora_alpha=lora_alpha,
162
  )
163
 
164
  @classmethod
165
  def add_to_layer(
166
+ cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
 
 
 
 
 
167
  ):
 
 
 
 
 
 
 
 
 
168
  if isinstance(layer, nn.Linear):
169
  parametrize.register_parametrization(
170
  layer,
171
  "weight",
172
  cls.from_linear(
173
  layer,
174
+ num_adaptions=num_adaptions,
175
  rank=rank,
176
+ lora_dropout_p=lora_dropout_p,
177
+ lora_alpha=lora_alpha,
178
  ),
179
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  elif isinstance(layer, nn.Embedding):
181
  parametrize.register_parametrization(
182
  layer,
183
  "weight",
184
  cls.from_embedding(
185
  layer,
186
+ num_adaptions=num_adaptions,
187
  rank=rank,
188
+ lora_dropout_p=lora_dropout_p,
189
+ lora_alpha=lora_alpha,
190
  ),
191
  )
192
 
193
+ @staticmethod
194
+ def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
195
+ if isinstance(layer, LoRAParametrization):
196
+ layer.current_task = task_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ @staticmethod
199
+ def merge_lora_into_layer(layer: nn.Module):
200
+ if hasattr(layer, "parametrizations"):
201
+ for attr_name in layer.parametrizations.keys():
202
+ parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=True)
203
 
204
 
205
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
206
+ def __init__(self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None, add_pooling_layer=True):
 
 
 
 
 
 
 
 
 
207
  super().__init__(config)
208
+
209
  if roberta is None:
210
  self.roberta = XLMRobertaModel(config, add_pooling_layer=add_pooling_layer)
211
  else:
212
  self.roberta = roberta
213
 
214
+ self._is_merged = False
215
+ self._num_adaptions = config.num_loras
216
+ self._register_lora(self._num_adaptions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ self.main_params_trainable = False
219
+ self._task_idx = None
220
+ # By default, we select the first LoRA
221
+ self.current_task = 0
222
 
223
  @property
224
  def main_params_trainable(self):
 
237
  if "lora" not in name:
238
  param.requires_grad_(val)
239
 
240
+ def merge_lora(self):
241
+ """Merges currently selected LoRA into main weights."""
242
+ if self._is_merged:
243
+ raise Exception('LoRA has already been merged, cannot merge again')
244
+ self._is_merged = True
245
+ self.apply(LoRAParametrization.merge_lora_into_layer)
246
+
247
  @classmethod
248
  def from_pretrained(
249
  cls,
 
259
  use_safetensors: bool = None,
260
  **kwargs,
261
  ):
262
+ config = XLMRobertaFlashConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
263
+ if config.load_trained_adapters:
 
 
264
  return super().from_pretrained(
265
  pretrained_model_name_or_path,
266
  *model_args,
267
+ **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  )
269
+ else:
270
+ roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
271
  return cls(config, roberta=roberta)
272
 
273
+ def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
274
  self.apply(
275
  partial(
276
  LoRAParametrization.add_to_layer,
277
+ num_adaptions=num_adaptions,
278
  rank=rank,
279
+ lora_dropout_p=lora_dropout_p,
280
+ lora_alpha=lora_alpha,
281
  )
282
  )
283
 
284
+ @property
285
+ def current_task(self):
286
+ """ Which LoRA is currently selected
287
+ :return: Integer or None (when LoRA is disabled)
288
+ """
289
+ return self._task_idx
290
+
291
+ @current_task.setter
292
+ def current_task(self, task_idx: Union[None, int]):
293
+ """Set the LoRA that is to be used.
294
+ The LoRA is specified by `task_idx`, which may be an integer >= 0,
295
+ indexing the available LoRAs. If it is None, no LoRA is used.
296
+ :param task_idx: Which LoRA to use
297
+ :return:
298
+ """
299
+ if self._is_merged:
300
+ raise Exception('LoRA has been merged, cannot select new task')
301
+ assert task_idx is None or 0 <= task_idx < self._num_adaptions
302
+ if self._task_idx != task_idx:
303
+ # In this case, we need to update the LoRAs everywhere
304
+ self._task_idx = task_idx
305
+ self.apply(
306
+ partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
307
+ )
308
+
309
+ def forward(self, *args, current_task: Union[None, int] = -1, **kwargs):
310
+ if current_task is None or current_task >= 0:
311
+ self.current_task = current_task
312
  return self.roberta(*args, **kwargs)
313
 
314
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
 
323
  ):
324
  if "lora" in name or self.main_params_trainable:
325
  yield name, param
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_xlm_roberta.py CHANGED
@@ -13,30 +13,39 @@ import re
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
16
- from typing import List, Optional, Tuple, Union
17
-
18
  import numpy as np
 
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
  import torch.utils.checkpoint
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
- from transformers import AutoTokenizer, PretrainedConfig
25
- from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
26
  from transformers.modeling_utils import PreTrainedModel
 
 
 
27
  from transformers.models.bert.modeling_bert import (
28
  BaseModelOutputWithPoolingAndCrossAttentions,
29
  BertForPreTrainingOutput,
30
  )
31
- from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
32
 
33
- from .rotary import RotaryEmbedding
34
- from .block import Block
 
 
 
 
 
 
35
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
 
36
  from .embedding import XLMRobertaEmbeddings
37
  from .mha import MHA
38
  from .mlp import FusedMLP, Mlp
39
- from .xlm_padding import index_first_axis_residual, pad_input, unpad_input
 
40
 
41
  try:
42
  from flash_attn.ops.fused_dense import FusedDense
@@ -52,7 +61,7 @@ except ImportError:
52
  try:
53
  from flash_attn.losses.cross_entropy import CrossEntropyLoss
54
  except ImportError:
55
- CrossEntropyLoss = torch.nn.CrossEntropyLoss
56
 
57
  try:
58
  from tqdm.autonotebook import trange
@@ -64,11 +73,13 @@ logger = logging.getLogger(__name__)
64
 
65
 
66
  def get_use_flash_attn(config: XLMRobertaFlashConfig):
67
- if not getattr(config, "use_flash_attn", False) or not torch.cuda.is_available():
 
 
68
  return False
69
  if importlib.util.find_spec("flash_attn") is None:
70
  logger.warning(
71
- "flash_attn is not installed. Using PyTorch native attention implementation."
72
  )
73
  return False
74
  return True
@@ -80,9 +91,9 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
80
  rotary_kwargs = {}
81
  if config.position_embedding_type == "rotary":
82
  rotary_kwargs["rotary_emb_dim"] = getattr(
83
- config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
84
  )
85
- rotary_kwargs["rotary_emb_base"] = config.rotary_emb_base
86
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
87
  config, "rotary_emb_scale_base", None
88
  )
@@ -98,7 +109,6 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
98
  fused_bias_fc=fused_bias_fc,
99
  use_flash_attn=use_flash_attn,
100
  return_residual=return_residual,
101
- use_alibi=config.position_embedding_type == "alibi",
102
  **rotary_kwargs,
103
  )
104
  return mixer_cls
@@ -180,7 +190,6 @@ class XLMRobertaEncoder(nn.Module):
180
  def __init__(self, config: XLMRobertaFlashConfig):
181
  super().__init__()
182
  self.use_flash_attn = get_use_flash_attn(config)
183
- self.use_reentrant = config.use_reentrant
184
  self.layers = nn.ModuleList(
185
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
186
  )
@@ -194,72 +203,46 @@ class XLMRobertaEncoder(nn.Module):
194
  def gradient_checkpointing(self, value):
195
  self._grad_checkpointing = value
196
 
197
- def forward(
198
- self,
199
- hidden_states,
200
- key_padding_mask=None,
201
- subset_mask=None,
202
- adapter_mask=None,
203
- output_hidden_states: Optional[bool] = None,
204
- ):
205
  """If subset_mask is not None, we only want output for the subset of the sequence.
206
  This means that we only compute the last layer output for these tokens.
207
  subset_mask: (batch, seqlen), dtype=torch.bool
208
  """
209
-
210
- all_hidden_states = () if output_hidden_states else None
211
-
212
- if output_hidden_states and subset_mask:
213
- raise ValueError('output_hidden_states is not supported for subset_masks')
214
-
215
  if key_padding_mask is None or not self.use_flash_attn:
216
- mixer_kwargs = {"adapter_mask": adapter_mask}
217
- if key_padding_mask is not None:
218
- mixer_kwargs["key_padding_mask"] = key_padding_mask.bool()
 
 
219
  for layer in self.layers:
220
- if output_hidden_states:
221
- all_hidden_states = all_hidden_states + (hidden_states,)
222
  if self._grad_checkpointing:
223
  hidden_states = torch.utils.checkpoint.checkpoint(
224
  layer,
225
  hidden_states,
226
- use_reentrant=self.use_reentrant,
227
  mixer_kwargs=mixer_kwargs,
228
  )
229
  else:
230
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
231
- if output_hidden_states:
232
- all_hidden_states = all_hidden_states + (hidden_states,)
233
  if subset_mask is not None:
234
  hidden_states = hidden_states[subset_mask]
235
  else:
236
  batch, seqlen = hidden_states.shape[:2]
237
- if output_hidden_states:
238
- all_hidden_states = all_hidden_states + (hidden_states,)
239
- hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = (
240
- unpad_input(hidden_states, key_padding_mask, adapter_mask)
241
  )
242
- mixer_kwargs = {
243
- "cu_seqlens": cu_seqlens,
244
- "max_seqlen": max_seqlen_in_batch,
245
- "adapter_mask": cu_adapter_mask,
246
- }
247
-
248
  if subset_mask is None:
249
  for layer in self.layers:
250
  if self._grad_checkpointing:
251
  hidden_states = torch.utils.checkpoint.checkpoint(
252
  layer,
253
  hidden_states,
254
- use_reentrant=self.use_reentrant,
255
  mixer_kwargs=mixer_kwargs,
256
  )
257
  else:
258
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
259
- if output_hidden_states:
260
- all_hidden_states = all_hidden_states + (
261
- pad_input(hidden_states, indices, batch, seqlen),
262
- )
263
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
264
  else:
265
  for layer in self.layers[:-1]:
@@ -267,7 +250,7 @@ class XLMRobertaEncoder(nn.Module):
267
  hidden_states = torch.utils.checkpoint.checkpoint(
268
  layer,
269
  hidden_states,
270
- use_reentrant=self.use_reentrant,
271
  mixer_kwargs=mixer_kwargs,
272
  )
273
  else:
@@ -305,14 +288,14 @@ class XLMRobertaEncoder(nn.Module):
305
  torch.utils.checkpoint.checkpoint(
306
  self.layers[-1],
307
  hidden_states_subset,
308
- use_reentrant=self.use_reentrant,
309
  mixer_kwargs=mixer_kwargs,
310
  )
311
  else:
312
  hidden_states = self.layers[-1](
313
  hidden_states_subset, mixer_kwargs=mixer_kwargs
314
  )
315
- return all_hidden_states if output_hidden_states else hidden_states
316
 
317
 
318
  class XLMRobertaPooler(nn.Module):
@@ -325,28 +308,11 @@ class XLMRobertaPooler(nn.Module):
325
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
326
  self.activation = nn.Tanh()
327
 
328
- def forward(self, hidden_states, pool=True, adapter_mask=None):
329
  # We "pool" the model by simply taking the hidden state corresponding
330
  # to the first token.
331
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
332
- if adapter_mask is not None:
333
- unique_tasks = torch.unique(adapter_mask)
334
- pool_dtype = next(self.dense.parameters()).dtype
335
- pooled_output = torch.empty(
336
- first_token_tensor.shape[0],
337
- self.dense.out_features,
338
- dtype=pool_dtype,
339
- device=first_token_tensor.device,
340
- )
341
- for task_id in unique_tasks:
342
- task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
343
- task_first_token_tensor = first_token_tensor[task_indices]
344
- task_pooled_output = self.dense(
345
- task_first_token_tensor, task_id=task_id
346
- )
347
- pooled_output[task_indices] = task_pooled_output
348
- else:
349
- pooled_output = self.dense(first_token_tensor)
350
  pooled_output = self.activation(pooled_output)
351
  return pooled_output
352
 
@@ -425,7 +391,6 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
425
  config_class = XLMRobertaFlashConfig
426
  base_model_prefix = "roberta"
427
  supports_gradient_checkpointing = True
428
- _supports_param_buffer_assignment = False
429
 
430
  def _set_gradient_checkpointing(self, module, value=False):
431
  if isinstance(module, XLMRobertaEncoder):
@@ -437,11 +402,12 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
437
  *args,
438
  **kwargs,
439
  ):
440
- if not "torch_dtype" in kwargs:
441
- kwargs["torch_dtype"] = "auto"
442
  return super().from_pretrained(*args, **kwargs)
443
 
444
 
 
445
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
446
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
447
  super().__init__(config)
@@ -459,14 +425,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
459
  "gelu_fast",
460
  "gelu_pytorch_tanh",
461
  ]
 
462
  self.embeddings = XLMRobertaEmbeddings(
463
  config.hidden_size,
464
  config.vocab_size,
465
- (
466
- config.max_position_embeddings
467
- if config.position_embedding_type == "absolute"
468
- else -1
469
- ),
470
  config.type_vocab_size,
471
  padding_idx=config.pad_token_id,
472
  )
@@ -476,25 +439,19 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
476
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
477
 
478
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
479
- self.tokenizer = AutoTokenizer.from_pretrained(
480
- self.name_or_path, trust_remote_code=True
481
- )
482
- self._rotary_emb_base = config.rotary_emb_base
483
 
484
  @torch.inference_mode()
485
  def encode(
486
- self: "XLMRobertaModel",
487
  sentences: Union[str, List[str]],
488
  batch_size: int = 32,
489
  show_progress_bar: Optional[bool] = None,
490
- output_value: str = "sentence_embedding",
491
  convert_to_numpy: bool = True,
492
  convert_to_tensor: bool = False,
493
  device: Optional[torch.device] = None,
494
- normalize_embeddings: bool = True,
495
- truncate_dim: Optional[int] = None,
496
- adapter_mask: Optional[torch.Tensor] = None,
497
- task: Optional[str] = None,
498
  **tokenizer_kwargs,
499
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
500
  """
@@ -520,12 +477,10 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
520
  Overwrites any setting from convert_to_numpy
521
  device(`torch.device`, *optional*, defaults to None):
522
  Which torch.device to use for the computation
523
- normalize_embeddings(`bool`, *optional*, defaults to True):
524
  If set to true, returned vectors will have length 1. In that case, the
525
  faster dot-product (util.dot_score) instead of cosine similarity can
526
  be used.
527
- truncate_dim(`int`, *optional*, defaults to None):
528
- The dimension to truncate sentence embeddings to. `None` does no truncation.
529
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
530
  Keyword arguments for the tokenizer
531
  Returns:
@@ -533,6 +488,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
533
  If convert_to_tensor, a stacked tensor is returned.
534
  If convert_to_numpy, a numpy matrix is returned.
535
  """
 
 
 
 
 
 
536
  is_training = self.training
537
  self.eval()
538
 
@@ -545,12 +506,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
545
  if convert_to_tensor:
546
  convert_to_numpy = False
547
 
548
- if output_value != "sentence_embedding":
549
  convert_to_tensor = False
550
  convert_to_numpy = False
551
 
552
  input_was_string = False
553
- if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
554
  sentences = [sentences]
555
  input_was_string = True
556
 
@@ -561,11 +522,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
561
  inverse_permutation = np.argsort(permutation)
562
  sentences = [sentences[idx] for idx in permutation]
563
 
564
- tokenizer_kwargs["padding"] = tokenizer_kwargs.get("padding", True)
565
- tokenizer_kwargs["max_length"] = tokenizer_kwargs.get(
566
- "max_length", self.tokenizer.init_kwargs.get("model_max_length", 8192)
567
  )
568
- tokenizer_kwargs["truncation"] = tokenizer_kwargs.get("truncation", True)
569
 
570
  all_embeddings = []
571
 
@@ -583,51 +544,41 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
583
  for i in range_iter:
584
  encoded_input = self.tokenizer(
585
  sentences[i : i + batch_size],
586
- return_tensors="pt",
587
  **tokenizer_kwargs,
588
  ).to(self.device)
589
- lora_arguments = (
590
- {"adapter_mask": adapter_mask[i : i + batch_size]}
591
- if adapter_mask is not None
592
- else {}
593
- )
594
- token_embs = self.forward(**encoded_input, **lora_arguments)[0]
595
 
596
  # Accumulate in fp32 to avoid overflow
597
  token_embs = token_embs.float()
598
 
599
- if output_value == "token_embeddings":
600
  raise NotImplementedError
601
  elif output_value is None:
602
  raise NotImplementedError
603
  else:
604
- if self.config.emb_pooler == "cls":
605
  embeddings = self.cls_pooling(
606
- token_embs, encoded_input["attention_mask"]
607
  )
608
  else:
609
  embeddings = self.mean_pooling(
610
- token_embs, encoded_input["attention_mask"]
611
  )
612
 
 
 
 
 
 
613
  all_embeddings.extend(embeddings)
614
 
615
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
616
 
617
- truncate_dim = truncate_dim or self.config.truncate_dim
618
- if truncate_dim:
619
- all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
620
-
621
- if normalize_embeddings:
622
- all_embeddings = [
623
- torch.nn.functional.normalize(embedding, p=2, dim=0)
624
- for embedding in all_embeddings
625
- ]
626
-
627
  if convert_to_tensor:
628
  all_embeddings = torch.stack(all_embeddings)
629
  elif convert_to_numpy:
630
- all_embeddings = np.asarray([emb.cpu().numpy() for emb in all_embeddings])
631
 
632
  if input_was_string:
633
  all_embeddings = all_embeddings[0]
@@ -635,20 +586,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
635
  self.train(is_training)
636
  return all_embeddings
637
 
638
- def truncate_embeddings(self, embeddings, truncate_dim):
639
- if not self.config.matryoshka_dimensions:
640
- logger.warning(
641
- "Matryoshka embeddings are not supported, so dimension truncation will not be performed."
642
- )
643
- return embeddings
644
- elif truncate_dim in self.config.matryoshka_dimensions:
645
- return [tensor[:truncate_dim] for tensor in embeddings]
646
- else:
647
- raise ValueError(
648
- f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
649
- f"Supported dimensions are {self.config.matryoshka_dimensions}."
650
- )
651
-
652
  def mean_pooling(
653
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
654
  ):
@@ -659,21 +596,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
659
  input_mask_expanded.sum(1), min=1e-9
660
  )
661
 
662
- def cls_pooling(self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor):
663
- return token_embeddings[:, 0]
664
 
665
- @property
666
- def rotary_emb_base(self):
667
- return self._rotary_emb_base
668
-
669
- @rotary_emb_base.setter
670
- def rotary_emb_base(self, base):
671
- if not isinstance(base, (int, float)):
672
- raise TypeError("Base must be an integer or float")
673
- logger.info(f"Changing RoPE base value to {base}")
674
- for layer in self.encoder.layers:
675
- layer.mixer.rotary_emb.base = base
676
- self._rotary_emb_base = base
677
 
678
  def forward(
679
  self,
@@ -683,7 +611,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
683
  attention_mask=None,
684
  masked_tokens_mask=None,
685
  return_dict=None,
686
- output_hidden_states=None,
687
  **kwargs,
688
  ):
689
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
@@ -691,12 +618,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
691
  layer output for these tokens.
692
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
693
  """
694
- adapter_mask = kwargs.pop("adapter_mask", None)
695
  if kwargs:
696
  for key, value in kwargs.items():
697
  if value is not None:
698
  logger.warning(
699
- "Flash attention implementation does not support kwargs: %s",
700
  key,
701
  )
702
 
@@ -705,10 +632,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
705
  )
706
 
707
  hidden_states = self.embeddings(
708
- input_ids,
709
- position_ids=position_ids,
710
- token_type_ids=token_type_ids,
711
- adapter_mask=adapter_mask,
712
  )
713
  # TD [2022-12:18]: Don't need to force residual in fp32
714
  # BERT puts embedding LayerNorm before embedding dropout.
@@ -732,24 +656,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
732
  subset_mask = None
733
 
734
  sequence_output = self.encoder(
735
- hidden_states,
736
- key_padding_mask=attention_mask,
737
- subset_mask=subset_mask,
738
- adapter_mask=adapter_mask,
739
- output_hidden_states=output_hidden_states,
740
  )
741
 
742
- if output_hidden_states:
743
- all_hidden_states = sequence_output
744
- sequence_output = sequence_output[-1]
745
- else:
746
- all_hidden_states = None
747
-
748
  if masked_tokens_mask is None:
749
  pooled_output = (
750
- self.pooler(sequence_output, adapter_mask=adapter_mask)
751
- if self.pooler is not None
752
- else None
753
  )
754
  else:
755
  # TD [2022-03-01]: the indexing here is very tricky.
@@ -763,9 +675,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
763
  pool_input = sequence_output[first_col_mask[subset_mask]]
764
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
765
  pooled_output = (
766
- self.pooler(pool_input, pool=False, adapter_mask=adapter_mask)
767
- if self.pooler is not None
768
- else None
769
  )
770
 
771
  if not return_dict:
@@ -774,7 +684,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
774
  return BaseModelOutputWithPoolingAndCrossAttentions(
775
  last_hidden_state=sequence_output,
776
  pooler_output=pooled_output,
777
- hidden_states=all_hidden_states,
778
  )
779
 
780
 
@@ -871,6 +780,103 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
871
  )
872
 
873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
  def remap_state_dict(state_dict, config: PretrainedConfig):
875
  """
876
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
@@ -1022,47 +1028,47 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
1022
  if not last_layer_subset or d != (config.num_hidden_layers - 1):
1023
  Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
1024
  Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
1025
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
1026
- Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
1027
- )
1028
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
1029
- Wqkv_weights[
1030
- Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
1031
- ]
1032
- )
1033
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
1034
- Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
1035
- )
1036
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = (
1037
- Wqkv_biases[: Wqkv_biases.shape[0] // 3]
1038
- )
1039
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = (
1040
- Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
1041
- )
1042
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
1043
- Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
1044
- )
1045
  else:
1046
  Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
1047
  Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
1048
  Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
1049
  Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
1050
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
1051
- Wq_weight
1052
- )
1053
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
1054
- Wkv_weights[: Wkv_weights.shape[0] // 2, :]
1055
- )
1056
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
1057
- Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
1058
- )
1059
  state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
1060
  state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
1061
  : Wkv_biases.shape[0] // 2
1062
  ]
1063
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
1064
- Wkv_biases[Wkv_biases.shape[0] // 2 :]
1065
- )
1066
 
1067
  def inv_key_mapping_ln(key):
1068
  key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
@@ -1142,18 +1148,14 @@ class XLMRobertaClassificationHead(nn.Module):
1142
 
1143
  def __init__(self, config):
1144
  super().__init__()
1145
- fused_bias_fc = getattr(config, "fused_bias_fc", False)
1146
- if fused_bias_fc and FusedDense is None:
1147
- raise ImportError("fused_dense is not installed")
1148
- linear_cls = nn.Linear if not fused_bias_fc else FusedDense
1149
- self.dense = linear_cls(config.hidden_size, config.hidden_size)
1150
  classifier_dropout = (
1151
  config.classifier_dropout
1152
  if config.classifier_dropout is not None
1153
  else config.hidden_dropout_prob
1154
  )
1155
  self.dropout = nn.Dropout(classifier_dropout)
1156
- self.out_proj = linear_cls(config.hidden_size, config.num_labels)
1157
 
1158
  def forward(self, features, **kwargs):
1159
  x = features[:, 0, :] # take <s> token (equiv. to [CLS])
@@ -1251,4 +1253,4 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1251
  logits=logits,
1252
  hidden_states=outputs.hidden_states,
1253
  attentions=outputs.attentions,
1254
- )
 
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
 
 
16
  import numpy as np
17
+
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+ from einops import rearrange
24
+ from transformers import PretrainedConfig
25
  from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
+
29
  from transformers.models.bert.modeling_bert import (
30
  BaseModelOutputWithPoolingAndCrossAttentions,
31
  BertForPreTrainingOutput,
32
  )
 
33
 
34
+ from typing import List, Optional, Tuple, Union
35
+
36
+ from .xlm_padding import (
37
+ index_first_axis,
38
+ index_first_axis_residual,
39
+ pad_input,
40
+ unpad_input,
41
+ )
42
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
43
+ from .block import Block
44
  from .embedding import XLMRobertaEmbeddings
45
  from .mha import MHA
46
  from .mlp import FusedMLP, Mlp
47
+ from .stochastic_depth import StochasticDepth
48
+
49
 
50
  try:
51
  from flash_attn.ops.fused_dense import FusedDense
 
61
  try:
62
  from flash_attn.losses.cross_entropy import CrossEntropyLoss
63
  except ImportError:
64
+ CrossEntropyLoss = None
65
 
66
  try:
67
  from tqdm.autonotebook import trange
 
73
 
74
 
75
  def get_use_flash_attn(config: XLMRobertaFlashConfig):
76
+ if not getattr(config, "use_flash_attn", False):
77
+ return False
78
+ if not torch.cuda.is_available():
79
  return False
80
  if importlib.util.find_spec("flash_attn") is None:
81
  logger.warning(
82
+ 'flash_attn is not installed. Using PyTorch native attention implementation.'
83
  )
84
  return False
85
  return True
 
91
  rotary_kwargs = {}
92
  if config.position_embedding_type == "rotary":
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
+ config, "rotary_emb_dim", config.hidden_size
95
  )
96
+ rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
98
  config, "rotary_emb_scale_base", None
99
  )
 
109
  fused_bias_fc=fused_bias_fc,
110
  use_flash_attn=use_flash_attn,
111
  return_residual=return_residual,
 
112
  **rotary_kwargs,
113
  )
114
  return mixer_cls
 
190
  def __init__(self, config: XLMRobertaFlashConfig):
191
  super().__init__()
192
  self.use_flash_attn = get_use_flash_attn(config)
 
193
  self.layers = nn.ModuleList(
194
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
195
  )
 
203
  def gradient_checkpointing(self, value):
204
  self._grad_checkpointing = value
205
 
206
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
 
 
 
 
 
 
 
207
  """If subset_mask is not None, we only want output for the subset of the sequence.
208
  This means that we only compute the last layer output for these tokens.
209
  subset_mask: (batch, seqlen), dtype=torch.bool
210
  """
 
 
 
 
 
 
211
  if key_padding_mask is None or not self.use_flash_attn:
212
+ mixer_kwargs = (
213
+ {"key_padding_mask": key_padding_mask.bool()}
214
+ if key_padding_mask is not None
215
+ else None
216
+ )
217
  for layer in self.layers:
 
 
218
  if self._grad_checkpointing:
219
  hidden_states = torch.utils.checkpoint.checkpoint(
220
  layer,
221
  hidden_states,
222
+ use_reentrant=False,
223
  mixer_kwargs=mixer_kwargs,
224
  )
225
  else:
226
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
227
  if subset_mask is not None:
228
  hidden_states = hidden_states[subset_mask]
229
  else:
230
  batch, seqlen = hidden_states.shape[:2]
231
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
232
+ hidden_states, key_padding_mask
 
 
233
  )
234
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
 
 
 
 
 
235
  if subset_mask is None:
236
  for layer in self.layers:
237
  if self._grad_checkpointing:
238
  hidden_states = torch.utils.checkpoint.checkpoint(
239
  layer,
240
  hidden_states,
241
+ use_reentrant=False,
242
  mixer_kwargs=mixer_kwargs,
243
  )
244
  else:
245
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
 
 
 
246
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
247
  else:
248
  for layer in self.layers[:-1]:
 
250
  hidden_states = torch.utils.checkpoint.checkpoint(
251
  layer,
252
  hidden_states,
253
+ use_reentrant=False,
254
  mixer_kwargs=mixer_kwargs,
255
  )
256
  else:
 
288
  torch.utils.checkpoint.checkpoint(
289
  self.layers[-1],
290
  hidden_states_subset,
291
+ use_reentrant=False,
292
  mixer_kwargs=mixer_kwargs,
293
  )
294
  else:
295
  hidden_states = self.layers[-1](
296
  hidden_states_subset, mixer_kwargs=mixer_kwargs
297
  )
298
+ return hidden_states
299
 
300
 
301
  class XLMRobertaPooler(nn.Module):
 
308
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
309
  self.activation = nn.Tanh()
310
 
311
+ def forward(self, hidden_states, pool=True):
312
  # We "pool" the model by simply taking the hidden state corresponding
313
  # to the first token.
314
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
315
+ pooled_output = self.dense(first_token_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  pooled_output = self.activation(pooled_output)
317
  return pooled_output
318
 
 
391
  config_class = XLMRobertaFlashConfig
392
  base_model_prefix = "roberta"
393
  supports_gradient_checkpointing = True
 
394
 
395
  def _set_gradient_checkpointing(self, module, value=False):
396
  if isinstance(module, XLMRobertaEncoder):
 
402
  *args,
403
  **kwargs,
404
  ):
405
+ if not 'torch_dtype' in kwargs:
406
+ kwargs['torch_dtype'] = 'auto'
407
  return super().from_pretrained(*args, **kwargs)
408
 
409
 
410
+
411
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
412
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
413
  super().__init__(config)
 
425
  "gelu_fast",
426
  "gelu_pytorch_tanh",
427
  ]
428
+
429
  self.embeddings = XLMRobertaEmbeddings(
430
  config.hidden_size,
431
  config.vocab_size,
432
+ config.max_position_embeddings,
 
 
 
 
433
  config.type_vocab_size,
434
  padding_idx=config.pad_token_id,
435
  )
 
439
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
440
 
441
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
442
+
 
 
 
443
 
444
  @torch.inference_mode()
445
  def encode(
446
+ self: 'XLMRobertaModel',
447
  sentences: Union[str, List[str]],
448
  batch_size: int = 32,
449
  show_progress_bar: Optional[bool] = None,
450
+ output_value: str = 'sentence_embedding',
451
  convert_to_numpy: bool = True,
452
  convert_to_tensor: bool = False,
453
  device: Optional[torch.device] = None,
454
+ normalize_embeddings: bool = False,
 
 
 
455
  **tokenizer_kwargs,
456
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
457
  """
 
477
  Overwrites any setting from convert_to_numpy
478
  device(`torch.device`, *optional*, defaults to None):
479
  Which torch.device to use for the computation
480
+ normalize_embeddings(`bool`, *optional*, defaults to False):
481
  If set to true, returned vectors will have length 1. In that case, the
482
  faster dot-product (util.dot_score) instead of cosine similarity can
483
  be used.
 
 
484
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
485
  Keyword arguments for the tokenizer
486
  Returns:
 
488
  If convert_to_tensor, a stacked tensor is returned.
489
  If convert_to_numpy, a numpy matrix is returned.
490
  """
491
+ from transformers import AutoTokenizer
492
+
493
+ self.tokenizer = AutoTokenizer.from_pretrained(
494
+ self.name_or_path, trust_remote_code=True
495
+ )
496
+
497
  is_training = self.training
498
  self.eval()
499
 
 
506
  if convert_to_tensor:
507
  convert_to_numpy = False
508
 
509
+ if output_value != 'sentence_embedding':
510
  convert_to_tensor = False
511
  convert_to_numpy = False
512
 
513
  input_was_string = False
514
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
515
  sentences = [sentences]
516
  input_was_string = True
517
 
 
522
  inverse_permutation = np.argsort(permutation)
523
  sentences = [sentences[idx] for idx in permutation]
524
 
525
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
526
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
527
+ 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
528
  )
529
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
530
 
531
  all_embeddings = []
532
 
 
544
  for i in range_iter:
545
  encoded_input = self.tokenizer(
546
  sentences[i : i + batch_size],
547
+ return_tensors='pt',
548
  **tokenizer_kwargs,
549
  ).to(self.device)
550
+ token_embs = self.forward(**encoded_input)[0]
 
 
 
 
 
551
 
552
  # Accumulate in fp32 to avoid overflow
553
  token_embs = token_embs.float()
554
 
555
+ if output_value == 'token_embeddings':
556
  raise NotImplementedError
557
  elif output_value is None:
558
  raise NotImplementedError
559
  else:
560
+ if self.config.emb_pooler == 'cls':
561
  embeddings = self.cls_pooling(
562
+ token_embs, encoded_input['attention_mask']
563
  )
564
  else:
565
  embeddings = self.mean_pooling(
566
+ token_embs, encoded_input['attention_mask']
567
  )
568
 
569
+ if normalize_embeddings:
570
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
571
+
572
+ if convert_to_numpy:
573
+ embeddings = embeddings.cpu()
574
  all_embeddings.extend(embeddings)
575
 
576
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
577
 
 
 
 
 
 
 
 
 
 
 
578
  if convert_to_tensor:
579
  all_embeddings = torch.stack(all_embeddings)
580
  elif convert_to_numpy:
581
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
582
 
583
  if input_was_string:
584
  all_embeddings = all_embeddings[0]
 
586
  self.train(is_training)
587
  return all_embeddings
588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  def mean_pooling(
590
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
591
  ):
 
596
  input_mask_expanded.sum(1), min=1e-9
597
  )
598
 
 
 
599
 
600
+ def cls_pooling(
601
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
602
+ ):
603
+ return token_embeddings[:,0]
604
+
 
 
 
 
 
 
 
605
 
606
  def forward(
607
  self,
 
611
  attention_mask=None,
612
  masked_tokens_mask=None,
613
  return_dict=None,
 
614
  **kwargs,
615
  ):
616
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
 
618
  layer output for these tokens.
619
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
620
  """
621
+
622
  if kwargs:
623
  for key, value in kwargs.items():
624
  if value is not None:
625
  logger.warning(
626
+ 'Flash attention implementation does not support kwargs: %s',
627
  key,
628
  )
629
 
 
632
  )
633
 
634
  hidden_states = self.embeddings(
635
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
 
 
 
636
  )
637
  # TD [2022-12:18]: Don't need to force residual in fp32
638
  # BERT puts embedding LayerNorm before embedding dropout.
 
656
  subset_mask = None
657
 
658
  sequence_output = self.encoder(
659
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
 
 
 
 
660
  )
661
 
 
 
 
 
 
 
662
  if masked_tokens_mask is None:
663
  pooled_output = (
664
+ self.pooler(sequence_output) if self.pooler is not None else None
 
 
665
  )
666
  else:
667
  # TD [2022-03-01]: the indexing here is very tricky.
 
675
  pool_input = sequence_output[first_col_mask[subset_mask]]
676
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
677
  pooled_output = (
678
+ self.pooler(pool_input, pool=False) if self.pooler is not None else None
 
 
679
  )
680
 
681
  if not return_dict:
 
684
  return BaseModelOutputWithPoolingAndCrossAttentions(
685
  last_hidden_state=sequence_output,
686
  pooler_output=pooled_output,
 
687
  )
688
 
689
 
 
780
  )
781
 
782
 
783
+ # class XLMRobertaForPreTraining(XLMRobertaPreTrainedModel):
784
+ # def __init__(self, config: XLMRobertaFlashConfig):
785
+ # super().__init__(config)
786
+ # # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
787
+ # # (around 15%) to the classifier heads.
788
+ # self.dense_seq_output = getattr(config, "dense_seq_output", False)
789
+ # # If last_layer_subset, we only need the compute the last layer for a subset of tokens
790
+ # # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
791
+ # self.last_layer_subset = getattr(config, "last_layer_subset", False)
792
+ # if self.last_layer_subset:
793
+ # assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
794
+ # use_xentropy = getattr(config, "use_xentropy", False)
795
+ # if use_xentropy and CrossEntropyLoss is None:
796
+ # raise ImportError("xentropy_cuda is not installed")
797
+ # loss_cls = (
798
+ # nn.CrossEntropyLoss
799
+ # if not use_xentropy
800
+ # else partial(CrossEntropyLoss, inplace_backward=True)
801
+ # )
802
+ #
803
+ # self.xlm = XLMRobertaModel(config)
804
+ # self.cls = XLMRobertaPreTrainingHeads(config)
805
+ # self.mlm_loss = loss_cls(ignore_index=0)
806
+ # self.nsp_loss = loss_cls(ignore_index=-1)
807
+ #
808
+ # # Initialize weights and apply final processing
809
+ # self.apply(partial(_init_weights, initializer_range=config.initializer_range))
810
+ # self.tie_weights()
811
+ #
812
+ # def tie_weights(self):
813
+ # self.cls.predictions.decoder.weight = self.xlm.embeddings.word_embeddings.weight
814
+ #
815
+ # def forward(
816
+ # self,
817
+ # input_ids,
818
+ # position_ids=None,
819
+ # token_type_ids=None,
820
+ # attention_mask=None,
821
+ # labels=None,
822
+ # next_sentence_label=None,
823
+ # ):
824
+ # """
825
+ # If labels are provided, they must be 0 for masked out tokens (as specified in the attention
826
+ # mask).
827
+ # Outputs:
828
+ # if `labels` and `next_sentence_label` are not `None`:
829
+ # Outputs the total_loss which is the sum of the masked language modeling loss and the next
830
+ # sentence classification loss.
831
+ # if `labels` or `next_sentence_label` is `None`:
832
+ # Outputs a tuple comprising
833
+ # - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
834
+ # - the next sentence classification logits of shape [batch_size, 2].
835
+ #
836
+ # """
837
+ # masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
838
+ # outputs = self.xlm(
839
+ # input_ids,
840
+ # position_ids=position_ids,
841
+ # token_type_ids=token_type_ids,
842
+ # attention_mask=attention_mask.bool() if attention_mask is not None else None,
843
+ # masked_tokens_mask=masked_tokens_mask,
844
+ # )
845
+ # sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
846
+ # if self.dense_seq_output and labels is not None:
847
+ # masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
848
+ # if not self.last_layer_subset:
849
+ # sequence_output = index_first_axis(
850
+ # rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
851
+ # )
852
+ # prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
853
+ #
854
+ # total_loss = None
855
+ # if labels is not None and next_sentence_label is not None:
856
+ # if (
857
+ # self.dense_seq_output and labels is not None
858
+ # ): # prediction_scores are already flattened
859
+ # masked_lm_loss = self.mlm_loss(
860
+ # prediction_scores, labels.flatten()[masked_token_idx]
861
+ # )
862
+ # else:
863
+ # masked_lm_loss = self.mlm_loss(
864
+ # rearrange(prediction_scores, "... v -> (...) v"),
865
+ # rearrange(labels, "... -> (...)"),
866
+ # )
867
+ # next_sentence_loss = self.nsp_loss(
868
+ # rearrange(seq_relationship_score, "... t -> (...) t"),
869
+ # rearrange(next_sentence_label, "... -> (...)"),
870
+ # )
871
+ # total_loss = masked_lm_loss.float() + next_sentence_loss.float()
872
+ #
873
+ # return BertForPreTrainingOutput(
874
+ # loss=total_loss,
875
+ # prediction_logits=prediction_scores,
876
+ # seq_relationship_logits=seq_relationship_score,
877
+ # )
878
+
879
+
880
  def remap_state_dict(state_dict, config: PretrainedConfig):
881
  """
882
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
 
1028
  if not last_layer_subset or d != (config.num_hidden_layers - 1):
1029
  Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
1030
  Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
1031
+ state_dict[
1032
+ f"bert.encoder.layers.{d}.attention.self.query.weight"
1033
+ ] = Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
1034
+ state_dict[
1035
+ f"bert.encoder.layers.{d}.attention.self.key.weight"
1036
+ ] = Wqkv_weights[
1037
+ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
1038
+ ]
1039
+ state_dict[
1040
+ f"bert.encoder.layers.{d}.attention.self.value.weight"
1041
+ ] = Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
1042
+ state_dict[
1043
+ f"bert.encoder.layers.{d}.attention.self.query.bias"
1044
+ ] = Wqkv_biases[: Wqkv_biases.shape[0] // 3]
1045
+ state_dict[
1046
+ f"bert.encoder.layers.{d}.attention.self.key.bias"
1047
+ ] = Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
1048
+ state_dict[
1049
+ f"bert.encoder.layers.{d}.attention.self.value.bias"
1050
+ ] = Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
1051
  else:
1052
  Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
1053
  Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
1054
  Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
1055
  Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
1056
+ state_dict[
1057
+ f"bert.encoder.layers.{d}.attention.self.query.weight"
1058
+ ] = Wq_weight
1059
+ state_dict[
1060
+ f"bert.encoder.layers.{d}.attention.self.key.weight"
1061
+ ] = Wkv_weights[: Wkv_weights.shape[0] // 2, :]
1062
+ state_dict[
1063
+ f"bert.encoder.layers.{d}.attention.self.value.weight"
1064
+ ] = Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
1065
  state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
1066
  state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
1067
  : Wkv_biases.shape[0] // 2
1068
  ]
1069
+ state_dict[
1070
+ f"bert.encoder.layers.{d}.attention.self.value.bias"
1071
+ ] = Wkv_biases[Wkv_biases.shape[0] // 2 :]
1072
 
1073
  def inv_key_mapping_ln(key):
1074
  key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
 
1148
 
1149
  def __init__(self, config):
1150
  super().__init__()
1151
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
 
 
 
 
1152
  classifier_dropout = (
1153
  config.classifier_dropout
1154
  if config.classifier_dropout is not None
1155
  else config.hidden_dropout_prob
1156
  )
1157
  self.dropout = nn.Dropout(classifier_dropout)
1158
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1159
 
1160
  def forward(self, features, **kwargs):
1161
  x = features[:, 0, :] # take <s> token (equiv. to [CLS])
 
1253
  logits=logits,
1254
  hidden_states=outputs.hidden_states,
1255
  attentions=outputs.attentions,
1256
+ )
modeling_xlm_roberta_for_glue.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
6
+ from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
7
+
8
+ from .modeling_xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
9
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
10
+
11
+
12
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
13
+ def __init__(self, config: XLMRobertaFlashConfig):
14
+ super().__init__(config)
15
+ self.num_labels = config.num_labels
16
+ self.config = config
17
+
18
+ self.roberta = XLMRobertaModel(config)
19
+ classifier_dropout = (
20
+ config.classifier_dropout
21
+ if config.classifier_dropout is not None
22
+ else config.hidden_dropout_prob
23
+ )
24
+ self.dropout = nn.Dropout(classifier_dropout)
25
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
26
+
27
+ # Initialize weights and apply final processing
28
+ self.post_init()
29
+
30
+
31
+ def forward(
32
+ self,
33
+ input_ids: Optional[torch.Tensor] = None,
34
+ attention_mask: Optional[torch.Tensor] = None,
35
+ token_type_ids: Optional[torch.Tensor] = None,
36
+ position_ids: Optional[torch.Tensor] = None,
37
+ head_mask: Optional[torch.Tensor] = None,
38
+ inputs_embeds: Optional[torch.Tensor] = None,
39
+ labels: Optional[torch.Tensor] = None,
40
+ output_attentions: Optional[bool] = None,
41
+ output_hidden_states: Optional[bool] = None,
42
+ return_dict: Optional[bool] = None,
43
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
44
+ r"""
45
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
46
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
47
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
48
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
49
+ """
50
+ return_dict = (
51
+ return_dict if return_dict is not None else self.config.use_return_dict
52
+ )
53
+
54
+ assert head_mask is None
55
+ assert inputs_embeds is None
56
+ assert output_attentions is None
57
+ assert output_hidden_states is None
58
+ assert return_dict
59
+ outputs = self.roberta(
60
+ input_ids,
61
+ attention_mask=attention_mask,
62
+ token_type_ids=token_type_ids,
63
+ position_ids=position_ids,
64
+ head_mask=head_mask,
65
+ inputs_embeds=inputs_embeds,
66
+ output_attentions=output_attentions,
67
+ output_hidden_states=output_hidden_states,
68
+ return_dict=return_dict,
69
+ )
70
+
71
+ pooled_output = outputs[1]
72
+
73
+ pooled_output = self.dropout(pooled_output)
74
+ logits = self.classifier(pooled_output)
75
+
76
+ loss = None
77
+ if labels is not None:
78
+ if self.config.problem_type is None:
79
+ if self.num_labels == 1:
80
+ self.config.problem_type = "regression"
81
+ elif self.num_labels > 1 and (
82
+ labels.dtype == torch.long or labels.dtype == torch.int
83
+ ):
84
+ self.config.problem_type = "single_label_classification"
85
+ else:
86
+ self.config.problem_type = "multi_label_classification"
87
+
88
+ if self.config.problem_type == "regression":
89
+ loss_fct = MSELoss()
90
+ if self.num_labels == 1:
91
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
92
+ else:
93
+ loss = loss_fct(logits, labels)
94
+ elif self.config.problem_type == "single_label_classification":
95
+ loss_fct = CrossEntropyLoss()
96
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
97
+ elif self.config.problem_type == "multi_label_classification":
98
+ loss_fct = BCEWithLogitsLoss()
99
+ loss = loss_fct(logits, labels)
100
+ if not return_dict:
101
+ output = (logits,) + outputs[2:]
102
+ return ((loss,) + output) if loss is not None else output
103
+
104
+ return SequenceClassifierOutput(
105
+ loss=loss,
106
+ logits=logits,
107
+ hidden_states=outputs.hidden_states,
108
+ attentions=outputs.attentions,
109
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfa8fa7c7e120199548fe7149512c0adfe58f6bc13ce19f09b895aa25e8af910
3
+ size 1113232188
rotary.py DELETED
@@ -1,659 +0,0 @@
1
- # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py
2
- # Commit id: 3566596ad867ee415dd3c12616dd50c610176f6c
3
- # Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
4
-
5
- # Copyright (c) 2023, Tri Dao.
6
-
7
- from typing import Optional, Tuple, Union
8
-
9
- import torch
10
- from einops import rearrange, repeat
11
-
12
- if torch.cuda.is_available():
13
- try:
14
- from flash_attn.ops.triton.rotary import apply_rotary
15
- except ImportError:
16
-
17
- def apply_rotary(*args, **kwargs):
18
- raise RuntimeError(
19
- "FlashAttention is not installed. To proceed with training, please install FlashAttention. "
20
- "For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
21
- )
22
-
23
-
24
- def rotate_half(x, interleaved=False):
25
- if not interleaved:
26
- x1, x2 = x.chunk(2, dim=-1)
27
- return torch.cat((-x2, x1), dim=-1)
28
- else:
29
- x1, x2 = x[..., ::2], x[..., 1::2]
30
- return rearrange(
31
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
32
- )
33
-
34
-
35
- def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
36
- """
37
- x: (batch_size, seqlen, nheads, headdim)
38
- cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
39
- """
40
- ro_dim = cos.shape[-1] * 2
41
- assert ro_dim <= x.shape[-1]
42
- cos, sin = (
43
- cos[: x.shape[1]],
44
- sin[: x.shape[1]],
45
- )
46
- cos = repeat(
47
- cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
48
- )
49
- sin = repeat(
50
- sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
51
- )
52
- return torch.cat(
53
- [
54
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
55
- x[..., ro_dim:],
56
- ],
57
- dim=-1,
58
- )
59
-
60
-
61
- class ApplyRotaryEmb(torch.autograd.Function):
62
- @staticmethod
63
- def forward(
64
- ctx,
65
- x,
66
- cos,
67
- sin,
68
- interleaved=False,
69
- inplace=False,
70
- seqlen_offsets: Union[int, torch.Tensor] = 0,
71
- cu_seqlens: Optional[torch.Tensor] = None,
72
- max_seqlen: Optional[int] = None,
73
- ):
74
- out = apply_rotary(
75
- x,
76
- cos,
77
- sin,
78
- seqlen_offsets=seqlen_offsets,
79
- cu_seqlens=cu_seqlens,
80
- max_seqlen=max_seqlen,
81
- interleaved=interleaved,
82
- inplace=inplace,
83
- )
84
-
85
- if isinstance(seqlen_offsets, int):
86
- ctx.save_for_backward(
87
- cos, sin, cu_seqlens
88
- ) # Can't save int with save_for_backward
89
- ctx.seqlen_offsets = seqlen_offsets
90
- else:
91
- ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
92
- ctx.seqlen_offsets = None
93
- ctx.interleaved = interleaved
94
- ctx.inplace = inplace
95
- ctx.max_seqlen = max_seqlen
96
- return out if not inplace else x
97
-
98
- @staticmethod
99
- def backward(ctx, do):
100
- seqlen_offsets = ctx.seqlen_offsets
101
- if seqlen_offsets is None:
102
- cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
103
- else:
104
- cos, sin, cu_seqlens = ctx.saved_tensors
105
- # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
106
- # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
107
- if not ctx.interleaved and not ctx.inplace:
108
- do = do.clone()
109
-
110
- dx = apply_rotary(
111
- do,
112
- cos,
113
- sin,
114
- seqlen_offsets=seqlen_offsets,
115
- cu_seqlens=cu_seqlens,
116
- max_seqlen=ctx.max_seqlen,
117
- interleaved=ctx.interleaved,
118
- inplace=ctx.inplace,
119
- conjugate=True,
120
- )
121
- return dx, None, None, None, None, None, None, None
122
-
123
-
124
- def apply_rotary_emb(
125
- x,
126
- cos,
127
- sin,
128
- interleaved=False,
129
- inplace=False,
130
- seqlen_offsets: Union[int, torch.Tensor] = 0,
131
- cu_seqlens: Optional[torch.Tensor] = None,
132
- max_seqlen: Optional[int] = None,
133
- ):
134
- """
135
- Arguments:
136
- x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
137
- else (total_seqlen, nheads, headdim)
138
- cos, sin: (seqlen_rotary, rotary_dim / 2)
139
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
140
- of 1st half and 2nd half (GPT-NeoX style).
141
- inplace: if True, apply rotary embedding in-place.
142
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
143
- Most commonly used in inference when we have KV cache.
144
- cu_seqlens: (batch + 1,) or None
145
- max_seqlen: int
146
- Return:
147
- out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
148
- else (total_seqlen, nheads, headdim)
149
- rotary_dim must be <= headdim
150
- Apply rotary embedding to the first rotary_dim of x.
151
- """
152
- return ApplyRotaryEmb.apply(
153
- x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
154
- )
155
-
156
-
157
- # For backward compatibility
158
- apply_rotary_emb_func = apply_rotary_emb
159
-
160
-
161
- class ApplyRotaryEmbQKV_(torch.autograd.Function):
162
- @staticmethod
163
- def forward(
164
- ctx,
165
- qkv,
166
- cos,
167
- sin,
168
- cos_k=None,
169
- sin_k=None,
170
- interleaved=False,
171
- seqlen_offsets: Union[int, torch.Tensor] = 0,
172
- cu_seqlens: Optional[torch.Tensor] = None,
173
- max_seqlen: Optional[int] = None,
174
- use_flash_attn: bool = True,
175
- ):
176
- # batch, seqlen, three, nheads, headdim = qkv.shape
177
- assert qkv.shape[-3] == 3
178
- if cos_k is None and sin_k is None and qkv.is_contiguous():
179
-
180
- if use_flash_attn:
181
- # Call 1 kernel instead of 2 kernels
182
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
183
- # dimensions, we get the same tensor
184
- qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
185
- # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
186
- apply_rotary(
187
- qk,
188
- cos,
189
- sin,
190
- seqlen_offsets=seqlen_offsets,
191
- interleaved=interleaved,
192
- inplace=True,
193
- cu_seqlens=cu_seqlens,
194
- max_seqlen=max_seqlen,
195
- )
196
- else:
197
- q_rot = apply_rotary_emb_torch(
198
- qkv[:, :, 0],
199
- cos,
200
- sin,
201
- interleaved=interleaved,
202
- )
203
- k_rot = apply_rotary_emb_torch(
204
- qkv[:, :, 1],
205
- cos,
206
- sin,
207
- interleaved=interleaved,
208
- )
209
- qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
210
- else:
211
- cos_k = cos if cos_k is None else cos_k
212
- sin_k = sin if sin_k is None else sin_k
213
- q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
214
- apply_rotary(
215
- q,
216
- cos,
217
- sin,
218
- seqlen_offsets,
219
- interleaved=interleaved,
220
- inplace=True,
221
- cu_seqlens=cu_seqlens,
222
- max_seqlen=max_seqlen,
223
- )
224
- apply_rotary(
225
- k,
226
- cos_k,
227
- sin_k,
228
- seqlen_offsets,
229
- interleaved=interleaved,
230
- inplace=True,
231
- cu_seqlens=cu_seqlens,
232
- max_seqlen=max_seqlen,
233
- )
234
- ctx.save_for_backward(cos, sin, cos_k, sin_k)
235
- if isinstance(seqlen_offsets, int):
236
- ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens)
237
- ctx.seqlen_offsets = seqlen_offsets
238
- else:
239
- ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets)
240
- ctx.seqlen_offsets = None
241
- ctx.max_seqlen = max_seqlen
242
- ctx.interleaved = interleaved
243
- return qkv
244
-
245
- @staticmethod
246
- def backward(ctx, dqkv):
247
- seqlen_offsets = ctx.seqlen_offsets
248
- if seqlen_offsets is None:
249
- cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
250
- else:
251
- cos, sin, cos_k, sin_k, cu_seqlens = ctx.saved_tensors
252
- if cos_k is None and sin_k is None and dqkv.is_contiguous():
253
- # Call 1 kernel instead of 2 kernels
254
- # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
255
- # dimensions, we get the same tensor
256
- dqk = rearrange(dqkv[..., :2, :, :], "... t h d -> ... (t h) d")
257
- apply_rotary(
258
- dqk,
259
- cos,
260
- sin,
261
- seqlen_offsets=seqlen_offsets,
262
- interleaved=ctx.interleaved,
263
- inplace=True,
264
- conjugate=True,
265
- cu_seqlens=cu_seqlens,
266
- max_seqlen=ctx.max_seqlen,
267
- )
268
- else:
269
- cos_k = cos if cos_k is None else cos_k
270
- sin_k = sin if sin_k is None else sin_k
271
- dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
272
- apply_rotary(
273
- dq,
274
- cos,
275
- sin,
276
- seqlen_offsets,
277
- interleaved=ctx.interleaved,
278
- inplace=True,
279
- conjugate=True,
280
- cu_seqlens=cu_seqlens,
281
- max_seqlen=ctx.max_seqlen,
282
- )
283
- apply_rotary(
284
- dk,
285
- cos_k,
286
- sin_k,
287
- seqlen_offsets,
288
- interleaved=ctx.interleaved,
289
- inplace=True,
290
- conjugate=True,
291
- cu_seqlens=cu_seqlens,
292
- max_seqlen=ctx.max_seqlen,
293
- )
294
- return dqkv, None, None, None, None, None, None, None, None, None
295
-
296
-
297
- def apply_rotary_emb_qkv_(
298
- qkv,
299
- cos,
300
- sin,
301
- cos_k=None,
302
- sin_k=None,
303
- interleaved=False,
304
- seqlen_offsets: Union[int, torch.Tensor] = 0,
305
- cu_seqlens: Optional[torch.Tensor] = None,
306
- max_seqlen: Optional[int] = None,
307
- use_flash_attn=True,
308
- ):
309
- """
310
- Arguments:
311
- qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
312
- else (total_seqlen, 3, nheads, headdim)
313
- cos, sin: (seqlen, rotary_dim / 2)
314
- cos_k, sin_k: (seqlen, rotary_dim / 2), optional
315
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
316
- 1st half and 2nd half (GPT-NeoX style).
317
- seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
318
- Most commonly used in inference when we have KV cache.
319
- cu_seqlens: (batch + 1,) or None
320
- max_seqlen: int
321
- Return:
322
- qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
323
- else (total_seqlen, 3, nheads, headdim)
324
- rotary_dim must be <= headdim
325
- Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
326
- """
327
- return ApplyRotaryEmbQKV_.apply(
328
- qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
329
- )
330
-
331
-
332
- class ApplyRotaryEmbKV_(torch.autograd.Function):
333
- @staticmethod
334
- def forward(
335
- ctx,
336
- kv,
337
- cos,
338
- sin,
339
- interleaved=False,
340
- seqlen_offsets: Union[int, torch.Tensor] = 0,
341
- cu_seqlens: Optional[torch.Tensor] = None,
342
- max_seqlen: Optional[int] = None,
343
- ):
344
- # batch, seqlen, two, nheads, headdim = kv.shape
345
- assert kv.shape[-3] == 2
346
- k = kv[..., 0, :, :]
347
- apply_rotary(
348
- k,
349
- cos,
350
- sin,
351
- seqlen_offsets=seqlen_offsets,
352
- interleaved=interleaved,
353
- inplace=True,
354
- cu_seqlens=cu_seqlens,
355
- max_seqlen=max_seqlen,
356
- )
357
- if isinstance(seqlen_offsets, int):
358
- ctx.save_for_backward(
359
- cos, sin, cu_seqlens
360
- ) # Can't save int with save_for_backward
361
- ctx.seqlen_offsets = seqlen_offsets
362
- else:
363
- ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
364
- ctx.seqlen_offsets = None
365
- ctx.max_seqlen = max_seqlen
366
- ctx.interleaved = interleaved
367
- return kv
368
-
369
- @staticmethod
370
- def backward(ctx, dkv):
371
- seqlen_offsets = ctx.seqlen_offsets
372
- if seqlen_offsets is None:
373
- cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
374
- else:
375
- cos, sin, cu_seqlens = ctx.saved_tensors
376
- apply_rotary(
377
- dkv[..., 0, :, :],
378
- cos,
379
- sin,
380
- seqlen_offsets=seqlen_offsets,
381
- interleaved=ctx.interleaved,
382
- inplace=True,
383
- conjugate=True,
384
- cu_seqlens=cu_seqlens,
385
- max_seqlen=ctx.max_seqlen,
386
- )
387
- return dkv, None, None, None, None, None, None
388
-
389
-
390
- apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
391
-
392
-
393
- def apply_rotary_emb_kv_(
394
- kv,
395
- cos,
396
- sin,
397
- interleaved=False,
398
- seqlen_offsets: Union[int, torch.Tensor] = 0,
399
- cu_seqlens: Optional[torch.Tensor] = None,
400
- max_seqlen: Optional[int] = None,
401
- ):
402
- """
403
- Arguments:
404
- kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
405
- else (total_seqlen, 2, nheads, headdim)
406
- cos, sin: (seqlen, rotary_dim / 2)
407
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
408
- 1st half and 2nd half (GPT-NeoX style).
409
- seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
410
- Most commonly used in inference when we have KV cache.
411
- cu_seqlens: (batch + 1,) or None
412
- max_seqlen: int
413
- Return:
414
- kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
415
- else (total_seqlen, 2, nheads, headdim)
416
- rotary_dim must be <= headdim
417
- Apply rotary embedding *inplace* to the first rotary_dim of K.
418
- """
419
- return ApplyRotaryEmbKV_.apply(
420
- kv, cos, sin, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
421
- )
422
-
423
-
424
- class RotaryEmbedding(torch.nn.Module):
425
- """
426
- The rotary position embeddings from RoFormer_ (Su et. al).
427
- A crucial insight from the method is that the query and keys are
428
- transformed by rotation matrices which depend on the relative positions.
429
-
430
- Other implementations are available in the Rotary Transformer repo_ and in
431
- GPT-NeoX_, GPT-NeoX was an inspiration
432
-
433
- .. _RoFormer: https://arxiv.org/abs/2104.09864
434
- .. _repo: https://github.com/ZhuiyiTechnology/roformer
435
- .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
436
-
437
- If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
438
- A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
439
- Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
440
- """
441
-
442
- def __init__(
443
- self,
444
- dim: int,
445
- base=10000.0,
446
- interleaved=False,
447
- scale_base=None,
448
- pos_idx_in_fp32=True,
449
- device=None,
450
- use_flash_attn=True,
451
- ):
452
- """
453
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
454
- of 1st half and 2nd half (GPT-NeoX style).
455
- pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
456
- otherwise they might be in lower precision.
457
- This option was added because previously (before 2023-07-02), when we construct
458
- the position indices, we use the dtype of self.inv_freq. In most cases this would
459
- be fp32, but if the model is trained in pure bf16 (not mixed precision), then
460
- self.inv_freq would be bf16, and the position indices are also in bf16.
461
- Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
462
- embeddings for some positions will coincide.
463
- To maintain compatibility with models previously trained in pure bf16,
464
- we add this option.
465
- """
466
- super().__init__()
467
- self.dim = dim
468
- self._base = float(base)
469
- self.pos_idx_in_fp32 = pos_idx_in_fp32
470
- self.use_flash_attn = use_flash_attn
471
- # Generate and save the inverse frequency buffer (non trainable)
472
- inv_freq = self._compute_inv_freq(device)
473
- self.register_buffer("inv_freq", inv_freq, persistent=False)
474
- self.interleaved = interleaved
475
- self.scale_base = scale_base
476
- scale = (
477
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
478
- / (1.4 * dim)
479
- if scale_base is not None
480
- else None
481
- )
482
- self.register_buffer("scale", scale, persistent=False)
483
-
484
- self._seq_len_cached = 0
485
- self._cos_cached = None
486
- self._sin_cached = None
487
- self._cos_k_cached = None
488
- self._sin_k_cached = None
489
-
490
- @property
491
- def base(self):
492
- return self._base
493
-
494
- @base.setter
495
- def base(self, new_base):
496
- new_base = float(new_base)
497
- if new_base > 0:
498
- if self._base != new_base: # only update if the base value has changed
499
- self._base = new_base
500
- self._update_cos_sin_cache(
501
- self._seq_len_cached,
502
- device=self.inv_freq.device,
503
- dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
504
- rotary_base_changed=True,
505
- )
506
- else:
507
- raise ValueError("Rotary base value must be positive")
508
-
509
- def _compute_inv_freq(self, device=None):
510
- return 1.0 / (
511
- self.base
512
- ** (
513
- torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
514
- / self.dim
515
- )
516
- )
517
-
518
- def _update_cos_sin_cache(
519
- self, seqlen, device=None, dtype=None, rotary_base_changed=False
520
- ):
521
- # Reset the tables if the sequence length has changed,
522
- # if we're on a new device (possibly due to tracing for instance),
523
- # or if we're switching from inference mode to training
524
- # or if the rotary base value was changed
525
- if (
526
- seqlen > self._seq_len_cached
527
- or self._cos_cached is None
528
- or self._cos_cached.device != device
529
- or self._cos_cached.dtype != dtype
530
- or (self.training and self._cos_cached.is_inference())
531
- or rotary_base_changed
532
- ):
533
- self._seq_len_cached = seqlen
534
- # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
535
- # And the output of arange can be quite large, so bf16 would lose a lot of precision.
536
- # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
537
- if rotary_base_changed:
538
- self.inv_freq = self._compute_inv_freq(device=device)
539
- if self.pos_idx_in_fp32:
540
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
541
- # We want fp32 here as well since inv_freq will be multiplied with t, and the output
542
- # will be large. Having it in bf16 will lose a lot of precision and cause the
543
- # cos & sin output to change significantly.
544
- # We want to recompute self.inv_freq if it was not loaded in fp32
545
- if self.inv_freq.dtype != torch.float32:
546
- inv_freq = self._compute_inv_freq(device=device)
547
- else:
548
- inv_freq = self.inv_freq
549
- else:
550
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
551
- inv_freq = self.inv_freq
552
-
553
- # Don't do einsum, it converts fp32 to fp16 under AMP
554
- # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
555
- freqs = torch.outer(t, inv_freq)
556
- if self.scale is None:
557
- self._cos_cached = torch.cos(freqs).to(dtype)
558
- self._sin_cached = torch.sin(freqs).to(dtype)
559
- else:
560
- power = (
561
- torch.arange(
562
- seqlen, dtype=self.scale.dtype, device=self.scale.device
563
- )
564
- - seqlen // 2
565
- ) / self.scale_base
566
- scale = self.scale.to(device=power.device) ** rearrange(
567
- power, "s -> s 1"
568
- )
569
- # We want the multiplication by scale to happen in fp32
570
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
571
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
572
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
573
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
574
-
575
- def forward(
576
- self,
577
- qkv: torch.Tensor,
578
- kv: Optional[torch.Tensor] = None,
579
- seqlen_offset: Union[int, torch.Tensor] = 0,
580
- cu_seqlens: Optional[torch.Tensor] = None,
581
- max_seqlen: Optional[int] = None,
582
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
583
- """
584
- qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
585
- else it's just q of shape (batch, seqlen, nheads, headdim)
586
- kv: (batch, seqlen, 2, nheads, headdim)
587
- seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
588
- Most commonly used in inference when we have KV cache.
589
- If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
590
- should pass in max_seqlen, which will update the cos / sin cache up to that length.
591
- Apply rotary embedding *inplace* to qkv and / or kv.
592
- """
593
- if cu_seqlens is not None:
594
- assert max_seqlen is not None
595
- seqlen = qkv.shape[1] if max_seqlen is None else max_seqlen
596
- if max_seqlen is not None:
597
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
598
- elif isinstance(seqlen_offset, int):
599
- self._update_cos_sin_cache(
600
- seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype
601
- )
602
- if kv is None:
603
- if self.scale is None:
604
- return apply_rotary_emb_qkv_(
605
- qkv,
606
- self._cos_cached,
607
- self._sin_cached,
608
- interleaved=self.interleaved,
609
- seqlen_offsets=seqlen_offset,
610
- cu_seqlens=cu_seqlens,
611
- max_seqlen=max_seqlen,
612
- use_flash_attn=self.use_flash_attn,
613
- )
614
- else:
615
- return apply_rotary_emb_qkv_(
616
- qkv,
617
- self._cos_cached,
618
- self._sin_cached,
619
- self._cos_k_cached,
620
- self._sin_k_cached,
621
- interleaved=self.interleaved,
622
- seqlen_offsets=seqlen_offset,
623
- cu_seqlens=cu_seqlens,
624
- max_seqlen=max_seqlen,
625
- use_flash_attn=self.use_flash_attn,
626
- )
627
- else:
628
- q = qkv
629
- q = apply_rotary_emb_func(
630
- q,
631
- self._cos_cached,
632
- self._sin_cached,
633
- interleaved=self.interleaved,
634
- inplace=True,
635
- seqlen_offsets=seqlen_offset,
636
- cu_seqlens=cu_seqlens,
637
- max_seqlen=max_seqlen,
638
- )
639
- if self.scale is None:
640
- kv = apply_rotary_emb_kv_(
641
- kv,
642
- self._cos_cached,
643
- self._sin_cached,
644
- interleaved=self.interleaved,
645
- seqlen_offsets=seqlen_offset,
646
- cu_seqlens=cu_seqlens,
647
- max_seqlen=max_seqlen,
648
- )
649
- else:
650
- kv = apply_rotary_emb_kv_(
651
- kv,
652
- self._cos_k_cached,
653
- self._sin_k_cached,
654
- interleaved=self.interleaved,
655
- seqlen_offsets=seqlen_offset,
656
- cu_seqlens=cu_seqlens,
657
- max_seqlen=max_seqlen,
658
- )
659
- return q, kv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stochastic_depth.py CHANGED
@@ -34,7 +34,7 @@
34
 
35
  import torch
36
  import torch.fx
37
- from torch import Tensor, nn
38
 
39
 
40
  def stochastic_depth(
 
34
 
35
  import torch
36
  import torch.fx
37
+ from torch import nn, Tensor
38
 
39
 
40
  def stochastic_depth(
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "model_max_length": 512,
3
+ "tokenizer_class": "XLMRobertaTokenizer"
4
+ }
xlm_padding.py CHANGED
@@ -18,9 +18,7 @@ class IndexFirstAxis(torch.autograd.Function):
18
  # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
  # return input[indices]
20
  return torch.gather(
21
- rearrange(input, "b ... -> b (...)"),
22
- 0,
23
- repeat(indices, "z -> z d", d=second_dim),
24
  ).reshape(-1, *other_shape)
25
 
26
  @staticmethod
@@ -36,9 +34,7 @@ class IndexFirstAxis(torch.autograd.Function):
36
  )
37
  # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
38
  # grad_input[indices] = grad_output
39
- grad_input.scatter_(
40
- 0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
41
- )
42
  return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
43
 
44
 
@@ -102,7 +98,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
102
  index_first_axis_residual = IndexFirstAxisResidual.apply
103
 
104
 
105
- def unpad_input(hidden_states, attention_mask, adapter_mask=None):
106
  """
107
  Arguments:
108
  hidden_states: (batch, seqlen, ...)
@@ -116,16 +112,7 @@ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
116
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
117
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
118
  max_seqlen_in_batch = seqlens_in_batch.max().item()
119
- cu_seqlens = F.pad(
120
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
121
- )
122
-
123
- cu_adapter_mask = (
124
- torch.repeat_interleave(adapter_mask, cu_seqlens[1:] - cu_seqlens[:-1])
125
- if adapter_mask is not None
126
- else None
127
- )
128
-
129
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
130
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
131
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
@@ -136,7 +123,6 @@ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
136
  indices,
137
  cu_seqlens,
138
  max_seqlen_in_batch,
139
- cu_adapter_mask,
140
  )
141
 
142
 
@@ -194,18 +180,14 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
194
  """
195
  length = attention_mask_in_length.sum(dim=-1)
196
  seqlen = attention_mask_in_length.size(-1)
197
- attention_mask_2d = torch.arange(
198
- seqlen, device=length.device, dtype=length.dtype
199
- ).expand(len(length), seqlen) < length.unsqueeze(1)
200
- real_indices_idx = torch.nonzero(
201
- attention_mask_in_length.flatten(), as_tuple=False
202
- ).flatten()
203
  seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
204
  indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
205
  max_seqlen_in_batch = seqlens_in_batch.max().item()
206
- cu_seqlens = F.pad(
207
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
208
- )
209
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
210
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
211
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
@@ -233,4 +215,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
233
  # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
234
  # output[indices] = hidden_states
235
  output = index_put_first_axis(hidden_states, indices, batch * seqlen)
236
- return rearrange(output, "(b s) ... -> b s ...", b=batch)
 
18
  # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
  # return input[indices]
20
  return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
 
 
22
  ).reshape(-1, *other_shape)
23
 
24
  @staticmethod
 
34
  )
35
  # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
36
  # grad_input[indices] = grad_output
37
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
 
 
38
  return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
39
 
40
 
 
98
  index_first_axis_residual = IndexFirstAxisResidual.apply
99
 
100
 
101
+ def unpad_input(hidden_states, attention_mask):
102
  """
103
  Arguments:
104
  hidden_states: (batch, seqlen, ...)
 
112
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
113
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
  max_seqlen_in_batch = seqlens_in_batch.max().item()
115
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
 
 
 
 
 
 
 
116
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
117
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
118
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
 
123
  indices,
124
  cu_seqlens,
125
  max_seqlen_in_batch,
 
126
  )
127
 
128
 
 
180
  """
181
  length = attention_mask_in_length.sum(dim=-1)
182
  seqlen = attention_mask_in_length.size(-1)
183
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
184
+ seqlen) < length.unsqueeze(
185
+ 1)
186
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
 
 
187
  seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
188
  indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
189
  max_seqlen_in_batch = seqlens_in_batch.max().item()
190
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
191
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
192
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
193
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
 
215
  # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
216
  # output[indices] = hidden_states
217
  output = index_put_first_axis(hidden_states, indices, batch * seqlen)
218
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)