support use_flash_attn in from_pretrained
#18
by
michael-guenther
- opened
- README.md +4 -113
- block.py +4 -5
- config.json +31 -0
- configuration_xlm_roberta.py +34 -95
- embedding.py +12 -45
- mha.py +46 -143
- mlp.py +15 -58
- modeling_lora.py +95 -140
- modeling_xlm_roberta.py +217 -191
- modeling_xlm_roberta_for_glue.py +109 -0
- pytorch_model.bin +3 -0
- rotary.py +0 -659
- stochastic_depth.py +1 -1
- tokenizer.json +0 -0
- tokenizer_config.json +4 -0
- xlm_padding.py +10 -28
README.md
CHANGED
@@ -1,114 +1,5 @@
|
|
1 |
-
|
2 |
-
tags:
|
3 |
-
- transformers
|
4 |
-
- xlm-roberta
|
5 |
-
library_name: transformers
|
6 |
-
license: cc-by-nc-4.0
|
7 |
-
language:
|
8 |
-
- multilingual
|
9 |
-
- af
|
10 |
-
- am
|
11 |
-
- ar
|
12 |
-
- as
|
13 |
-
- az
|
14 |
-
- be
|
15 |
-
- bg
|
16 |
-
- bn
|
17 |
-
- br
|
18 |
-
- bs
|
19 |
-
- ca
|
20 |
-
- cs
|
21 |
-
- cy
|
22 |
-
- da
|
23 |
-
- de
|
24 |
-
- el
|
25 |
-
- en
|
26 |
-
- eo
|
27 |
-
- es
|
28 |
-
- et
|
29 |
-
- eu
|
30 |
-
- fa
|
31 |
-
- fi
|
32 |
-
- fr
|
33 |
-
- fy
|
34 |
-
- ga
|
35 |
-
- gd
|
36 |
-
- gl
|
37 |
-
- gu
|
38 |
-
- ha
|
39 |
-
- he
|
40 |
-
- hi
|
41 |
-
- hr
|
42 |
-
- hu
|
43 |
-
- hy
|
44 |
-
- id
|
45 |
-
- is
|
46 |
-
- it
|
47 |
-
- ja
|
48 |
-
- jv
|
49 |
-
- ka
|
50 |
-
- kk
|
51 |
-
- km
|
52 |
-
- kn
|
53 |
-
- ko
|
54 |
-
- ku
|
55 |
-
- ky
|
56 |
-
- la
|
57 |
-
- lo
|
58 |
-
- lt
|
59 |
-
- lv
|
60 |
-
- mg
|
61 |
-
- mk
|
62 |
-
- ml
|
63 |
-
- mn
|
64 |
-
- mr
|
65 |
-
- ms
|
66 |
-
- my
|
67 |
-
- ne
|
68 |
-
- nl
|
69 |
-
- 'no'
|
70 |
-
- om
|
71 |
-
- or
|
72 |
-
- pa
|
73 |
-
- pl
|
74 |
-
- ps
|
75 |
-
- pt
|
76 |
-
- ro
|
77 |
-
- ru
|
78 |
-
- sa
|
79 |
-
- sd
|
80 |
-
- si
|
81 |
-
- sk
|
82 |
-
- sl
|
83 |
-
- so
|
84 |
-
- sq
|
85 |
-
- sr
|
86 |
-
- su
|
87 |
-
- sv
|
88 |
-
- sw
|
89 |
-
- ta
|
90 |
-
- te
|
91 |
-
- th
|
92 |
-
- tl
|
93 |
-
- tr
|
94 |
-
- ug
|
95 |
-
- uk
|
96 |
-
- ur
|
97 |
-
- uz
|
98 |
-
- vi
|
99 |
-
- xh
|
100 |
-
- yi
|
101 |
-
- zh
|
102 |
-
---
|
103 |
-
Core implementation of Jina XLM-RoBERTa
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
- [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3)
|
110 |
-
- [jinaai/jina-colbert-v2](https://huggingface.co/jinaai/jina-colbert-v2)
|
111 |
-
|
112 |
-
### Converting weights
|
113 |
-
|
114 |
-
Weights from an [original XLMRoberta model](https://huggingface.co/FacebookAI/xlm-roberta-large) can be converted using the `convert_roberta_weights_to_flash.py` script in the model repository.
|
|
|
1 |
+
# Converting Weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
```
|
4 |
+
python3 -m "xlm-roberta-flash-implementation".convert_roberta_weights_to_flash --output pytorch_model_xlmr_flash.bin
|
5 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block.py
CHANGED
@@ -8,14 +8,15 @@ from typing import Optional
|
|
8 |
|
9 |
import torch
|
10 |
import torch.nn as nn
|
|
|
11 |
from torch import Tensor
|
12 |
|
|
|
13 |
from .mha import MHA
|
14 |
from .mlp import Mlp
|
15 |
-
from .stochastic_depth import StochasticDepth
|
16 |
|
17 |
try:
|
18 |
-
from flash_attn.ops.triton.layer_norm import
|
19 |
except ImportError:
|
20 |
layer_norm_fn, RMSNorm = None, None
|
21 |
|
@@ -232,9 +233,7 @@ class Block(nn.Module):
|
|
232 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
233 |
)
|
234 |
if not isinstance(self.mlp, nn.Identity):
|
235 |
-
mlp_out = self.mlp(
|
236 |
-
hidden_states, adapter_mask=mixer_kwargs.get("adapter_mask")
|
237 |
-
)
|
238 |
if self.return_residual: # mlp out is actually a pair here
|
239 |
mlp_out, hidden_states = mlp_out
|
240 |
if not self.fused_dropout_add_ln:
|
|
|
8 |
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
from torch import Tensor
|
13 |
|
14 |
+
from .stochastic_depth import StochasticDepth
|
15 |
from .mha import MHA
|
16 |
from .mlp import Mlp
|
|
|
17 |
|
18 |
try:
|
19 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
20 |
except ImportError:
|
21 |
layer_norm_fn, RMSNorm = None, None
|
22 |
|
|
|
233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
234 |
)
|
235 |
if not isinstance(self.mlp, nn.Identity):
|
236 |
+
mlp_out = self.mlp(hidden_states)
|
|
|
|
|
237 |
if self.return_residual: # mlp out is actually a pair here
|
238 |
mlp_out, hidden_states = mlp_out
|
239 |
if not self.fused_dropout_add_ln:
|
config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
|
4 |
+
"AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
|
5 |
+
"AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
|
6 |
+
"AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM",
|
7 |
+
"AutoModelForSequenceClassification":"modeling_xlm_roberta.XLMRobertaForSequenceClassification"
|
8 |
+
},
|
9 |
+
"architectures": [
|
10 |
+
"XLMRobertaModel"
|
11 |
+
],
|
12 |
+
"attention_probs_dropout_prob": 0.1,
|
13 |
+
"bos_token_id": 0,
|
14 |
+
"eos_token_id": 2,
|
15 |
+
"hidden_act": "gelu",
|
16 |
+
"hidden_dropout_prob": 0.1,
|
17 |
+
"hidden_size": 768,
|
18 |
+
"initializer_range": 0.02,
|
19 |
+
"intermediate_size": 3072,
|
20 |
+
"layer_norm_eps": 1e-05,
|
21 |
+
"max_position_embeddings": 8194,
|
22 |
+
"num_attention_heads": 12,
|
23 |
+
"num_hidden_layers": 12,
|
24 |
+
"output_past": true,
|
25 |
+
"pad_token_id": 1,
|
26 |
+
"position_embedding_type": "absolute",
|
27 |
+
"transformers_version": "4.17.0.dev0",
|
28 |
+
"type_vocab_size": 1,
|
29 |
+
"use_cache": false,
|
30 |
+
"vocab_size": 250002
|
31 |
+
}
|
configuration_xlm_roberta.py
CHANGED
@@ -1,94 +1,42 @@
|
|
1 |
-
from typing import Any, Dict, List, Optional, Union
|
2 |
-
|
3 |
-
import torch
|
4 |
from transformers import PretrainedConfig
|
5 |
-
|
6 |
|
7 |
class XLMRobertaFlashConfig(PretrainedConfig):
|
8 |
-
|
9 |
-
model_type = "xlm-roberta"
|
10 |
-
|
11 |
def __init__(
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
matryoshka_dimensions: Optional[List[int]] = None,
|
44 |
-
truncate_dim: Optional[int] = None,
|
45 |
-
**kwargs: Dict[str, Any],
|
46 |
):
|
47 |
-
|
48 |
-
Initialize the XLMRobertaFlashConfig configuration.
|
49 |
|
50 |
-
Args:
|
51 |
-
vocab_size (int): Size of the vocabulary.
|
52 |
-
hidden_size (int): Dimensionality of the encoder layers and the pooler layer.
|
53 |
-
num_hidden_layers (int): Number of hidden layers in the Transformer encoder.
|
54 |
-
num_attention_heads (int): Number of attention heads for each attention layer in the Transformer encoder.
|
55 |
-
intermediate_size (int): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer.
|
56 |
-
hidden_act (str): The activation function to use.
|
57 |
-
hidden_dropout_prob (float): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
58 |
-
attention_probs_dropout_prob (float): The dropout ratio for the attention probabilities.
|
59 |
-
max_position_embeddings (int): The maximum length of the position embeddings.
|
60 |
-
type_vocab_size (int): The vocabulary size of the token type ids.
|
61 |
-
initializer_range (float): The standard deviation for initializing all weight matrices.
|
62 |
-
layer_norm_eps (float): The epsilon used by the layer normalization layers.
|
63 |
-
pad_token_id (int): The ID of the padding token.
|
64 |
-
bos_token_id (int): The ID of the beginning-of-sequence token.
|
65 |
-
eos_token_id (int): The ID of the end-of-sequence token.
|
66 |
-
position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
|
67 |
-
rotary_emb_base (float): Base for rotary embeddings.
|
68 |
-
use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
|
69 |
-
use_reentrant (bool): Whether or not the model should enable the 'use_reentrant' flag in gradient checkpointing.
|
70 |
-
classifier_dropout (Optional[float]): The dropout ratio for the classification head.
|
71 |
-
lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
|
72 |
-
lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
|
73 |
-
lora_rank (int): Rank for LoRA adaptations.
|
74 |
-
lora_dropout_p (float): Dropout probability for LoRA adaptations.
|
75 |
-
lora_alpha (int): Alpha parameter for LoRA.
|
76 |
-
lora_main_params_trainable (bool): Whether to make the main model parameters trainable when using LoRA.
|
77 |
-
load_trained_adapters (bool): Whether to load trained adapters.
|
78 |
-
use_flash_attn (bool): Whether to use FlashAttention.
|
79 |
-
torch_dtype (Optional[Union[str, torch.dtype]]): Data type for the tensors.
|
80 |
-
emb_pooler (Optional[str]): Pooling layer configuration.
|
81 |
-
matryoshka_dimensions (Optional[List[int]]): Configuration for matryoshka dimension reduction.
|
82 |
-
truncate_dim (Optional[int]): Dimension to truncate embeddings to, if any.
|
83 |
-
**kwargs (Dict[str, Any]): Additional keyword arguments passed to the configuration.
|
84 |
-
"""
|
85 |
-
|
86 |
-
super().__init__(
|
87 |
-
pad_token_id=pad_token_id,
|
88 |
-
bos_token_id=bos_token_id,
|
89 |
-
eos_token_id=eos_token_id,
|
90 |
-
**kwargs,
|
91 |
-
)
|
92 |
|
93 |
self.vocab_size = vocab_size
|
94 |
self.hidden_size = hidden_size
|
@@ -103,13 +51,10 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
103 |
self.initializer_range = initializer_range
|
104 |
self.layer_norm_eps = layer_norm_eps
|
105 |
self.position_embedding_type = position_embedding_type
|
106 |
-
self.rotary_emb_base = rotary_emb_base
|
107 |
self.use_cache = use_cache
|
108 |
-
self.use_reentrant = use_reentrant
|
109 |
self.classifier_dropout = classifier_dropout
|
110 |
self.load_trained_adapters = load_trained_adapters
|
111 |
self.lora_adaptations = lora_adaptations
|
112 |
-
self.task_instructions = task_instructions
|
113 |
self.lora_rank = lora_rank
|
114 |
self.lora_dropout_p = lora_dropout_p
|
115 |
self.lora_alpha = lora_alpha
|
@@ -118,13 +63,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
118 |
self.emb_pooler = emb_pooler
|
119 |
self.matryoshka_dimensions = matryoshka_dimensions
|
120 |
self.truncate_dim = truncate_dim
|
121 |
-
if (
|
122 |
-
torch_dtype
|
123 |
-
and hasattr(torch, torch_dtype)
|
124 |
-
and type(getattr(torch, torch_dtype)) is torch.dtype
|
125 |
-
):
|
126 |
self.torch_dtype = getattr(torch, torch_dtype)
|
127 |
else:
|
128 |
self.torch_dtype = torch_dtype
|
129 |
-
if not self.use_flash_attn or not torch.cuda.is_available():
|
130 |
-
self.torch_dtype = torch.float32
|
|
|
|
|
|
|
|
|
1 |
from transformers import PretrainedConfig
|
2 |
+
import torch
|
3 |
|
4 |
class XLMRobertaFlashConfig(PretrainedConfig):
|
|
|
|
|
|
|
5 |
def __init__(
|
6 |
+
self,
|
7 |
+
vocab_size=30522,
|
8 |
+
hidden_size=768,
|
9 |
+
num_hidden_layers=12,
|
10 |
+
num_attention_heads=12,
|
11 |
+
intermediate_size=3072,
|
12 |
+
hidden_act="gelu",
|
13 |
+
hidden_dropout_prob=0.1,
|
14 |
+
attention_probs_dropout_prob=0.1,
|
15 |
+
max_position_embeddings=512,
|
16 |
+
type_vocab_size=2,
|
17 |
+
initializer_range=0.02,
|
18 |
+
layer_norm_eps=1e-12,
|
19 |
+
pad_token_id=1,
|
20 |
+
bos_token_id=0,
|
21 |
+
eos_token_id=2,
|
22 |
+
position_embedding_type="absolute",
|
23 |
+
use_cache=True,
|
24 |
+
classifier_dropout=None,
|
25 |
+
lora_adaptations=None,
|
26 |
+
lora_rank=4,
|
27 |
+
lora_dropout_p=0.0,
|
28 |
+
lora_alpha=1,
|
29 |
+
lora_main_params_trainable=False,
|
30 |
+
load_trained_adapters=False,
|
31 |
+
use_flash_attn=True,
|
32 |
+
torch_dtype=None,
|
33 |
+
emb_pooler=None,
|
34 |
+
matryoshka_dimensions=None,
|
35 |
+
truncate_dim=None,
|
36 |
+
**kwargs,
|
|
|
|
|
|
|
37 |
):
|
38 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
self.vocab_size = vocab_size
|
42 |
self.hidden_size = hidden_size
|
|
|
51 |
self.initializer_range = initializer_range
|
52 |
self.layer_norm_eps = layer_norm_eps
|
53 |
self.position_embedding_type = position_embedding_type
|
|
|
54 |
self.use_cache = use_cache
|
|
|
55 |
self.classifier_dropout = classifier_dropout
|
56 |
self.load_trained_adapters = load_trained_adapters
|
57 |
self.lora_adaptations = lora_adaptations
|
|
|
58 |
self.lora_rank = lora_rank
|
59 |
self.lora_dropout_p = lora_dropout_p
|
60 |
self.lora_alpha = lora_alpha
|
|
|
63 |
self.emb_pooler = emb_pooler
|
64 |
self.matryoshka_dimensions = matryoshka_dimensions
|
65 |
self.truncate_dim = truncate_dim
|
66 |
+
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
|
|
|
|
|
|
|
|
67 |
self.torch_dtype = getattr(torch, torch_dtype)
|
68 |
else:
|
69 |
self.torch_dtype = torch_dtype
|
|
|
|
embedding.py
CHANGED
@@ -5,8 +5,10 @@
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
-
from
|
9 |
-
|
|
|
|
|
10 |
|
11 |
|
12 |
class XLMRobertaEmbeddings(nn.Module):
|
@@ -36,60 +38,25 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
36 |
max_position_embeddings, embed_dim, **factory_kwargs
|
37 |
)
|
38 |
if self.type_vocab_size > 0:
|
39 |
-
self.token_type_embeddings = nn.Embedding(
|
40 |
-
type_vocab_size, embed_dim, **factory_kwargs
|
41 |
-
)
|
42 |
|
43 |
-
def forward(
|
44 |
-
self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None
|
45 |
-
):
|
46 |
"""
|
47 |
input_ids: (batch, seqlen)
|
48 |
position_ids: (batch, seqlen)
|
49 |
token_type_ids: (batch, seqlen)
|
50 |
-
adapter_mask: (batch, 1)
|
51 |
"""
|
52 |
batch_size, seqlen = input_ids.shape
|
53 |
-
|
54 |
-
unique_tasks = torch.unique(adapter_mask)
|
55 |
-
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
56 |
-
embeddings = torch.empty(
|
57 |
-
*input_ids.shape,
|
58 |
-
self.word_embeddings.embedding_dim,
|
59 |
-
dtype=embedding_dtype,
|
60 |
-
device=input_ids.device
|
61 |
-
)
|
62 |
-
for task_id in unique_tasks:
|
63 |
-
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
64 |
-
task_input_ids = input_ids[task_indices]
|
65 |
-
task_embeddings = self.word_embeddings(task_input_ids, task_id=task_id)
|
66 |
-
embeddings[task_indices] = task_embeddings
|
67 |
-
else:
|
68 |
-
embeddings = self.word_embeddings(input_ids)
|
69 |
if self.max_position_embeddings > 0:
|
70 |
if position_ids is None:
|
71 |
-
position_ids =
|
72 |
-
|
73 |
-
).to(input_ids.device)
|
74 |
position_embeddings = self.position_embeddings(position_ids)
|
75 |
embeddings = embeddings + position_embeddings
|
76 |
if self.type_vocab_size > 0:
|
77 |
if token_type_ids is None:
|
78 |
-
token_type_ids = torch.zeros(
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
if adapter_mask is not None:
|
83 |
-
unique_tasks = torch.unique(adapter_mask)
|
84 |
-
for task_id in unique_tasks:
|
85 |
-
task_token_type_embeddings = self.token_type_embeddings(
|
86 |
-
token_type_ids, task_id=task_id
|
87 |
-
)
|
88 |
-
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
89 |
-
embeddings[task_indices] = (
|
90 |
-
embeddings[task_indices] + task_token_type_embeddings
|
91 |
-
)
|
92 |
-
else:
|
93 |
-
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
94 |
-
embeddings = embeddings + token_type_embeddings
|
95 |
return embeddings
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
|
12 |
|
13 |
|
14 |
class XLMRobertaEmbeddings(nn.Module):
|
|
|
38 |
max_position_embeddings, embed_dim, **factory_kwargs
|
39 |
)
|
40 |
if self.type_vocab_size > 0:
|
41 |
+
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
|
|
|
|
|
42 |
|
43 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
|
|
|
|
44 |
"""
|
45 |
input_ids: (batch, seqlen)
|
46 |
position_ids: (batch, seqlen)
|
47 |
token_type_ids: (batch, seqlen)
|
|
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
50 |
+
embeddings = self.word_embeddings(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
if self.max_position_embeddings > 0:
|
52 |
if position_ids is None:
|
53 |
+
position_ids =create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
54 |
+
# position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
|
|
55 |
position_embeddings = self.position_embeddings(position_ids)
|
56 |
embeddings = embeddings + position_embeddings
|
57 |
if self.type_vocab_size > 0:
|
58 |
if token_type_ids is None:
|
59 |
+
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
60 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
61 |
+
embeddings = embeddings + token_type_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
return embeddings
|
mha.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
|
2 |
# Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
|
3 |
-
# Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
|
4 |
|
5 |
# Copyright (c) 2023, Tri Dao.
|
6 |
|
@@ -12,23 +11,27 @@ import torch.nn as nn
|
|
12 |
from einops import rearrange, repeat
|
13 |
|
14 |
try:
|
15 |
-
from flash_attn import (
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
except ImportError:
|
21 |
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
22 |
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
23 |
flash_attn_with_kvcache = None
|
24 |
|
25 |
try:
|
26 |
-
from flash_attn.ops.fused_dense import
|
27 |
-
RowParallelLinear)
|
28 |
except ImportError:
|
29 |
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
30 |
|
31 |
-
|
|
|
|
|
|
|
32 |
|
33 |
|
34 |
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
|
@@ -44,9 +47,7 @@ def get_alibi_slopes(nheads):
|
|
44 |
closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
|
45 |
return (
|
46 |
get_slopes_power_of_2(closest_power_of_2)
|
47 |
-
+ get_alibi_slopes(2 * closest_power_of_2)[0::2][
|
48 |
-
: nheads - closest_power_of_2
|
49 |
-
]
|
50 |
)
|
51 |
|
52 |
|
@@ -71,9 +72,7 @@ class FlashSelfAttention(nn.Module):
|
|
71 |
deterministic=False,
|
72 |
):
|
73 |
super().__init__()
|
74 |
-
assert
|
75 |
-
flash_attn_varlen_qkvpacked_func is not None
|
76 |
-
), "FlashAttention is not installed"
|
77 |
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
78 |
self.causal = causal
|
79 |
self.softmax_scale = softmax_scale
|
@@ -153,9 +152,7 @@ class FlashCrossAttention(nn.Module):
|
|
153 |
deterministic=False,
|
154 |
):
|
155 |
super().__init__()
|
156 |
-
assert
|
157 |
-
flash_attn_varlen_kvpacked_func is not None
|
158 |
-
), "FlashAttention is not installed"
|
159 |
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
160 |
self.causal = causal
|
161 |
self.softmax_scale = softmax_scale
|
@@ -321,10 +318,7 @@ class CrossAttention(nn.Module):
|
|
321 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
322 |
if key_padding_mask is not None:
|
323 |
padding_mask = torch.full(
|
324 |
-
(batch_size, seqlen_k),
|
325 |
-
-10000.0,
|
326 |
-
dtype=scores.dtype,
|
327 |
-
device=scores.device,
|
328 |
)
|
329 |
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
330 |
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
@@ -436,26 +430,20 @@ class MHA(nn.Module):
|
|
436 |
else:
|
437 |
alibi_slopes = None
|
438 |
if window_size != (-1, -1):
|
439 |
-
assert (
|
440 |
-
use_flash_attn
|
441 |
-
), "Local (sliding window) attention code path requires flash_attn"
|
442 |
|
443 |
self.num_heads = num_heads
|
444 |
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
445 |
assert (
|
446 |
self.num_heads % self.num_heads_kv == 0
|
447 |
), "num_heads must be divisible by num_heads_kv"
|
448 |
-
assert
|
449 |
-
self.embed_dim % num_heads == 0
|
450 |
-
), "embed_dim must be divisible by num_heads"
|
451 |
self.head_dim = self.embed_dim // num_heads
|
452 |
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
453 |
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
454 |
|
455 |
if self.rotary_emb_dim > 0:
|
456 |
-
assert
|
457 |
-
not cross_attn
|
458 |
-
), "MHA with rotary embedding does not support cross-attention yet"
|
459 |
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
460 |
self.rotary_emb = RotaryEmbedding(
|
461 |
self.rotary_emb_dim,
|
@@ -463,41 +451,29 @@ class MHA(nn.Module):
|
|
463 |
scale_base=rotary_emb_scale_base,
|
464 |
interleaved=rotary_emb_interleaved,
|
465 |
device=device,
|
466 |
-
use_flash_attn=use_flash_attn,
|
467 |
)
|
468 |
|
469 |
if fused_bias_fc and FusedDense is None:
|
470 |
raise ImportError("fused_dense is not installed")
|
471 |
-
|
472 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
473 |
linear_resid_cls = (
|
474 |
-
LinearResidual
|
475 |
-
if not fused_bias_fc
|
476 |
-
else partial(FusedDense, return_residual=True)
|
477 |
)
|
478 |
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
479 |
inner_attn_cls = (
|
480 |
-
partial(
|
481 |
-
FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size
|
482 |
-
)
|
483 |
if use_flash_attn
|
484 |
else SelfAttention
|
485 |
)
|
486 |
inner_cross_attn_cls = (
|
487 |
-
partial(
|
488 |
-
FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size
|
489 |
-
)
|
490 |
if use_flash_attn
|
491 |
else CrossAttention
|
492 |
)
|
493 |
if not self.cross_attn:
|
494 |
-
self.Wqkv = wqkv_cls(
|
495 |
-
embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs
|
496 |
-
)
|
497 |
else:
|
498 |
-
self.Wq = linear_cls(
|
499 |
-
embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs
|
500 |
-
)
|
501 |
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
502 |
if self.dwconv:
|
503 |
if self.num_heads_kv == self.num_heads:
|
@@ -508,9 +484,7 @@ class MHA(nn.Module):
|
|
508 |
self.dwconv_q = nn.Conv1d(
|
509 |
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
|
510 |
)
|
511 |
-
self.dwconv_kv = nn.Conv1d(
|
512 |
-
kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim
|
513 |
-
)
|
514 |
self.inner_attn = inner_attn_cls(
|
515 |
causal=causal,
|
516 |
softmax_scale=softmax_scale,
|
@@ -519,9 +493,7 @@ class MHA(nn.Module):
|
|
519 |
self.inner_cross_attn = inner_cross_attn_cls(
|
520 |
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
521 |
)
|
522 |
-
self.out_proj = linear_cls(
|
523 |
-
embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs
|
524 |
-
)
|
525 |
|
526 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
527 |
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
@@ -539,9 +511,7 @@ class MHA(nn.Module):
|
|
539 |
def _update_kv_cache(self, kv, inference_params):
|
540 |
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
541 |
assert not self.dwconv, "Generation does not support dwconv yet"
|
542 |
-
assert
|
543 |
-
self.layer_idx is not None
|
544 |
-
), "Generation requires layer_idx in the constructor"
|
545 |
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
546 |
|
547 |
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
@@ -557,10 +527,7 @@ class MHA(nn.Module):
|
|
557 |
self.rotary_emb._update_cos_sin_cache(
|
558 |
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
559 |
)
|
560 |
-
rotary_cos, rotary_sin =
|
561 |
-
self.rotary_emb._cos_cached,
|
562 |
-
self.rotary_emb._sin_cached,
|
563 |
-
)
|
564 |
else:
|
565 |
rotary_cos, rotary_sin = None, None
|
566 |
batch = q.shape[0]
|
@@ -582,9 +549,7 @@ class MHA(nn.Module):
|
|
582 |
cache_seqlens=cache_seqlens,
|
583 |
softmax_scale=self.inner_cross_attn.softmax_scale,
|
584 |
causal=self.inner_cross_attn.causal,
|
585 |
-
rotary_interleaved=
|
586 |
-
self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False
|
587 |
-
),
|
588 |
alibi_slopes=alibi_slopes,
|
589 |
)
|
590 |
return context
|
@@ -629,7 +594,6 @@ class MHA(nn.Module):
|
|
629 |
max_seqlen=None,
|
630 |
mixer_subset=None,
|
631 |
inference_params=None,
|
632 |
-
adapter_mask=None,
|
633 |
**kwargs,
|
634 |
):
|
635 |
"""
|
@@ -655,6 +619,7 @@ class MHA(nn.Module):
|
|
655 |
assert key_padding_mask is None
|
656 |
assert self.use_flash_attn
|
657 |
assert not self.dwconv
|
|
|
658 |
if key_padding_mask is not None:
|
659 |
assert cu_seqlens is None
|
660 |
assert max_seqlen is None
|
@@ -678,50 +643,19 @@ class MHA(nn.Module):
|
|
678 |
else inference_params.seqlen_offset
|
679 |
)
|
680 |
)
|
681 |
-
rotary_max_seqlen =
|
682 |
-
|
683 |
-
if inference_params is not None
|
684 |
-
else max_seqlen
|
685 |
-
)
|
686 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
687 |
assert x_kv is None and mixer_subset is None
|
688 |
-
|
689 |
-
|
690 |
-
unique_tasks = torch.unique(adapter_mask)
|
691 |
-
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
692 |
-
qkv = torch.empty(
|
693 |
-
*x.shape[:-1],
|
694 |
-
self.Wqkv.out_features,
|
695 |
-
dtype=qkv_dtype,
|
696 |
-
device=x.device,
|
697 |
-
)
|
698 |
-
for task_id in unique_tasks:
|
699 |
-
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
700 |
-
task_tensor = x[task_indices]
|
701 |
-
if not self.return_residual:
|
702 |
-
task_qkv = self.Wqkv(task_tensor, task_id=task_id)
|
703 |
-
else:
|
704 |
-
task_qkv, _ = self.Wqkv(
|
705 |
-
task_tensor, task_id=task_id, residual=True
|
706 |
-
)
|
707 |
-
qkv[task_indices] = task_qkv
|
708 |
else:
|
709 |
-
|
710 |
-
qkv = self.Wqkv(x)
|
711 |
-
else:
|
712 |
-
if hasattr(self.Wqkv, "parametrizations"):
|
713 |
-
qkv, x = self.Wqkv(x, residual=True)
|
714 |
-
else:
|
715 |
-
qkv, x = self.Wqkv(x)
|
716 |
-
|
717 |
if self.dwconv:
|
718 |
qkv = rearrange(
|
719 |
-
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
|
720 |
-
"b d s -> b s d",
|
721 |
).contiguous()
|
722 |
-
qkv = rearrange(
|
723 |
-
qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
|
724 |
-
)
|
725 |
if (
|
726 |
inference_params is None
|
727 |
or inference_params.seqlen_offset == 0
|
@@ -730,18 +664,13 @@ class MHA(nn.Module):
|
|
730 |
):
|
731 |
if self.rotary_emb_dim > 0:
|
732 |
qkv = self.rotary_emb(
|
733 |
-
qkv,
|
734 |
-
seqlen_offset=seqlen_offset,
|
735 |
-
cu_seqlens=cu_seqlens,
|
736 |
-
max_seqlen=rotary_max_seqlen,
|
737 |
)
|
738 |
if inference_params is None:
|
739 |
if not self.checkpointing:
|
740 |
context = self.inner_attn(qkv, **kwargs)
|
741 |
else:
|
742 |
-
context = torch.utils.checkpoint.checkpoint(
|
743 |
-
self.inner_attn, qkv, **kwargs
|
744 |
-
)
|
745 |
else:
|
746 |
context = self._update_kvcache_attention(
|
747 |
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
@@ -770,17 +699,13 @@ class MHA(nn.Module):
|
|
770 |
q = qkv[..., : self.num_heads * self.head_dim]
|
771 |
kv = qkv[..., self.num_heads * self.head_dim :]
|
772 |
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
773 |
-
kv = rearrange(
|
774 |
-
kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim
|
775 |
-
)
|
776 |
if self.dwconv:
|
777 |
q = rearrange(
|
778 |
-
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
|
779 |
-
"b d s -> b s d",
|
780 |
).contiguous()
|
781 |
kv = rearrange(
|
782 |
-
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
|
783 |
-
"b d s -> b s d",
|
784 |
).contiguous()
|
785 |
if (
|
786 |
inference_params is None
|
@@ -790,11 +715,7 @@ class MHA(nn.Module):
|
|
790 |
):
|
791 |
if self.rotary_emb_dim > 0:
|
792 |
q, kv = self.rotary_emb(
|
793 |
-
q,
|
794 |
-
kv,
|
795 |
-
seqlen_offset=seqlen_offset,
|
796 |
-
cu_seqlens=cu_seqlens,
|
797 |
-
max_seqlen=rotary_max_seqlen,
|
798 |
)
|
799 |
if inference_params is None:
|
800 |
if not self.checkpointing:
|
@@ -806,25 +727,7 @@ class MHA(nn.Module):
|
|
806 |
else:
|
807 |
context = self._update_kvcache_attention(q, kv, inference_params)
|
808 |
else:
|
809 |
-
context = self._apply_rotary_update_kvcache_attention(
|
810 |
-
|
811 |
-
)
|
812 |
-
|
813 |
-
inp = rearrange(context, "... h d -> ... (h d)")
|
814 |
-
if adapter_mask is not None:
|
815 |
-
unique_tasks = torch.unique(adapter_mask)
|
816 |
-
out_dtype = next(self.out_proj.parameters()).dtype
|
817 |
-
out = torch.empty(
|
818 |
-
*inp.shape[:-1],
|
819 |
-
self.out_proj.out_features,
|
820 |
-
dtype=out_dtype,
|
821 |
-
device=inp.device,
|
822 |
-
)
|
823 |
-
for task_id in unique_tasks:
|
824 |
-
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
825 |
-
task_tensor = inp[task_indices]
|
826 |
-
task_out = self.out_proj(task_tensor, task_id=task_id)
|
827 |
-
out[task_indices] = task_out
|
828 |
-
else:
|
829 |
-
out = self.out_proj(inp)
|
830 |
return out if not self.return_residual else (out, x)
|
|
|
|
1 |
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
|
2 |
# Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
|
|
|
3 |
|
4 |
# Copyright (c) 2023, Tri Dao.
|
5 |
|
|
|
11 |
from einops import rearrange, repeat
|
12 |
|
13 |
try:
|
14 |
+
from flash_attn import (
|
15 |
+
flash_attn_kvpacked_func,
|
16 |
+
flash_attn_qkvpacked_func,
|
17 |
+
flash_attn_varlen_kvpacked_func,
|
18 |
+
flash_attn_varlen_qkvpacked_func,
|
19 |
+
flash_attn_with_kvcache,
|
20 |
+
)
|
21 |
except ImportError:
|
22 |
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
23 |
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
24 |
flash_attn_with_kvcache = None
|
25 |
|
26 |
try:
|
27 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
|
|
|
28 |
except ImportError:
|
29 |
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
30 |
|
31 |
+
try:
|
32 |
+
from flash_attn.layers.rotary import RotaryEmbedding
|
33 |
+
except ImportError:
|
34 |
+
RotaryEmbedding = None
|
35 |
|
36 |
|
37 |
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
|
|
|
47 |
closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
|
48 |
return (
|
49 |
get_slopes_power_of_2(closest_power_of_2)
|
50 |
+
+ get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
|
|
|
|
|
51 |
)
|
52 |
|
53 |
|
|
|
72 |
deterministic=False,
|
73 |
):
|
74 |
super().__init__()
|
75 |
+
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
|
|
|
|
76 |
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
77 |
self.causal = causal
|
78 |
self.softmax_scale = softmax_scale
|
|
|
152 |
deterministic=False,
|
153 |
):
|
154 |
super().__init__()
|
155 |
+
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
|
|
|
|
156 |
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
157 |
self.causal = causal
|
158 |
self.softmax_scale = softmax_scale
|
|
|
318 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
319 |
if key_padding_mask is not None:
|
320 |
padding_mask = torch.full(
|
321 |
+
(batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
|
|
|
|
|
|
|
322 |
)
|
323 |
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
324 |
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
|
|
430 |
else:
|
431 |
alibi_slopes = None
|
432 |
if window_size != (-1, -1):
|
433 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
|
|
|
|
434 |
|
435 |
self.num_heads = num_heads
|
436 |
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
437 |
assert (
|
438 |
self.num_heads % self.num_heads_kv == 0
|
439 |
), "num_heads must be divisible by num_heads_kv"
|
440 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
|
|
|
|
441 |
self.head_dim = self.embed_dim // num_heads
|
442 |
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
443 |
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
444 |
|
445 |
if self.rotary_emb_dim > 0:
|
446 |
+
assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
|
|
|
|
|
447 |
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
448 |
self.rotary_emb = RotaryEmbedding(
|
449 |
self.rotary_emb_dim,
|
|
|
451 |
scale_base=rotary_emb_scale_base,
|
452 |
interleaved=rotary_emb_interleaved,
|
453 |
device=device,
|
|
|
454 |
)
|
455 |
|
456 |
if fused_bias_fc and FusedDense is None:
|
457 |
raise ImportError("fused_dense is not installed")
|
|
|
458 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
459 |
linear_resid_cls = (
|
460 |
+
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
|
|
|
|
461 |
)
|
462 |
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
463 |
inner_attn_cls = (
|
464 |
+
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
|
|
|
|
465 |
if use_flash_attn
|
466 |
else SelfAttention
|
467 |
)
|
468 |
inner_cross_attn_cls = (
|
469 |
+
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
|
|
|
|
470 |
if use_flash_attn
|
471 |
else CrossAttention
|
472 |
)
|
473 |
if not self.cross_attn:
|
474 |
+
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
|
|
|
|
475 |
else:
|
476 |
+
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
|
|
|
|
|
477 |
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
478 |
if self.dwconv:
|
479 |
if self.num_heads_kv == self.num_heads:
|
|
|
484 |
self.dwconv_q = nn.Conv1d(
|
485 |
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
|
486 |
)
|
487 |
+
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
|
|
|
|
|
488 |
self.inner_attn = inner_attn_cls(
|
489 |
causal=causal,
|
490 |
softmax_scale=softmax_scale,
|
|
|
493 |
self.inner_cross_attn = inner_cross_attn_cls(
|
494 |
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
495 |
)
|
496 |
+
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
|
|
|
|
497 |
|
498 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
499 |
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
|
|
511 |
def _update_kv_cache(self, kv, inference_params):
|
512 |
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
513 |
assert not self.dwconv, "Generation does not support dwconv yet"
|
514 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
|
|
|
|
515 |
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
516 |
|
517 |
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
|
|
527 |
self.rotary_emb._update_cos_sin_cache(
|
528 |
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
529 |
)
|
530 |
+
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
|
|
|
|
|
|
531 |
else:
|
532 |
rotary_cos, rotary_sin = None, None
|
533 |
batch = q.shape[0]
|
|
|
549 |
cache_seqlens=cache_seqlens,
|
550 |
softmax_scale=self.inner_cross_attn.softmax_scale,
|
551 |
causal=self.inner_cross_attn.causal,
|
552 |
+
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
|
|
|
|
553 |
alibi_slopes=alibi_slopes,
|
554 |
)
|
555 |
return context
|
|
|
594 |
max_seqlen=None,
|
595 |
mixer_subset=None,
|
596 |
inference_params=None,
|
|
|
597 |
**kwargs,
|
598 |
):
|
599 |
"""
|
|
|
619 |
assert key_padding_mask is None
|
620 |
assert self.use_flash_attn
|
621 |
assert not self.dwconv
|
622 |
+
assert self.rotary_emb_dim == 0
|
623 |
if key_padding_mask is not None:
|
624 |
assert cu_seqlens is None
|
625 |
assert max_seqlen is None
|
|
|
643 |
else inference_params.seqlen_offset
|
644 |
)
|
645 |
)
|
646 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
647 |
+
batch, seqlen = x.shape[:2]
|
|
|
|
|
|
|
648 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
649 |
assert x_kv is None and mixer_subset is None
|
650 |
+
if not self.return_residual:
|
651 |
+
qkv = self.Wqkv(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
652 |
else:
|
653 |
+
qkv, x = self.Wqkv(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
if self.dwconv:
|
655 |
qkv = rearrange(
|
656 |
+
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
|
|
657 |
).contiguous()
|
658 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
|
|
|
|
659 |
if (
|
660 |
inference_params is None
|
661 |
or inference_params.seqlen_offset == 0
|
|
|
664 |
):
|
665 |
if self.rotary_emb_dim > 0:
|
666 |
qkv = self.rotary_emb(
|
667 |
+
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
|
|
|
|
|
|
668 |
)
|
669 |
if inference_params is None:
|
670 |
if not self.checkpointing:
|
671 |
context = self.inner_attn(qkv, **kwargs)
|
672 |
else:
|
673 |
+
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
|
|
|
|
674 |
else:
|
675 |
context = self._update_kvcache_attention(
|
676 |
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
|
|
699 |
q = qkv[..., : self.num_heads * self.head_dim]
|
700 |
kv = qkv[..., self.num_heads * self.head_dim :]
|
701 |
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
702 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
|
|
|
|
703 |
if self.dwconv:
|
704 |
q = rearrange(
|
705 |
+
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
|
|
706 |
).contiguous()
|
707 |
kv = rearrange(
|
708 |
+
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
|
|
709 |
).contiguous()
|
710 |
if (
|
711 |
inference_params is None
|
|
|
715 |
):
|
716 |
if self.rotary_emb_dim > 0:
|
717 |
q, kv = self.rotary_emb(
|
718 |
+
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
|
|
|
|
|
|
|
|
719 |
)
|
720 |
if inference_params is None:
|
721 |
if not self.checkpointing:
|
|
|
727 |
else:
|
728 |
context = self._update_kvcache_attention(q, kv, inference_params)
|
729 |
else:
|
730 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
731 |
+
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
732 |
return out if not self.return_residual else (out, x)
|
733 |
+
|
mlp.py
CHANGED
@@ -8,14 +8,14 @@ import torch.nn as nn
|
|
8 |
import torch.nn.functional as F
|
9 |
from torch.distributed import ProcessGroup
|
10 |
|
|
|
11 |
try:
|
12 |
from flash_attn.ops.activations import swiglu
|
13 |
except ImportError:
|
14 |
swiglu = None
|
15 |
|
16 |
try:
|
17 |
-
from flash_attn.ops.fused_dense import
|
18 |
-
RowParallelLinear)
|
19 |
except ImportError:
|
20 |
ColumnParallelLinear, RowParallelLinear = None, None
|
21 |
|
@@ -41,48 +41,17 @@ class Mlp(nn.Module):
|
|
41 |
factory_kwargs = {"device": device, "dtype": dtype}
|
42 |
super().__init__()
|
43 |
out_features = out_features if out_features is not None else in_features
|
44 |
-
hidden_features =
|
45 |
-
hidden_features if hidden_features is not None else in_features * 4
|
46 |
-
)
|
47 |
self.return_residual = return_residual
|
48 |
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
49 |
self.activation = activation
|
50 |
-
self.fc2 = nn.Linear(
|
51 |
-
hidden_features, out_features, bias=bias2, **factory_kwargs
|
52 |
-
)
|
53 |
-
|
54 |
-
def forward(self, x, adapter_mask=None):
|
55 |
-
if adapter_mask is not None:
|
56 |
-
unique_tasks = torch.unique(adapter_mask)
|
57 |
-
fc1_dtype = next(self.fc1.parameters()).dtype
|
58 |
-
y = torch.empty(
|
59 |
-
*x.shape[:-1], self.fc1.out_features, dtype=fc1_dtype, device=x.device
|
60 |
-
)
|
61 |
-
for task_id in unique_tasks:
|
62 |
-
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
63 |
-
task_tensor = x[task_indices]
|
64 |
-
task_y = self.fc1(task_tensor, task_id=task_id)
|
65 |
-
y[task_indices] = task_y
|
66 |
-
else:
|
67 |
-
y = self.fc1(x)
|
68 |
|
|
|
|
|
69 |
y = self.activation(y)
|
70 |
-
|
71 |
-
|
72 |
-
unique_tasks = torch.unique(adapter_mask)
|
73 |
-
fc2_dtype = next(self.fc2.parameters()).dtype
|
74 |
-
out = torch.empty(
|
75 |
-
*y.shape[:-1], self.fc2.out_features, dtype=fc2_dtype, device=y.device
|
76 |
-
)
|
77 |
-
for task_id in unique_tasks:
|
78 |
-
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
79 |
-
task_tensor = y[task_indices]
|
80 |
-
task_out = self.fc2(task_tensor, task_id=task_id)
|
81 |
-
out[task_indices] = task_out
|
82 |
-
else:
|
83 |
-
out = self.fc2(y)
|
84 |
-
|
85 |
-
return out if not self.return_residual else (out, x)
|
86 |
|
87 |
|
88 |
class ParallelMLP(nn.Module):
|
@@ -104,9 +73,7 @@ class ParallelMLP(nn.Module):
|
|
104 |
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
105 |
assert RowParallelLinear is not None, "Need to install fused_dense"
|
106 |
out_features = out_features if out_features is not None else in_features
|
107 |
-
hidden_features =
|
108 |
-
hidden_features if hidden_features is not None else in_features * 4
|
109 |
-
)
|
110 |
self.fc1 = ColumnParallelLinear(
|
111 |
in_features,
|
112 |
hidden_features,
|
@@ -152,25 +119,17 @@ class GatedMlp(nn.Module):
|
|
152 |
hidden_features = (
|
153 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
154 |
)
|
155 |
-
hidden_features = (
|
156 |
-
(hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
157 |
-
)
|
158 |
self.return_residual = return_residual
|
159 |
-
self.fc1 = nn.Linear(
|
160 |
-
in_features, 2 * hidden_features, bias=bias1, **factory_kwargs
|
161 |
-
)
|
162 |
self.activation = activation
|
163 |
-
self.fc2 = nn.Linear(
|
164 |
-
hidden_features, out_features, bias=bias2, **factory_kwargs
|
165 |
-
)
|
166 |
|
167 |
def forward(self, x):
|
168 |
y = self.fc1(x)
|
169 |
if self.activation == F.sigmoid: # Special case for GLU
|
170 |
y = F.glu(y, dim=-1)
|
171 |
-
elif
|
172 |
-
self.activation == F.silu and swiglu is not None
|
173 |
-
): # Special case for SwiGLU
|
174 |
y, gate = y.chunk(2, dim=-1)
|
175 |
y = swiglu(gate, y)
|
176 |
else:
|
@@ -203,9 +162,7 @@ class ParallelGatedMlp(nn.Module):
|
|
203 |
hidden_features = (
|
204 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
205 |
)
|
206 |
-
hidden_features = (
|
207 |
-
(hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
208 |
-
)
|
209 |
if ColumnParallelLinear is None or RowParallelLinear is None:
|
210 |
raise ImportError("fused_dense is not installed")
|
211 |
self.fc1 = ColumnParallelLinear(
|
@@ -234,4 +191,4 @@ class ParallelGatedMlp(nn.Module):
|
|
234 |
y, gate = y.chunk(2, dim=-1)
|
235 |
y = y * self.activation(gate)
|
236 |
y = self.fc2(y)
|
237 |
-
return y
|
|
|
8 |
import torch.nn.functional as F
|
9 |
from torch.distributed import ProcessGroup
|
10 |
|
11 |
+
|
12 |
try:
|
13 |
from flash_attn.ops.activations import swiglu
|
14 |
except ImportError:
|
15 |
swiglu = None
|
16 |
|
17 |
try:
|
18 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
|
|
19 |
except ImportError:
|
20 |
ColumnParallelLinear, RowParallelLinear = None, None
|
21 |
|
|
|
41 |
factory_kwargs = {"device": device, "dtype": dtype}
|
42 |
super().__init__()
|
43 |
out_features = out_features if out_features is not None else in_features
|
44 |
+
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
|
|
|
|
45 |
self.return_residual = return_residual
|
46 |
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
47 |
self.activation = activation
|
48 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
def forward(self, x):
|
51 |
+
y = self.fc1(x)
|
52 |
y = self.activation(y)
|
53 |
+
y = self.fc2(y)
|
54 |
+
return y if not self.return_residual else (y, x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
class ParallelMLP(nn.Module):
|
|
|
73 |
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
74 |
assert RowParallelLinear is not None, "Need to install fused_dense"
|
75 |
out_features = out_features if out_features is not None else in_features
|
76 |
+
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
|
|
|
|
77 |
self.fc1 = ColumnParallelLinear(
|
78 |
in_features,
|
79 |
hidden_features,
|
|
|
119 |
hidden_features = (
|
120 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
121 |
)
|
122 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
|
|
|
|
123 |
self.return_residual = return_residual
|
124 |
+
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
|
|
|
|
|
125 |
self.activation = activation
|
126 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
|
|
|
|
127 |
|
128 |
def forward(self, x):
|
129 |
y = self.fc1(x)
|
130 |
if self.activation == F.sigmoid: # Special case for GLU
|
131 |
y = F.glu(y, dim=-1)
|
132 |
+
elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
|
|
|
|
|
133 |
y, gate = y.chunk(2, dim=-1)
|
134 |
y = swiglu(gate, y)
|
135 |
else:
|
|
|
162 |
hidden_features = (
|
163 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
164 |
)
|
165 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
|
|
|
|
166 |
if ColumnParallelLinear is None or RowParallelLinear is None:
|
167 |
raise ImportError("fused_dense is not installed")
|
168 |
self.fc1 = ColumnParallelLinear(
|
|
|
191 |
y, gate = y.chunk(2, dim=-1)
|
192 |
y = y * self.activation(gate)
|
193 |
y = self.fc2(y)
|
194 |
+
return y
|
modeling_lora.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import math
|
2 |
import os
|
|
|
3 |
from functools import partial
|
4 |
from typing import Iterator, List, Optional, Tuple, Union
|
5 |
|
@@ -8,15 +9,12 @@ import torch
|
|
8 |
import torch.nn.utils.parametrize as parametrize
|
9 |
from torch import nn
|
10 |
from torch.nn import Parameter
|
11 |
-
from torch.nn import functional as F
|
12 |
from transformers import PretrainedConfig
|
13 |
|
14 |
-
from .
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
XLMRobertaPreTrainedModel,
|
19 |
-
)
|
20 |
|
21 |
|
22 |
def initialized_weights(
|
@@ -93,19 +91,22 @@ class LoRAParametrization(nn.Module):
|
|
93 |
torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
|
94 |
persistent=False,
|
95 |
)
|
|
|
|
|
96 |
|
97 |
def _dropout(self, A):
|
98 |
# to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
|
99 |
return A * self.lora_dropout(self.lora_dropout_mask)
|
100 |
|
101 |
-
def lora_forward(self, X
|
|
|
102 |
return (
|
103 |
X
|
104 |
+ torch.matmul(
|
105 |
*self.swap(
|
106 |
(
|
107 |
-
self.lora_B[current_task],
|
108 |
-
self.dropout_fn(self.lora_A[current_task]),
|
109 |
)
|
110 |
)
|
111 |
).view(X.shape)
|
@@ -113,7 +114,19 @@ class LoRAParametrization(nn.Module):
|
|
113 |
)
|
114 |
|
115 |
def forward(self, X):
|
116 |
-
return X
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
@classmethod
|
119 |
def from_linear(
|
@@ -166,15 +179,6 @@ class LoRAParametrization(nn.Module):
|
|
166 |
dropout_p: float,
|
167 |
alpha: float,
|
168 |
):
|
169 |
-
"""
|
170 |
-
Registering LoRA adapters to all embedding and linear layers.
|
171 |
-
Additionally, we implement a custom forward function for LoRA parametrization.
|
172 |
-
This function modifies the layer's forward pass to optionally use task-specific
|
173 |
-
parameters. When a `task_id` is provided, it employs a LoRA parametrization
|
174 |
-
to modify the original weights according to the specific task. This allows
|
175 |
-
the layer to adapt dynamically to different tasks at runtime. If no `task_id`
|
176 |
-
is specified, the layer uses its original weights.
|
177 |
-
"""
|
178 |
if isinstance(layer, nn.Linear):
|
179 |
parametrize.register_parametrization(
|
180 |
layer,
|
@@ -187,23 +191,6 @@ class LoRAParametrization(nn.Module):
|
|
187 |
alpha=alpha,
|
188 |
),
|
189 |
)
|
190 |
-
|
191 |
-
def new_forward(self, input, task_id=None, residual=False):
|
192 |
-
if task_id is not None:
|
193 |
-
weights = self.parametrizations.weight[0].lora_forward(
|
194 |
-
self.weight, current_task=task_id
|
195 |
-
)
|
196 |
-
else:
|
197 |
-
weights = self.weight
|
198 |
-
|
199 |
-
out = F.linear(input, weights, self.bias)
|
200 |
-
|
201 |
-
if residual:
|
202 |
-
return out, input
|
203 |
-
return out
|
204 |
-
|
205 |
-
layer.forward = new_forward.__get__(layer, layer.__class__)
|
206 |
-
|
207 |
elif isinstance(layer, nn.Embedding):
|
208 |
parametrize.register_parametrization(
|
209 |
layer,
|
@@ -217,43 +204,22 @@ class LoRAParametrization(nn.Module):
|
|
217 |
),
|
218 |
)
|
219 |
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
)
|
225 |
-
else:
|
226 |
-
weights = self.weight
|
227 |
-
|
228 |
-
out = F.embedding(
|
229 |
-
input,
|
230 |
-
weights,
|
231 |
-
self.padding_idx,
|
232 |
-
self.max_norm,
|
233 |
-
self.norm_type,
|
234 |
-
self.scale_grad_by_freq,
|
235 |
-
self.sparse,
|
236 |
-
)
|
237 |
-
|
238 |
-
return out
|
239 |
-
|
240 |
-
layer.forward = new_forward.__get__(layer, layer.__class__)
|
241 |
|
242 |
|
243 |
class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
244 |
-
"""
|
245 |
-
A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
|
246 |
-
"""
|
247 |
-
|
248 |
def __init__(
|
249 |
self,
|
250 |
config: XLMRobertaFlashConfig,
|
251 |
-
roberta: Optional[XLMRobertaModel] = None
|
252 |
-
add_pooling_layer: bool = True,
|
253 |
):
|
254 |
super().__init__(config)
|
|
|
255 |
if roberta is None:
|
256 |
-
self.roberta = XLMRobertaModel(config
|
257 |
else:
|
258 |
self.roberta = roberta
|
259 |
|
@@ -263,19 +229,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
263 |
or len(self._lora_adaptations) < 1
|
264 |
):
|
265 |
raise ValueError(
|
266 |
-
f
|
267 |
-
)
|
268 |
-
self._task_instructions = config.task_instructions
|
269 |
-
if (
|
270 |
-
not isinstance(self._task_instructions, dict)
|
271 |
-
or len(self._task_instructions) != len(self._lora_adaptations)
|
272 |
-
or not all(
|
273 |
-
[v in self._lora_adaptations for v in self._task_instructions.keys()]
|
274 |
-
)
|
275 |
-
):
|
276 |
-
raise ValueError(
|
277 |
-
f"`task_instructions` must be a dict and contain the same number of elements "
|
278 |
-
f"as `lora_adaptations` with all keys in `task_instructions` present in `lora_adaptations`."
|
279 |
)
|
280 |
self._adaptation_map = {
|
281 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
@@ -290,14 +244,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
290 |
alpha=self._alpha,
|
291 |
)
|
292 |
self.main_params_trainable = config.lora_main_params_trainable
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
return self.roberta.rotary_emb_base
|
297 |
-
|
298 |
-
@rotary_emb_base.setter
|
299 |
-
def rotary_emb_base(self, base):
|
300 |
-
self.roberta.rotary_emb_base = base
|
301 |
|
302 |
@property
|
303 |
def main_params_trainable(self):
|
@@ -331,30 +280,16 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
331 |
use_safetensors: bool = None,
|
332 |
**kwargs,
|
333 |
):
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
|
|
338 |
return super().from_pretrained(
|
339 |
-
pretrained_model_name_or_path,
|
340 |
-
*model_args,
|
341 |
-
config=config,
|
342 |
-
cache_dir=cache_dir,
|
343 |
-
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
344 |
-
force_download=force_download,
|
345 |
-
local_files_only=local_files_only,
|
346 |
-
token=token,
|
347 |
-
revision=revision,
|
348 |
-
use_safetensors=use_safetensors,
|
349 |
-
**kwargs,
|
350 |
-
)
|
351 |
-
else: # initializing new adapters
|
352 |
-
roberta = XLMRobertaModel.from_pretrained(
|
353 |
-
pretrained_model_name_or_path,
|
354 |
-
*model_args,
|
355 |
-
use_flash_attn=config.use_flash_attn,
|
356 |
-
**kwargs,
|
357 |
)
|
|
|
|
|
358 |
return cls(config, roberta=roberta)
|
359 |
|
360 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
@@ -368,7 +303,39 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
368 |
)
|
369 |
)
|
370 |
|
371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
return self.roberta(*args, **kwargs)
|
373 |
|
374 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
@@ -387,40 +354,28 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
387 |
@torch.inference_mode()
|
388 |
def encode(
|
389 |
self,
|
390 |
-
sentences: Union[str, List[str]],
|
391 |
*args,
|
392 |
-
task:
|
393 |
**kwargs,
|
394 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
395 |
"""
|
396 |
-
Computes sentence embeddings
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
|
|
|
|
403 |
"""
|
404 |
-
if task
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
task_id = self._adaptation_map[task]
|
413 |
-
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
414 |
-
adapter_mask = torch.full(
|
415 |
-
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
416 |
-
)
|
417 |
-
if isinstance(sentences, str):
|
418 |
-
sentences = self._task_instructions[task] + sentences
|
419 |
-
else:
|
420 |
-
sentences = [
|
421 |
-
self._task_instructions[task] + sentence for sentence in sentences
|
422 |
-
]
|
423 |
-
return self.roberta.encode(
|
424 |
-
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
425 |
-
)
|
426 |
|
|
|
|
1 |
import math
|
2 |
import os
|
3 |
+
import warnings
|
4 |
from functools import partial
|
5 |
from typing import Iterator, List, Optional, Tuple, Union
|
6 |
|
|
|
9 |
import torch.nn.utils.parametrize as parametrize
|
10 |
from torch import nn
|
11 |
from torch.nn import Parameter
|
|
|
12 |
from transformers import PretrainedConfig
|
13 |
|
14 |
+
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
|
15 |
+
|
16 |
+
|
17 |
+
LORA_NO_UPDATE = '__lora_no_update__'
|
|
|
|
|
18 |
|
19 |
|
20 |
def initialized_weights(
|
|
|
91 |
torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
|
92 |
persistent=False,
|
93 |
)
|
94 |
+
self.forward_fn = lambda x: x
|
95 |
+
self.current_task = None
|
96 |
|
97 |
def _dropout(self, A):
|
98 |
# to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
|
99 |
return A * self.lora_dropout(self.lora_dropout_mask)
|
100 |
|
101 |
+
def lora_forward(self, X):
|
102 |
+
assert self.current_task is not None
|
103 |
return (
|
104 |
X
|
105 |
+ torch.matmul(
|
106 |
*self.swap(
|
107 |
(
|
108 |
+
self.lora_B[self.current_task],
|
109 |
+
self.dropout_fn(self.lora_A[self.current_task]),
|
110 |
)
|
111 |
)
|
112 |
).view(X.shape)
|
|
|
114 |
)
|
115 |
|
116 |
def forward(self, X):
|
117 |
+
return self.forward_fn(X)
|
118 |
+
|
119 |
+
@property
|
120 |
+
def current_task(self):
|
121 |
+
return self._current_task
|
122 |
+
|
123 |
+
@current_task.setter
|
124 |
+
def current_task(self, task: Union[None, int]):
|
125 |
+
self._current_task = task
|
126 |
+
if task is None:
|
127 |
+
self.forward_fn = lambda x: x
|
128 |
+
else:
|
129 |
+
self.forward_fn = self.lora_forward
|
130 |
|
131 |
@classmethod
|
132 |
def from_linear(
|
|
|
179 |
dropout_p: float,
|
180 |
alpha: float,
|
181 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
if isinstance(layer, nn.Linear):
|
183 |
parametrize.register_parametrization(
|
184 |
layer,
|
|
|
191 |
alpha=alpha,
|
192 |
),
|
193 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
elif isinstance(layer, nn.Embedding):
|
195 |
parametrize.register_parametrization(
|
196 |
layer,
|
|
|
204 |
),
|
205 |
)
|
206 |
|
207 |
+
@staticmethod
|
208 |
+
def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
|
209 |
+
if isinstance(layer, LoRAParametrization):
|
210 |
+
layer.current_task = task_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
|
213 |
class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
|
|
|
|
|
|
|
214 |
def __init__(
|
215 |
self,
|
216 |
config: XLMRobertaFlashConfig,
|
217 |
+
roberta: Optional[XLMRobertaModel] = None
|
|
|
218 |
):
|
219 |
super().__init__(config)
|
220 |
+
|
221 |
if roberta is None:
|
222 |
+
self.roberta = XLMRobertaModel(config)
|
223 |
else:
|
224 |
self.roberta = roberta
|
225 |
|
|
|
229 |
or len(self._lora_adaptations) < 1
|
230 |
):
|
231 |
raise ValueError(
|
232 |
+
f'`lora_adaptations` must be a list and contain at least one element'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
)
|
234 |
self._adaptation_map = {
|
235 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
|
|
244 |
alpha=self._alpha,
|
245 |
)
|
246 |
self.main_params_trainable = config.lora_main_params_trainable
|
247 |
+
self._task_idx = None
|
248 |
+
# By default, disable LoRA until it's specified which adapter/task to use
|
249 |
+
self.current_task = None
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
@property
|
252 |
def main_params_trainable(self):
|
|
|
280 |
use_safetensors: bool = None,
|
281 |
**kwargs,
|
282 |
):
|
283 |
+
config = XLMRobertaFlashConfig.from_pretrained(
|
284 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
285 |
+
)
|
286 |
+
|
287 |
+
if config.load_trained_adapters:
|
288 |
return super().from_pretrained(
|
289 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
)
|
291 |
+
else:
|
292 |
+
roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
293 |
return cls(config, roberta=roberta)
|
294 |
|
295 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
|
|
303 |
)
|
304 |
)
|
305 |
|
306 |
+
@property
|
307 |
+
def current_task(self):
|
308 |
+
"""Which LoRA is currently selected
|
309 |
+
:return: Integer or None (when LoRA is disabled)
|
310 |
+
"""
|
311 |
+
return self._task_idx
|
312 |
+
|
313 |
+
@current_task.setter
|
314 |
+
def current_task(self, task_name: Union[None, str]):
|
315 |
+
"""Set the LoRA that is to be used.
|
316 |
+
The LoRA is specified by `task_idx`, which may be an integer >= 0,
|
317 |
+
indexing the available LoRAs. If it is None, no LoRA is used.
|
318 |
+
:param task_name: Which LoRA to use
|
319 |
+
:return:
|
320 |
+
"""
|
321 |
+
if task_name and task_name not in self._lora_adaptations:
|
322 |
+
raise ValueError(
|
323 |
+
f"Unsupported task '{task_name}'. "
|
324 |
+
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
325 |
+
f"Alternatively, set `task` to `None` if you want to disable LoRA."
|
326 |
+
)
|
327 |
+
task_idx = self._adaptation_map[task_name] if task_name else None
|
328 |
+
if self._task_idx != task_idx:
|
329 |
+
# In this case, we need to update the LoRAs everywhere
|
330 |
+
self._task_idx = task_idx
|
331 |
+
self.apply(
|
332 |
+
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
333 |
+
)
|
334 |
+
|
335 |
+
def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
|
336 |
+
if task != LORA_NO_UPDATE:
|
337 |
+
self.current_task = task
|
338 |
+
|
339 |
return self.roberta(*args, **kwargs)
|
340 |
|
341 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
|
|
354 |
@torch.inference_mode()
|
355 |
def encode(
|
356 |
self,
|
|
|
357 |
*args,
|
358 |
+
task: Union[str, None] = LORA_NO_UPDATE,
|
359 |
**kwargs,
|
360 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
361 |
"""
|
362 |
+
Computes sentence embeddings
|
363 |
+
|
364 |
+
task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
|
365 |
+
Specifies the task for which the encoding is intended. This parameter controls the
|
366 |
+
use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
|
367 |
+
to `LORA_NO_UPDATE`, there will be no update to the current task, retaining the
|
368 |
+
existing adapter configuration. If `task` is explicitly set to `None`, all LoRA
|
369 |
+
adapters are disabled, and the model reverts to its original, general-purpose weights.
|
370 |
+
If `task` is set to a specific LoRA adaptation, that adaptation is activated.
|
371 |
"""
|
372 |
+
if task != LORA_NO_UPDATE:
|
373 |
+
if not task:
|
374 |
+
warnings.warn(
|
375 |
+
f"Task-specific embeddings are disabled. To enable, specify the `task` "
|
376 |
+
f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
|
377 |
+
category=UserWarning,
|
378 |
+
)
|
379 |
+
self.current_task = task
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
|
381 |
+
return self.roberta.encode(*args, **kwargs)
|
modeling_xlm_roberta.py
CHANGED
@@ -13,30 +13,39 @@ import re
|
|
13 |
from collections import OrderedDict
|
14 |
from collections.abc import Sequence
|
15 |
from functools import partial
|
16 |
-
from typing import List, Optional, Tuple, Union
|
17 |
-
|
18 |
import numpy as np
|
|
|
19 |
import torch
|
20 |
import torch.nn as nn
|
21 |
import torch.nn.functional as F
|
22 |
import torch.utils.checkpoint
|
23 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
24 |
-
from
|
25 |
-
from transformers
|
26 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
|
|
|
27 |
from transformers.models.bert.modeling_bert import (
|
28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
29 |
BertForPreTrainingOutput,
|
30 |
)
|
31 |
-
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
|
32 |
|
33 |
-
from
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
|
|
36 |
from .embedding import XLMRobertaEmbeddings
|
37 |
from .mha import MHA
|
38 |
from .mlp import FusedMLP, Mlp
|
39 |
-
from .
|
|
|
40 |
|
41 |
try:
|
42 |
from flash_attn.ops.fused_dense import FusedDense
|
@@ -64,11 +73,13 @@ logger = logging.getLogger(__name__)
|
|
64 |
|
65 |
|
66 |
def get_use_flash_attn(config: XLMRobertaFlashConfig):
|
67 |
-
if not getattr(config, "use_flash_attn", False)
|
|
|
|
|
68 |
return False
|
69 |
if importlib.util.find_spec("flash_attn") is None:
|
70 |
logger.warning(
|
71 |
-
|
72 |
)
|
73 |
return False
|
74 |
return True
|
@@ -80,9 +91,9 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
80 |
rotary_kwargs = {}
|
81 |
if config.position_embedding_type == "rotary":
|
82 |
rotary_kwargs["rotary_emb_dim"] = getattr(
|
83 |
-
config, "rotary_emb_dim", config.hidden_size
|
84 |
)
|
85 |
-
rotary_kwargs["rotary_emb_base"] = config.
|
86 |
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
87 |
config, "rotary_emb_scale_base", None
|
88 |
)
|
@@ -98,7 +109,6 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
98 |
fused_bias_fc=fused_bias_fc,
|
99 |
use_flash_attn=use_flash_attn,
|
100 |
return_residual=return_residual,
|
101 |
-
use_alibi=config.position_embedding_type == "alibi",
|
102 |
**rotary_kwargs,
|
103 |
)
|
104 |
return mixer_cls
|
@@ -180,7 +190,6 @@ class XLMRobertaEncoder(nn.Module):
|
|
180 |
def __init__(self, config: XLMRobertaFlashConfig):
|
181 |
super().__init__()
|
182 |
self.use_flash_attn = get_use_flash_attn(config)
|
183 |
-
self.use_reentrant = config.use_reentrant
|
184 |
self.layers = nn.ModuleList(
|
185 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
186 |
)
|
@@ -194,72 +203,46 @@ class XLMRobertaEncoder(nn.Module):
|
|
194 |
def gradient_checkpointing(self, value):
|
195 |
self._grad_checkpointing = value
|
196 |
|
197 |
-
def forward(
|
198 |
-
self,
|
199 |
-
hidden_states,
|
200 |
-
key_padding_mask=None,
|
201 |
-
subset_mask=None,
|
202 |
-
adapter_mask=None,
|
203 |
-
output_hidden_states: Optional[bool] = None,
|
204 |
-
):
|
205 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
206 |
This means that we only compute the last layer output for these tokens.
|
207 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
208 |
"""
|
209 |
-
|
210 |
-
all_hidden_states = () if output_hidden_states else None
|
211 |
-
|
212 |
-
if output_hidden_states and subset_mask:
|
213 |
-
raise ValueError('output_hidden_states is not supported for subset_masks')
|
214 |
-
|
215 |
if key_padding_mask is None or not self.use_flash_attn:
|
216 |
-
mixer_kwargs =
|
217 |
-
|
218 |
-
|
|
|
|
|
219 |
for layer in self.layers:
|
220 |
-
if output_hidden_states:
|
221 |
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
222 |
if self._grad_checkpointing:
|
223 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
224 |
layer,
|
225 |
hidden_states,
|
226 |
-
use_reentrant=
|
227 |
mixer_kwargs=mixer_kwargs,
|
228 |
)
|
229 |
else:
|
230 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
231 |
-
if output_hidden_states:
|
232 |
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
233 |
if subset_mask is not None:
|
234 |
hidden_states = hidden_states[subset_mask]
|
235 |
else:
|
236 |
batch, seqlen = hidden_states.shape[:2]
|
237 |
-
|
238 |
-
|
239 |
-
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = (
|
240 |
-
unpad_input(hidden_states, key_padding_mask, adapter_mask)
|
241 |
)
|
242 |
-
mixer_kwargs = {
|
243 |
-
"cu_seqlens": cu_seqlens,
|
244 |
-
"max_seqlen": max_seqlen_in_batch,
|
245 |
-
"adapter_mask": cu_adapter_mask,
|
246 |
-
}
|
247 |
-
|
248 |
if subset_mask is None:
|
249 |
for layer in self.layers:
|
250 |
if self._grad_checkpointing:
|
251 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
252 |
layer,
|
253 |
hidden_states,
|
254 |
-
use_reentrant=
|
255 |
mixer_kwargs=mixer_kwargs,
|
256 |
)
|
257 |
else:
|
258 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
259 |
-
if output_hidden_states:
|
260 |
-
all_hidden_states = all_hidden_states + (
|
261 |
-
pad_input(hidden_states, indices, batch, seqlen),
|
262 |
-
)
|
263 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
264 |
else:
|
265 |
for layer in self.layers[:-1]:
|
@@ -267,7 +250,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
267 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
268 |
layer,
|
269 |
hidden_states,
|
270 |
-
use_reentrant=
|
271 |
mixer_kwargs=mixer_kwargs,
|
272 |
)
|
273 |
else:
|
@@ -305,14 +288,14 @@ class XLMRobertaEncoder(nn.Module):
|
|
305 |
torch.utils.checkpoint.checkpoint(
|
306 |
self.layers[-1],
|
307 |
hidden_states_subset,
|
308 |
-
use_reentrant=
|
309 |
mixer_kwargs=mixer_kwargs,
|
310 |
)
|
311 |
else:
|
312 |
hidden_states = self.layers[-1](
|
313 |
hidden_states_subset, mixer_kwargs=mixer_kwargs
|
314 |
)
|
315 |
-
return
|
316 |
|
317 |
|
318 |
class XLMRobertaPooler(nn.Module):
|
@@ -325,28 +308,11 @@ class XLMRobertaPooler(nn.Module):
|
|
325 |
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
326 |
self.activation = nn.Tanh()
|
327 |
|
328 |
-
def forward(self, hidden_states, pool=True
|
329 |
# We "pool" the model by simply taking the hidden state corresponding
|
330 |
# to the first token.
|
331 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
332 |
-
|
333 |
-
unique_tasks = torch.unique(adapter_mask)
|
334 |
-
pool_dtype = next(self.dense.parameters()).dtype
|
335 |
-
pooled_output = torch.empty(
|
336 |
-
first_token_tensor.shape[0],
|
337 |
-
self.dense.out_features,
|
338 |
-
dtype=pool_dtype,
|
339 |
-
device=first_token_tensor.device,
|
340 |
-
)
|
341 |
-
for task_id in unique_tasks:
|
342 |
-
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
343 |
-
task_first_token_tensor = first_token_tensor[task_indices]
|
344 |
-
task_pooled_output = self.dense(
|
345 |
-
task_first_token_tensor, task_id=task_id
|
346 |
-
)
|
347 |
-
pooled_output[task_indices] = task_pooled_output
|
348 |
-
else:
|
349 |
-
pooled_output = self.dense(first_token_tensor)
|
350 |
pooled_output = self.activation(pooled_output)
|
351 |
return pooled_output
|
352 |
|
@@ -425,7 +391,6 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|
425 |
config_class = XLMRobertaFlashConfig
|
426 |
base_model_prefix = "roberta"
|
427 |
supports_gradient_checkpointing = True
|
428 |
-
_supports_param_buffer_assignment = False
|
429 |
|
430 |
def _set_gradient_checkpointing(self, module, value=False):
|
431 |
if isinstance(module, XLMRobertaEncoder):
|
@@ -437,11 +402,12 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|
437 |
*args,
|
438 |
**kwargs,
|
439 |
):
|
440 |
-
if not
|
441 |
-
kwargs[
|
442 |
return super().from_pretrained(*args, **kwargs)
|
443 |
|
444 |
|
|
|
445 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
446 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
447 |
super().__init__(config)
|
@@ -459,14 +425,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
459 |
"gelu_fast",
|
460 |
"gelu_pytorch_tanh",
|
461 |
]
|
|
|
462 |
self.embeddings = XLMRobertaEmbeddings(
|
463 |
config.hidden_size,
|
464 |
config.vocab_size,
|
465 |
-
|
466 |
-
config.max_position_embeddings
|
467 |
-
if config.position_embedding_type == "absolute"
|
468 |
-
else -1
|
469 |
-
),
|
470 |
config.type_vocab_size,
|
471 |
padding_idx=config.pad_token_id,
|
472 |
)
|
@@ -476,25 +439,20 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
476 |
self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
|
477 |
|
478 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
479 |
-
|
480 |
-
self.name_or_path, trust_remote_code=True
|
481 |
-
)
|
482 |
-
self._rotary_emb_base = config.rotary_emb_base
|
483 |
|
484 |
@torch.inference_mode()
|
485 |
def encode(
|
486 |
-
self:
|
487 |
sentences: Union[str, List[str]],
|
488 |
batch_size: int = 32,
|
489 |
show_progress_bar: Optional[bool] = None,
|
490 |
-
output_value: str =
|
491 |
convert_to_numpy: bool = True,
|
492 |
convert_to_tensor: bool = False,
|
493 |
device: Optional[torch.device] = None,
|
494 |
-
normalize_embeddings: bool =
|
495 |
truncate_dim: Optional[int] = None,
|
496 |
-
adapter_mask: Optional[torch.Tensor] = None,
|
497 |
-
task: Optional[str] = None,
|
498 |
**tokenizer_kwargs,
|
499 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
500 |
"""
|
@@ -520,7 +478,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
520 |
Overwrites any setting from convert_to_numpy
|
521 |
device(`torch.device`, *optional*, defaults to None):
|
522 |
Which torch.device to use for the computation
|
523 |
-
normalize_embeddings(`bool`, *optional*, defaults to
|
524 |
If set to true, returned vectors will have length 1. In that case, the
|
525 |
faster dot-product (util.dot_score) instead of cosine similarity can
|
526 |
be used.
|
@@ -533,6 +491,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
533 |
If convert_to_tensor, a stacked tensor is returned.
|
534 |
If convert_to_numpy, a numpy matrix is returned.
|
535 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
is_training = self.training
|
537 |
self.eval()
|
538 |
|
@@ -545,12 +509,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
545 |
if convert_to_tensor:
|
546 |
convert_to_numpy = False
|
547 |
|
548 |
-
if output_value !=
|
549 |
convert_to_tensor = False
|
550 |
convert_to_numpy = False
|
551 |
|
552 |
input_was_string = False
|
553 |
-
if isinstance(sentences, str) or not hasattr(sentences,
|
554 |
sentences = [sentences]
|
555 |
input_was_string = True
|
556 |
|
@@ -561,11 +525,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
561 |
inverse_permutation = np.argsort(permutation)
|
562 |
sentences = [sentences[idx] for idx in permutation]
|
563 |
|
564 |
-
tokenizer_kwargs[
|
565 |
-
tokenizer_kwargs[
|
566 |
-
|
567 |
)
|
568 |
-
tokenizer_kwargs[
|
569 |
|
570 |
all_embeddings = []
|
571 |
|
@@ -583,33 +547,33 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
583 |
for i in range_iter:
|
584 |
encoded_input = self.tokenizer(
|
585 |
sentences[i : i + batch_size],
|
586 |
-
return_tensors=
|
587 |
**tokenizer_kwargs,
|
588 |
).to(self.device)
|
589 |
-
|
590 |
-
{"adapter_mask": adapter_mask[i : i + batch_size]}
|
591 |
-
if adapter_mask is not None
|
592 |
-
else {}
|
593 |
-
)
|
594 |
-
token_embs = self.forward(**encoded_input, **lora_arguments)[0]
|
595 |
|
596 |
# Accumulate in fp32 to avoid overflow
|
597 |
token_embs = token_embs.float()
|
598 |
|
599 |
-
if output_value ==
|
600 |
raise NotImplementedError
|
601 |
elif output_value is None:
|
602 |
raise NotImplementedError
|
603 |
else:
|
604 |
-
if self.config.emb_pooler ==
|
605 |
embeddings = self.cls_pooling(
|
606 |
-
token_embs, encoded_input[
|
607 |
)
|
608 |
else:
|
609 |
embeddings = self.mean_pooling(
|
610 |
-
token_embs, encoded_input[
|
611 |
)
|
612 |
|
|
|
|
|
|
|
|
|
|
|
613 |
all_embeddings.extend(embeddings)
|
614 |
|
615 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
@@ -618,16 +582,10 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
618 |
if truncate_dim:
|
619 |
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
620 |
|
621 |
-
if normalize_embeddings:
|
622 |
-
all_embeddings = [
|
623 |
-
torch.nn.functional.normalize(embedding, p=2, dim=0)
|
624 |
-
for embedding in all_embeddings
|
625 |
-
]
|
626 |
-
|
627 |
if convert_to_tensor:
|
628 |
all_embeddings = torch.stack(all_embeddings)
|
629 |
elif convert_to_numpy:
|
630 |
-
all_embeddings = np.asarray([emb.
|
631 |
|
632 |
if input_was_string:
|
633 |
all_embeddings = all_embeddings[0]
|
@@ -635,19 +593,18 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
635 |
self.train(is_training)
|
636 |
return all_embeddings
|
637 |
|
|
|
638 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
639 |
if not self.config.matryoshka_dimensions:
|
640 |
logger.warning(
|
641 |
-
|
642 |
)
|
643 |
return embeddings
|
644 |
elif truncate_dim in self.config.matryoshka_dimensions:
|
645 |
return [tensor[:truncate_dim] for tensor in embeddings]
|
646 |
else:
|
647 |
-
raise ValueError(
|
648 |
-
|
649 |
-
f"Supported dimensions are {self.config.matryoshka_dimensions}."
|
650 |
-
)
|
651 |
|
652 |
def mean_pooling(
|
653 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
@@ -659,21 +616,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
659 |
input_mask_expanded.sum(1), min=1e-9
|
660 |
)
|
661 |
|
662 |
-
def cls_pooling(self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor):
|
663 |
-
return token_embeddings[:, 0]
|
664 |
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
def rotary_emb_base(self, base):
|
671 |
-
if not isinstance(base, (int, float)):
|
672 |
-
raise TypeError("Base must be an integer or float")
|
673 |
-
logger.info(f"Changing RoPE base value to {base}")
|
674 |
-
for layer in self.encoder.layers:
|
675 |
-
layer.mixer.rotary_emb.base = base
|
676 |
-
self._rotary_emb_base = base
|
677 |
|
678 |
def forward(
|
679 |
self,
|
@@ -683,7 +631,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
683 |
attention_mask=None,
|
684 |
masked_tokens_mask=None,
|
685 |
return_dict=None,
|
686 |
-
output_hidden_states=None,
|
687 |
**kwargs,
|
688 |
):
|
689 |
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
|
@@ -691,12 +638,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
691 |
layer output for these tokens.
|
692 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
693 |
"""
|
694 |
-
|
695 |
if kwargs:
|
696 |
for key, value in kwargs.items():
|
697 |
if value is not None:
|
698 |
logger.warning(
|
699 |
-
|
700 |
key,
|
701 |
)
|
702 |
|
@@ -705,10 +652,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
705 |
)
|
706 |
|
707 |
hidden_states = self.embeddings(
|
708 |
-
input_ids,
|
709 |
-
position_ids=position_ids,
|
710 |
-
token_type_ids=token_type_ids,
|
711 |
-
adapter_mask=adapter_mask,
|
712 |
)
|
713 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
714 |
# BERT puts embedding LayerNorm before embedding dropout.
|
@@ -732,24 +676,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
732 |
subset_mask = None
|
733 |
|
734 |
sequence_output = self.encoder(
|
735 |
-
hidden_states,
|
736 |
-
key_padding_mask=attention_mask,
|
737 |
-
subset_mask=subset_mask,
|
738 |
-
adapter_mask=adapter_mask,
|
739 |
-
output_hidden_states=output_hidden_states,
|
740 |
)
|
741 |
|
742 |
-
if output_hidden_states:
|
743 |
-
all_hidden_states = sequence_output
|
744 |
-
sequence_output = sequence_output[-1]
|
745 |
-
else:
|
746 |
-
all_hidden_states = None
|
747 |
-
|
748 |
if masked_tokens_mask is None:
|
749 |
pooled_output = (
|
750 |
-
self.pooler(sequence_output
|
751 |
-
if self.pooler is not None
|
752 |
-
else None
|
753 |
)
|
754 |
else:
|
755 |
# TD [2022-03-01]: the indexing here is very tricky.
|
@@ -763,9 +695,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
763 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
764 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
765 |
pooled_output = (
|
766 |
-
self.pooler(pool_input, pool=False
|
767 |
-
if self.pooler is not None
|
768 |
-
else None
|
769 |
)
|
770 |
|
771 |
if not return_dict:
|
@@ -774,7 +704,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
774 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
775 |
last_hidden_state=sequence_output,
|
776 |
pooler_output=pooled_output,
|
777 |
-
hidden_states=all_hidden_states,
|
778 |
)
|
779 |
|
780 |
|
@@ -871,6 +800,103 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
|
871 |
)
|
872 |
|
873 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
874 |
def remap_state_dict(state_dict, config: PretrainedConfig):
|
875 |
"""
|
876 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
@@ -1022,47 +1048,47 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
|
1022 |
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
1023 |
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
1024 |
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
1025 |
-
state_dict[
|
1026 |
-
|
1027 |
-
|
1028 |
-
state_dict[
|
1029 |
-
|
1030 |
-
|
1031 |
-
]
|
1032 |
-
|
1033 |
-
state_dict[
|
1034 |
-
|
1035 |
-
|
1036 |
-
state_dict[
|
1037 |
-
|
1038 |
-
|
1039 |
-
state_dict[
|
1040 |
-
|
1041 |
-
|
1042 |
-
state_dict[
|
1043 |
-
|
1044 |
-
|
1045 |
else:
|
1046 |
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
1047 |
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
1048 |
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
1049 |
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
1050 |
-
state_dict[
|
1051 |
-
|
1052 |
-
|
1053 |
-
state_dict[
|
1054 |
-
|
1055 |
-
|
1056 |
-
state_dict[
|
1057 |
-
|
1058 |
-
|
1059 |
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
1060 |
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
1061 |
: Wkv_biases.shape[0] // 2
|
1062 |
]
|
1063 |
-
state_dict[
|
1064 |
-
|
1065 |
-
|
1066 |
|
1067 |
def inv_key_mapping_ln(key):
|
1068 |
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
|
|
13 |
from collections import OrderedDict
|
14 |
from collections.abc import Sequence
|
15 |
from functools import partial
|
|
|
|
|
16 |
import numpy as np
|
17 |
+
|
18 |
import torch
|
19 |
import torch.nn as nn
|
20 |
import torch.nn.functional as F
|
21 |
import torch.utils.checkpoint
|
22 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
23 |
+
from einops import rearrange
|
24 |
+
from transformers import PretrainedConfig
|
25 |
from transformers.modeling_utils import PreTrainedModel
|
26 |
+
from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
|
27 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
|
28 |
+
|
29 |
from transformers.models.bert.modeling_bert import (
|
30 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
31 |
BertForPreTrainingOutput,
|
32 |
)
|
|
|
33 |
|
34 |
+
from typing import List, Optional, Tuple, Union
|
35 |
+
|
36 |
+
from .xlm_padding import (
|
37 |
+
index_first_axis,
|
38 |
+
index_first_axis_residual,
|
39 |
+
pad_input,
|
40 |
+
unpad_input,
|
41 |
+
)
|
42 |
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
43 |
+
from .block import Block
|
44 |
from .embedding import XLMRobertaEmbeddings
|
45 |
from .mha import MHA
|
46 |
from .mlp import FusedMLP, Mlp
|
47 |
+
from .stochastic_depth import StochasticDepth
|
48 |
+
|
49 |
|
50 |
try:
|
51 |
from flash_attn.ops.fused_dense import FusedDense
|
|
|
73 |
|
74 |
|
75 |
def get_use_flash_attn(config: XLMRobertaFlashConfig):
|
76 |
+
if not getattr(config, "use_flash_attn", False):
|
77 |
+
return False
|
78 |
+
if not torch.cuda.is_available():
|
79 |
return False
|
80 |
if importlib.util.find_spec("flash_attn") is None:
|
81 |
logger.warning(
|
82 |
+
'flash_attn is not installed. Using PyTorch native attention implementation.'
|
83 |
)
|
84 |
return False
|
85 |
return True
|
|
|
91 |
rotary_kwargs = {}
|
92 |
if config.position_embedding_type == "rotary":
|
93 |
rotary_kwargs["rotary_emb_dim"] = getattr(
|
94 |
+
config, "rotary_emb_dim", config.hidden_size
|
95 |
)
|
96 |
+
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
|
97 |
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
98 |
config, "rotary_emb_scale_base", None
|
99 |
)
|
|
|
109 |
fused_bias_fc=fused_bias_fc,
|
110 |
use_flash_attn=use_flash_attn,
|
111 |
return_residual=return_residual,
|
|
|
112 |
**rotary_kwargs,
|
113 |
)
|
114 |
return mixer_cls
|
|
|
190 |
def __init__(self, config: XLMRobertaFlashConfig):
|
191 |
super().__init__()
|
192 |
self.use_flash_attn = get_use_flash_attn(config)
|
|
|
193 |
self.layers = nn.ModuleList(
|
194 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
195 |
)
|
|
|
203 |
def gradient_checkpointing(self, value):
|
204 |
self._grad_checkpointing = value
|
205 |
|
206 |
+
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
208 |
This means that we only compute the last layer output for these tokens.
|
209 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
210 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
if key_padding_mask is None or not self.use_flash_attn:
|
212 |
+
mixer_kwargs = (
|
213 |
+
{"key_padding_mask": key_padding_mask.bool()}
|
214 |
+
if key_padding_mask is not None
|
215 |
+
else None
|
216 |
+
)
|
217 |
for layer in self.layers:
|
|
|
|
|
218 |
if self._grad_checkpointing:
|
219 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
220 |
layer,
|
221 |
hidden_states,
|
222 |
+
use_reentrant=False,
|
223 |
mixer_kwargs=mixer_kwargs,
|
224 |
)
|
225 |
else:
|
226 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
|
|
227 |
if subset_mask is not None:
|
228 |
hidden_states = hidden_states[subset_mask]
|
229 |
else:
|
230 |
batch, seqlen = hidden_states.shape[:2]
|
231 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
232 |
+
hidden_states, key_padding_mask
|
|
|
|
|
233 |
)
|
234 |
+
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
|
|
|
|
|
|
|
|
|
|
235 |
if subset_mask is None:
|
236 |
for layer in self.layers:
|
237 |
if self._grad_checkpointing:
|
238 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
239 |
layer,
|
240 |
hidden_states,
|
241 |
+
use_reentrant=False,
|
242 |
mixer_kwargs=mixer_kwargs,
|
243 |
)
|
244 |
else:
|
245 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
|
|
|
|
|
|
246 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
247 |
else:
|
248 |
for layer in self.layers[:-1]:
|
|
|
250 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
251 |
layer,
|
252 |
hidden_states,
|
253 |
+
use_reentrant=False,
|
254 |
mixer_kwargs=mixer_kwargs,
|
255 |
)
|
256 |
else:
|
|
|
288 |
torch.utils.checkpoint.checkpoint(
|
289 |
self.layers[-1],
|
290 |
hidden_states_subset,
|
291 |
+
use_reentrant=False,
|
292 |
mixer_kwargs=mixer_kwargs,
|
293 |
)
|
294 |
else:
|
295 |
hidden_states = self.layers[-1](
|
296 |
hidden_states_subset, mixer_kwargs=mixer_kwargs
|
297 |
)
|
298 |
+
return hidden_states
|
299 |
|
300 |
|
301 |
class XLMRobertaPooler(nn.Module):
|
|
|
308 |
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
309 |
self.activation = nn.Tanh()
|
310 |
|
311 |
+
def forward(self, hidden_states, pool=True):
|
312 |
# We "pool" the model by simply taking the hidden state corresponding
|
313 |
# to the first token.
|
314 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
315 |
+
pooled_output = self.dense(first_token_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
pooled_output = self.activation(pooled_output)
|
317 |
return pooled_output
|
318 |
|
|
|
391 |
config_class = XLMRobertaFlashConfig
|
392 |
base_model_prefix = "roberta"
|
393 |
supports_gradient_checkpointing = True
|
|
|
394 |
|
395 |
def _set_gradient_checkpointing(self, module, value=False):
|
396 |
if isinstance(module, XLMRobertaEncoder):
|
|
|
402 |
*args,
|
403 |
**kwargs,
|
404 |
):
|
405 |
+
if not 'torch_dtype' in kwargs:
|
406 |
+
kwargs['torch_dtype'] = 'auto'
|
407 |
return super().from_pretrained(*args, **kwargs)
|
408 |
|
409 |
|
410 |
+
|
411 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
412 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
413 |
super().__init__(config)
|
|
|
425 |
"gelu_fast",
|
426 |
"gelu_pytorch_tanh",
|
427 |
]
|
428 |
+
|
429 |
self.embeddings = XLMRobertaEmbeddings(
|
430 |
config.hidden_size,
|
431 |
config.vocab_size,
|
432 |
+
config.max_position_embeddings,
|
|
|
|
|
|
|
|
|
433 |
config.type_vocab_size,
|
434 |
padding_idx=config.pad_token_id,
|
435 |
)
|
|
|
439 |
self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
|
440 |
|
441 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
442 |
+
|
|
|
|
|
|
|
443 |
|
444 |
@torch.inference_mode()
|
445 |
def encode(
|
446 |
+
self: 'XLMRobertaModel',
|
447 |
sentences: Union[str, List[str]],
|
448 |
batch_size: int = 32,
|
449 |
show_progress_bar: Optional[bool] = None,
|
450 |
+
output_value: str = 'sentence_embedding',
|
451 |
convert_to_numpy: bool = True,
|
452 |
convert_to_tensor: bool = False,
|
453 |
device: Optional[torch.device] = None,
|
454 |
+
normalize_embeddings: bool = False,
|
455 |
truncate_dim: Optional[int] = None,
|
|
|
|
|
456 |
**tokenizer_kwargs,
|
457 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
458 |
"""
|
|
|
478 |
Overwrites any setting from convert_to_numpy
|
479 |
device(`torch.device`, *optional*, defaults to None):
|
480 |
Which torch.device to use for the computation
|
481 |
+
normalize_embeddings(`bool`, *optional*, defaults to False):
|
482 |
If set to true, returned vectors will have length 1. In that case, the
|
483 |
faster dot-product (util.dot_score) instead of cosine similarity can
|
484 |
be used.
|
|
|
491 |
If convert_to_tensor, a stacked tensor is returned.
|
492 |
If convert_to_numpy, a numpy matrix is returned.
|
493 |
"""
|
494 |
+
from transformers import AutoTokenizer
|
495 |
+
|
496 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
497 |
+
self.name_or_path, trust_remote_code=True
|
498 |
+
)
|
499 |
+
|
500 |
is_training = self.training
|
501 |
self.eval()
|
502 |
|
|
|
509 |
if convert_to_tensor:
|
510 |
convert_to_numpy = False
|
511 |
|
512 |
+
if output_value != 'sentence_embedding':
|
513 |
convert_to_tensor = False
|
514 |
convert_to_numpy = False
|
515 |
|
516 |
input_was_string = False
|
517 |
+
if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
|
518 |
sentences = [sentences]
|
519 |
input_was_string = True
|
520 |
|
|
|
525 |
inverse_permutation = np.argsort(permutation)
|
526 |
sentences = [sentences[idx] for idx in permutation]
|
527 |
|
528 |
+
tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
|
529 |
+
tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
|
530 |
+
'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
|
531 |
)
|
532 |
+
tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
|
533 |
|
534 |
all_embeddings = []
|
535 |
|
|
|
547 |
for i in range_iter:
|
548 |
encoded_input = self.tokenizer(
|
549 |
sentences[i : i + batch_size],
|
550 |
+
return_tensors='pt',
|
551 |
**tokenizer_kwargs,
|
552 |
).to(self.device)
|
553 |
+
token_embs = self.forward(**encoded_input)[0]
|
|
|
|
|
|
|
|
|
|
|
554 |
|
555 |
# Accumulate in fp32 to avoid overflow
|
556 |
token_embs = token_embs.float()
|
557 |
|
558 |
+
if output_value == 'token_embeddings':
|
559 |
raise NotImplementedError
|
560 |
elif output_value is None:
|
561 |
raise NotImplementedError
|
562 |
else:
|
563 |
+
if self.config.emb_pooler == 'cls':
|
564 |
embeddings = self.cls_pooling(
|
565 |
+
token_embs, encoded_input['attention_mask']
|
566 |
)
|
567 |
else:
|
568 |
embeddings = self.mean_pooling(
|
569 |
+
token_embs, encoded_input['attention_mask']
|
570 |
)
|
571 |
|
572 |
+
if normalize_embeddings:
|
573 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
574 |
+
|
575 |
+
if convert_to_numpy:
|
576 |
+
embeddings = embeddings.cpu()
|
577 |
all_embeddings.extend(embeddings)
|
578 |
|
579 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
|
|
582 |
if truncate_dim:
|
583 |
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
584 |
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
if convert_to_tensor:
|
586 |
all_embeddings = torch.stack(all_embeddings)
|
587 |
elif convert_to_numpy:
|
588 |
+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
589 |
|
590 |
if input_was_string:
|
591 |
all_embeddings = all_embeddings[0]
|
|
|
593 |
self.train(is_training)
|
594 |
return all_embeddings
|
595 |
|
596 |
+
|
597 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
598 |
if not self.config.matryoshka_dimensions:
|
599 |
logger.warning(
|
600 |
+
'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
|
601 |
)
|
602 |
return embeddings
|
603 |
elif truncate_dim in self.config.matryoshka_dimensions:
|
604 |
return [tensor[:truncate_dim] for tensor in embeddings]
|
605 |
else:
|
606 |
+
raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
|
607 |
+
f'Supported dimensions are {self.config.matryoshka_dimensions}.')
|
|
|
|
|
608 |
|
609 |
def mean_pooling(
|
610 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
|
|
616 |
input_mask_expanded.sum(1), min=1e-9
|
617 |
)
|
618 |
|
|
|
|
|
619 |
|
620 |
+
def cls_pooling(
|
621 |
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
622 |
+
):
|
623 |
+
return token_embeddings[:,0]
|
624 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
625 |
|
626 |
def forward(
|
627 |
self,
|
|
|
631 |
attention_mask=None,
|
632 |
masked_tokens_mask=None,
|
633 |
return_dict=None,
|
|
|
634 |
**kwargs,
|
635 |
):
|
636 |
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
|
|
|
638 |
layer output for these tokens.
|
639 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
640 |
"""
|
641 |
+
|
642 |
if kwargs:
|
643 |
for key, value in kwargs.items():
|
644 |
if value is not None:
|
645 |
logger.warning(
|
646 |
+
'Flash attention implementation does not support kwargs: %s',
|
647 |
key,
|
648 |
)
|
649 |
|
|
|
652 |
)
|
653 |
|
654 |
hidden_states = self.embeddings(
|
655 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
|
|
|
|
|
|
656 |
)
|
657 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
658 |
# BERT puts embedding LayerNorm before embedding dropout.
|
|
|
676 |
subset_mask = None
|
677 |
|
678 |
sequence_output = self.encoder(
|
679 |
+
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
|
|
|
|
|
|
|
|
|
680 |
)
|
681 |
|
|
|
|
|
|
|
|
|
|
|
|
|
682 |
if masked_tokens_mask is None:
|
683 |
pooled_output = (
|
684 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
|
|
685 |
)
|
686 |
else:
|
687 |
# TD [2022-03-01]: the indexing here is very tricky.
|
|
|
695 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
696 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
697 |
pooled_output = (
|
698 |
+
self.pooler(pool_input, pool=False) if self.pooler is not None else None
|
|
|
|
|
699 |
)
|
700 |
|
701 |
if not return_dict:
|
|
|
704 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
705 |
last_hidden_state=sequence_output,
|
706 |
pooler_output=pooled_output,
|
|
|
707 |
)
|
708 |
|
709 |
|
|
|
800 |
)
|
801 |
|
802 |
|
803 |
+
# class XLMRobertaForPreTraining(XLMRobertaPreTrainedModel):
|
804 |
+
# def __init__(self, config: XLMRobertaFlashConfig):
|
805 |
+
# super().__init__(config)
|
806 |
+
# # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
|
807 |
+
# # (around 15%) to the classifier heads.
|
808 |
+
# self.dense_seq_output = getattr(config, "dense_seq_output", False)
|
809 |
+
# # If last_layer_subset, we only need the compute the last layer for a subset of tokens
|
810 |
+
# # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
|
811 |
+
# self.last_layer_subset = getattr(config, "last_layer_subset", False)
|
812 |
+
# if self.last_layer_subset:
|
813 |
+
# assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
|
814 |
+
# use_xentropy = getattr(config, "use_xentropy", False)
|
815 |
+
# if use_xentropy and CrossEntropyLoss is None:
|
816 |
+
# raise ImportError("xentropy_cuda is not installed")
|
817 |
+
# loss_cls = (
|
818 |
+
# nn.CrossEntropyLoss
|
819 |
+
# if not use_xentropy
|
820 |
+
# else partial(CrossEntropyLoss, inplace_backward=True)
|
821 |
+
# )
|
822 |
+
#
|
823 |
+
# self.xlm = XLMRobertaModel(config)
|
824 |
+
# self.cls = XLMRobertaPreTrainingHeads(config)
|
825 |
+
# self.mlm_loss = loss_cls(ignore_index=0)
|
826 |
+
# self.nsp_loss = loss_cls(ignore_index=-1)
|
827 |
+
#
|
828 |
+
# # Initialize weights and apply final processing
|
829 |
+
# self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
830 |
+
# self.tie_weights()
|
831 |
+
#
|
832 |
+
# def tie_weights(self):
|
833 |
+
# self.cls.predictions.decoder.weight = self.xlm.embeddings.word_embeddings.weight
|
834 |
+
#
|
835 |
+
# def forward(
|
836 |
+
# self,
|
837 |
+
# input_ids,
|
838 |
+
# position_ids=None,
|
839 |
+
# token_type_ids=None,
|
840 |
+
# attention_mask=None,
|
841 |
+
# labels=None,
|
842 |
+
# next_sentence_label=None,
|
843 |
+
# ):
|
844 |
+
# """
|
845 |
+
# If labels are provided, they must be 0 for masked out tokens (as specified in the attention
|
846 |
+
# mask).
|
847 |
+
# Outputs:
|
848 |
+
# if `labels` and `next_sentence_label` are not `None`:
|
849 |
+
# Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
850 |
+
# sentence classification loss.
|
851 |
+
# if `labels` or `next_sentence_label` is `None`:
|
852 |
+
# Outputs a tuple comprising
|
853 |
+
# - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
854 |
+
# - the next sentence classification logits of shape [batch_size, 2].
|
855 |
+
#
|
856 |
+
# """
|
857 |
+
# masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
|
858 |
+
# outputs = self.xlm(
|
859 |
+
# input_ids,
|
860 |
+
# position_ids=position_ids,
|
861 |
+
# token_type_ids=token_type_ids,
|
862 |
+
# attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
863 |
+
# masked_tokens_mask=masked_tokens_mask,
|
864 |
+
# )
|
865 |
+
# sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
|
866 |
+
# if self.dense_seq_output and labels is not None:
|
867 |
+
# masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
|
868 |
+
# if not self.last_layer_subset:
|
869 |
+
# sequence_output = index_first_axis(
|
870 |
+
# rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
|
871 |
+
# )
|
872 |
+
# prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
873 |
+
#
|
874 |
+
# total_loss = None
|
875 |
+
# if labels is not None and next_sentence_label is not None:
|
876 |
+
# if (
|
877 |
+
# self.dense_seq_output and labels is not None
|
878 |
+
# ): # prediction_scores are already flattened
|
879 |
+
# masked_lm_loss = self.mlm_loss(
|
880 |
+
# prediction_scores, labels.flatten()[masked_token_idx]
|
881 |
+
# )
|
882 |
+
# else:
|
883 |
+
# masked_lm_loss = self.mlm_loss(
|
884 |
+
# rearrange(prediction_scores, "... v -> (...) v"),
|
885 |
+
# rearrange(labels, "... -> (...)"),
|
886 |
+
# )
|
887 |
+
# next_sentence_loss = self.nsp_loss(
|
888 |
+
# rearrange(seq_relationship_score, "... t -> (...) t"),
|
889 |
+
# rearrange(next_sentence_label, "... -> (...)"),
|
890 |
+
# )
|
891 |
+
# total_loss = masked_lm_loss.float() + next_sentence_loss.float()
|
892 |
+
#
|
893 |
+
# return BertForPreTrainingOutput(
|
894 |
+
# loss=total_loss,
|
895 |
+
# prediction_logits=prediction_scores,
|
896 |
+
# seq_relationship_logits=seq_relationship_score,
|
897 |
+
# )
|
898 |
+
|
899 |
+
|
900 |
def remap_state_dict(state_dict, config: PretrainedConfig):
|
901 |
"""
|
902 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
|
|
1048 |
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
1049 |
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
1050 |
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
1051 |
+
state_dict[
|
1052 |
+
f"bert.encoder.layers.{d}.attention.self.query.weight"
|
1053 |
+
] = Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
|
1054 |
+
state_dict[
|
1055 |
+
f"bert.encoder.layers.{d}.attention.self.key.weight"
|
1056 |
+
] = Wqkv_weights[
|
1057 |
+
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
|
1058 |
+
]
|
1059 |
+
state_dict[
|
1060 |
+
f"bert.encoder.layers.{d}.attention.self.value.weight"
|
1061 |
+
] = Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
|
1062 |
+
state_dict[
|
1063 |
+
f"bert.encoder.layers.{d}.attention.self.query.bias"
|
1064 |
+
] = Wqkv_biases[: Wqkv_biases.shape[0] // 3]
|
1065 |
+
state_dict[
|
1066 |
+
f"bert.encoder.layers.{d}.attention.self.key.bias"
|
1067 |
+
] = Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
|
1068 |
+
state_dict[
|
1069 |
+
f"bert.encoder.layers.{d}.attention.self.value.bias"
|
1070 |
+
] = Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
|
1071 |
else:
|
1072 |
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
1073 |
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
1074 |
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
1075 |
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
1076 |
+
state_dict[
|
1077 |
+
f"bert.encoder.layers.{d}.attention.self.query.weight"
|
1078 |
+
] = Wq_weight
|
1079 |
+
state_dict[
|
1080 |
+
f"bert.encoder.layers.{d}.attention.self.key.weight"
|
1081 |
+
] = Wkv_weights[: Wkv_weights.shape[0] // 2, :]
|
1082 |
+
state_dict[
|
1083 |
+
f"bert.encoder.layers.{d}.attention.self.value.weight"
|
1084 |
+
] = Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
|
1085 |
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
1086 |
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
1087 |
: Wkv_biases.shape[0] // 2
|
1088 |
]
|
1089 |
+
state_dict[
|
1090 |
+
f"bert.encoder.layers.{d}.attention.self.value.bias"
|
1091 |
+
] = Wkv_biases[Wkv_biases.shape[0] // 2 :]
|
1092 |
|
1093 |
def inv_key_mapping_ln(key):
|
1094 |
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
modeling_xlm_roberta_for_glue.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
|
6 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
|
7 |
+
|
8 |
+
from .modeling_xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
|
9 |
+
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
10 |
+
|
11 |
+
|
12 |
+
class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
|
13 |
+
def __init__(self, config: XLMRobertaFlashConfig):
|
14 |
+
super().__init__(config)
|
15 |
+
self.num_labels = config.num_labels
|
16 |
+
self.config = config
|
17 |
+
|
18 |
+
self.roberta = XLMRobertaModel(config)
|
19 |
+
classifier_dropout = (
|
20 |
+
config.classifier_dropout
|
21 |
+
if config.classifier_dropout is not None
|
22 |
+
else config.hidden_dropout_prob
|
23 |
+
)
|
24 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
25 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
26 |
+
|
27 |
+
# Initialize weights and apply final processing
|
28 |
+
self.post_init()
|
29 |
+
|
30 |
+
|
31 |
+
def forward(
|
32 |
+
self,
|
33 |
+
input_ids: Optional[torch.Tensor] = None,
|
34 |
+
attention_mask: Optional[torch.Tensor] = None,
|
35 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
36 |
+
position_ids: Optional[torch.Tensor] = None,
|
37 |
+
head_mask: Optional[torch.Tensor] = None,
|
38 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
39 |
+
labels: Optional[torch.Tensor] = None,
|
40 |
+
output_attentions: Optional[bool] = None,
|
41 |
+
output_hidden_states: Optional[bool] = None,
|
42 |
+
return_dict: Optional[bool] = None,
|
43 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
44 |
+
r"""
|
45 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
46 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
47 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
48 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
49 |
+
"""
|
50 |
+
return_dict = (
|
51 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
52 |
+
)
|
53 |
+
|
54 |
+
assert head_mask is None
|
55 |
+
assert inputs_embeds is None
|
56 |
+
assert output_attentions is None
|
57 |
+
assert output_hidden_states is None
|
58 |
+
assert return_dict
|
59 |
+
outputs = self.roberta(
|
60 |
+
input_ids,
|
61 |
+
attention_mask=attention_mask,
|
62 |
+
token_type_ids=token_type_ids,
|
63 |
+
position_ids=position_ids,
|
64 |
+
head_mask=head_mask,
|
65 |
+
inputs_embeds=inputs_embeds,
|
66 |
+
output_attentions=output_attentions,
|
67 |
+
output_hidden_states=output_hidden_states,
|
68 |
+
return_dict=return_dict,
|
69 |
+
)
|
70 |
+
|
71 |
+
pooled_output = outputs[1]
|
72 |
+
|
73 |
+
pooled_output = self.dropout(pooled_output)
|
74 |
+
logits = self.classifier(pooled_output)
|
75 |
+
|
76 |
+
loss = None
|
77 |
+
if labels is not None:
|
78 |
+
if self.config.problem_type is None:
|
79 |
+
if self.num_labels == 1:
|
80 |
+
self.config.problem_type = "regression"
|
81 |
+
elif self.num_labels > 1 and (
|
82 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
83 |
+
):
|
84 |
+
self.config.problem_type = "single_label_classification"
|
85 |
+
else:
|
86 |
+
self.config.problem_type = "multi_label_classification"
|
87 |
+
|
88 |
+
if self.config.problem_type == "regression":
|
89 |
+
loss_fct = MSELoss()
|
90 |
+
if self.num_labels == 1:
|
91 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
92 |
+
else:
|
93 |
+
loss = loss_fct(logits, labels)
|
94 |
+
elif self.config.problem_type == "single_label_classification":
|
95 |
+
loss_fct = CrossEntropyLoss()
|
96 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
97 |
+
elif self.config.problem_type == "multi_label_classification":
|
98 |
+
loss_fct = BCEWithLogitsLoss()
|
99 |
+
loss = loss_fct(logits, labels)
|
100 |
+
if not return_dict:
|
101 |
+
output = (logits,) + outputs[2:]
|
102 |
+
return ((loss,) + output) if loss is not None else output
|
103 |
+
|
104 |
+
return SequenceClassifierOutput(
|
105 |
+
loss=loss,
|
106 |
+
logits=logits,
|
107 |
+
hidden_states=outputs.hidden_states,
|
108 |
+
attentions=outputs.attentions,
|
109 |
+
)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cfa8fa7c7e120199548fe7149512c0adfe58f6bc13ce19f09b895aa25e8af910
|
3 |
+
size 1113232188
|
rotary.py
DELETED
@@ -1,659 +0,0 @@
|
|
1 |
-
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py
|
2 |
-
# Commit id: 3566596ad867ee415dd3c12616dd50c610176f6c
|
3 |
-
# Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
|
4 |
-
|
5 |
-
# Copyright (c) 2023, Tri Dao.
|
6 |
-
|
7 |
-
from typing import Optional, Tuple, Union
|
8 |
-
|
9 |
-
import torch
|
10 |
-
from einops import rearrange, repeat
|
11 |
-
|
12 |
-
if torch.cuda.is_available():
|
13 |
-
try:
|
14 |
-
from flash_attn.ops.triton.rotary import apply_rotary
|
15 |
-
except ImportError:
|
16 |
-
|
17 |
-
def apply_rotary(*args, **kwargs):
|
18 |
-
raise RuntimeError(
|
19 |
-
"FlashAttention is not installed. To proceed with training, please install FlashAttention. "
|
20 |
-
"For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
|
21 |
-
)
|
22 |
-
|
23 |
-
|
24 |
-
def rotate_half(x, interleaved=False):
|
25 |
-
if not interleaved:
|
26 |
-
x1, x2 = x.chunk(2, dim=-1)
|
27 |
-
return torch.cat((-x2, x1), dim=-1)
|
28 |
-
else:
|
29 |
-
x1, x2 = x[..., ::2], x[..., 1::2]
|
30 |
-
return rearrange(
|
31 |
-
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
32 |
-
)
|
33 |
-
|
34 |
-
|
35 |
-
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
36 |
-
"""
|
37 |
-
x: (batch_size, seqlen, nheads, headdim)
|
38 |
-
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
39 |
-
"""
|
40 |
-
ro_dim = cos.shape[-1] * 2
|
41 |
-
assert ro_dim <= x.shape[-1]
|
42 |
-
cos, sin = (
|
43 |
-
cos[: x.shape[1]],
|
44 |
-
sin[: x.shape[1]],
|
45 |
-
)
|
46 |
-
cos = repeat(
|
47 |
-
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
48 |
-
)
|
49 |
-
sin = repeat(
|
50 |
-
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
51 |
-
)
|
52 |
-
return torch.cat(
|
53 |
-
[
|
54 |
-
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
55 |
-
x[..., ro_dim:],
|
56 |
-
],
|
57 |
-
dim=-1,
|
58 |
-
)
|
59 |
-
|
60 |
-
|
61 |
-
class ApplyRotaryEmb(torch.autograd.Function):
|
62 |
-
@staticmethod
|
63 |
-
def forward(
|
64 |
-
ctx,
|
65 |
-
x,
|
66 |
-
cos,
|
67 |
-
sin,
|
68 |
-
interleaved=False,
|
69 |
-
inplace=False,
|
70 |
-
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
71 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
72 |
-
max_seqlen: Optional[int] = None,
|
73 |
-
):
|
74 |
-
out = apply_rotary(
|
75 |
-
x,
|
76 |
-
cos,
|
77 |
-
sin,
|
78 |
-
seqlen_offsets=seqlen_offsets,
|
79 |
-
cu_seqlens=cu_seqlens,
|
80 |
-
max_seqlen=max_seqlen,
|
81 |
-
interleaved=interleaved,
|
82 |
-
inplace=inplace,
|
83 |
-
)
|
84 |
-
|
85 |
-
if isinstance(seqlen_offsets, int):
|
86 |
-
ctx.save_for_backward(
|
87 |
-
cos, sin, cu_seqlens
|
88 |
-
) # Can't save int with save_for_backward
|
89 |
-
ctx.seqlen_offsets = seqlen_offsets
|
90 |
-
else:
|
91 |
-
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
92 |
-
ctx.seqlen_offsets = None
|
93 |
-
ctx.interleaved = interleaved
|
94 |
-
ctx.inplace = inplace
|
95 |
-
ctx.max_seqlen = max_seqlen
|
96 |
-
return out if not inplace else x
|
97 |
-
|
98 |
-
@staticmethod
|
99 |
-
def backward(ctx, do):
|
100 |
-
seqlen_offsets = ctx.seqlen_offsets
|
101 |
-
if seqlen_offsets is None:
|
102 |
-
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
103 |
-
else:
|
104 |
-
cos, sin, cu_seqlens = ctx.saved_tensors
|
105 |
-
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
|
106 |
-
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
|
107 |
-
if not ctx.interleaved and not ctx.inplace:
|
108 |
-
do = do.clone()
|
109 |
-
|
110 |
-
dx = apply_rotary(
|
111 |
-
do,
|
112 |
-
cos,
|
113 |
-
sin,
|
114 |
-
seqlen_offsets=seqlen_offsets,
|
115 |
-
cu_seqlens=cu_seqlens,
|
116 |
-
max_seqlen=ctx.max_seqlen,
|
117 |
-
interleaved=ctx.interleaved,
|
118 |
-
inplace=ctx.inplace,
|
119 |
-
conjugate=True,
|
120 |
-
)
|
121 |
-
return dx, None, None, None, None, None, None, None
|
122 |
-
|
123 |
-
|
124 |
-
def apply_rotary_emb(
|
125 |
-
x,
|
126 |
-
cos,
|
127 |
-
sin,
|
128 |
-
interleaved=False,
|
129 |
-
inplace=False,
|
130 |
-
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
131 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
132 |
-
max_seqlen: Optional[int] = None,
|
133 |
-
):
|
134 |
-
"""
|
135 |
-
Arguments:
|
136 |
-
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
137 |
-
else (total_seqlen, nheads, headdim)
|
138 |
-
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
139 |
-
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
140 |
-
of 1st half and 2nd half (GPT-NeoX style).
|
141 |
-
inplace: if True, apply rotary embedding in-place.
|
142 |
-
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
143 |
-
Most commonly used in inference when we have KV cache.
|
144 |
-
cu_seqlens: (batch + 1,) or None
|
145 |
-
max_seqlen: int
|
146 |
-
Return:
|
147 |
-
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
148 |
-
else (total_seqlen, nheads, headdim)
|
149 |
-
rotary_dim must be <= headdim
|
150 |
-
Apply rotary embedding to the first rotary_dim of x.
|
151 |
-
"""
|
152 |
-
return ApplyRotaryEmb.apply(
|
153 |
-
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
|
154 |
-
)
|
155 |
-
|
156 |
-
|
157 |
-
# For backward compatibility
|
158 |
-
apply_rotary_emb_func = apply_rotary_emb
|
159 |
-
|
160 |
-
|
161 |
-
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
162 |
-
@staticmethod
|
163 |
-
def forward(
|
164 |
-
ctx,
|
165 |
-
qkv,
|
166 |
-
cos,
|
167 |
-
sin,
|
168 |
-
cos_k=None,
|
169 |
-
sin_k=None,
|
170 |
-
interleaved=False,
|
171 |
-
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
172 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
173 |
-
max_seqlen: Optional[int] = None,
|
174 |
-
use_flash_attn: bool = True,
|
175 |
-
):
|
176 |
-
# batch, seqlen, three, nheads, headdim = qkv.shape
|
177 |
-
assert qkv.shape[-3] == 3
|
178 |
-
if cos_k is None and sin_k is None and qkv.is_contiguous():
|
179 |
-
|
180 |
-
if use_flash_attn:
|
181 |
-
# Call 1 kernel instead of 2 kernels
|
182 |
-
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
|
183 |
-
# dimensions, we get the same tensor
|
184 |
-
qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
|
185 |
-
# qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
|
186 |
-
apply_rotary(
|
187 |
-
qk,
|
188 |
-
cos,
|
189 |
-
sin,
|
190 |
-
seqlen_offsets=seqlen_offsets,
|
191 |
-
interleaved=interleaved,
|
192 |
-
inplace=True,
|
193 |
-
cu_seqlens=cu_seqlens,
|
194 |
-
max_seqlen=max_seqlen,
|
195 |
-
)
|
196 |
-
else:
|
197 |
-
q_rot = apply_rotary_emb_torch(
|
198 |
-
qkv[:, :, 0],
|
199 |
-
cos,
|
200 |
-
sin,
|
201 |
-
interleaved=interleaved,
|
202 |
-
)
|
203 |
-
k_rot = apply_rotary_emb_torch(
|
204 |
-
qkv[:, :, 1],
|
205 |
-
cos,
|
206 |
-
sin,
|
207 |
-
interleaved=interleaved,
|
208 |
-
)
|
209 |
-
qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
|
210 |
-
else:
|
211 |
-
cos_k = cos if cos_k is None else cos_k
|
212 |
-
sin_k = sin if sin_k is None else sin_k
|
213 |
-
q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
|
214 |
-
apply_rotary(
|
215 |
-
q,
|
216 |
-
cos,
|
217 |
-
sin,
|
218 |
-
seqlen_offsets,
|
219 |
-
interleaved=interleaved,
|
220 |
-
inplace=True,
|
221 |
-
cu_seqlens=cu_seqlens,
|
222 |
-
max_seqlen=max_seqlen,
|
223 |
-
)
|
224 |
-
apply_rotary(
|
225 |
-
k,
|
226 |
-
cos_k,
|
227 |
-
sin_k,
|
228 |
-
seqlen_offsets,
|
229 |
-
interleaved=interleaved,
|
230 |
-
inplace=True,
|
231 |
-
cu_seqlens=cu_seqlens,
|
232 |
-
max_seqlen=max_seqlen,
|
233 |
-
)
|
234 |
-
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
235 |
-
if isinstance(seqlen_offsets, int):
|
236 |
-
ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens)
|
237 |
-
ctx.seqlen_offsets = seqlen_offsets
|
238 |
-
else:
|
239 |
-
ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets)
|
240 |
-
ctx.seqlen_offsets = None
|
241 |
-
ctx.max_seqlen = max_seqlen
|
242 |
-
ctx.interleaved = interleaved
|
243 |
-
return qkv
|
244 |
-
|
245 |
-
@staticmethod
|
246 |
-
def backward(ctx, dqkv):
|
247 |
-
seqlen_offsets = ctx.seqlen_offsets
|
248 |
-
if seqlen_offsets is None:
|
249 |
-
cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
250 |
-
else:
|
251 |
-
cos, sin, cos_k, sin_k, cu_seqlens = ctx.saved_tensors
|
252 |
-
if cos_k is None and sin_k is None and dqkv.is_contiguous():
|
253 |
-
# Call 1 kernel instead of 2 kernels
|
254 |
-
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
|
255 |
-
# dimensions, we get the same tensor
|
256 |
-
dqk = rearrange(dqkv[..., :2, :, :], "... t h d -> ... (t h) d")
|
257 |
-
apply_rotary(
|
258 |
-
dqk,
|
259 |
-
cos,
|
260 |
-
sin,
|
261 |
-
seqlen_offsets=seqlen_offsets,
|
262 |
-
interleaved=ctx.interleaved,
|
263 |
-
inplace=True,
|
264 |
-
conjugate=True,
|
265 |
-
cu_seqlens=cu_seqlens,
|
266 |
-
max_seqlen=ctx.max_seqlen,
|
267 |
-
)
|
268 |
-
else:
|
269 |
-
cos_k = cos if cos_k is None else cos_k
|
270 |
-
sin_k = sin if sin_k is None else sin_k
|
271 |
-
dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
|
272 |
-
apply_rotary(
|
273 |
-
dq,
|
274 |
-
cos,
|
275 |
-
sin,
|
276 |
-
seqlen_offsets,
|
277 |
-
interleaved=ctx.interleaved,
|
278 |
-
inplace=True,
|
279 |
-
conjugate=True,
|
280 |
-
cu_seqlens=cu_seqlens,
|
281 |
-
max_seqlen=ctx.max_seqlen,
|
282 |
-
)
|
283 |
-
apply_rotary(
|
284 |
-
dk,
|
285 |
-
cos_k,
|
286 |
-
sin_k,
|
287 |
-
seqlen_offsets,
|
288 |
-
interleaved=ctx.interleaved,
|
289 |
-
inplace=True,
|
290 |
-
conjugate=True,
|
291 |
-
cu_seqlens=cu_seqlens,
|
292 |
-
max_seqlen=ctx.max_seqlen,
|
293 |
-
)
|
294 |
-
return dqkv, None, None, None, None, None, None, None, None, None
|
295 |
-
|
296 |
-
|
297 |
-
def apply_rotary_emb_qkv_(
|
298 |
-
qkv,
|
299 |
-
cos,
|
300 |
-
sin,
|
301 |
-
cos_k=None,
|
302 |
-
sin_k=None,
|
303 |
-
interleaved=False,
|
304 |
-
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
305 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
306 |
-
max_seqlen: Optional[int] = None,
|
307 |
-
use_flash_attn=True,
|
308 |
-
):
|
309 |
-
"""
|
310 |
-
Arguments:
|
311 |
-
qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
|
312 |
-
else (total_seqlen, 3, nheads, headdim)
|
313 |
-
cos, sin: (seqlen, rotary_dim / 2)
|
314 |
-
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
315 |
-
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
316 |
-
1st half and 2nd half (GPT-NeoX style).
|
317 |
-
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
|
318 |
-
Most commonly used in inference when we have KV cache.
|
319 |
-
cu_seqlens: (batch + 1,) or None
|
320 |
-
max_seqlen: int
|
321 |
-
Return:
|
322 |
-
qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
|
323 |
-
else (total_seqlen, 3, nheads, headdim)
|
324 |
-
rotary_dim must be <= headdim
|
325 |
-
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
|
326 |
-
"""
|
327 |
-
return ApplyRotaryEmbQKV_.apply(
|
328 |
-
qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
|
329 |
-
)
|
330 |
-
|
331 |
-
|
332 |
-
class ApplyRotaryEmbKV_(torch.autograd.Function):
|
333 |
-
@staticmethod
|
334 |
-
def forward(
|
335 |
-
ctx,
|
336 |
-
kv,
|
337 |
-
cos,
|
338 |
-
sin,
|
339 |
-
interleaved=False,
|
340 |
-
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
341 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
342 |
-
max_seqlen: Optional[int] = None,
|
343 |
-
):
|
344 |
-
# batch, seqlen, two, nheads, headdim = kv.shape
|
345 |
-
assert kv.shape[-3] == 2
|
346 |
-
k = kv[..., 0, :, :]
|
347 |
-
apply_rotary(
|
348 |
-
k,
|
349 |
-
cos,
|
350 |
-
sin,
|
351 |
-
seqlen_offsets=seqlen_offsets,
|
352 |
-
interleaved=interleaved,
|
353 |
-
inplace=True,
|
354 |
-
cu_seqlens=cu_seqlens,
|
355 |
-
max_seqlen=max_seqlen,
|
356 |
-
)
|
357 |
-
if isinstance(seqlen_offsets, int):
|
358 |
-
ctx.save_for_backward(
|
359 |
-
cos, sin, cu_seqlens
|
360 |
-
) # Can't save int with save_for_backward
|
361 |
-
ctx.seqlen_offsets = seqlen_offsets
|
362 |
-
else:
|
363 |
-
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
364 |
-
ctx.seqlen_offsets = None
|
365 |
-
ctx.max_seqlen = max_seqlen
|
366 |
-
ctx.interleaved = interleaved
|
367 |
-
return kv
|
368 |
-
|
369 |
-
@staticmethod
|
370 |
-
def backward(ctx, dkv):
|
371 |
-
seqlen_offsets = ctx.seqlen_offsets
|
372 |
-
if seqlen_offsets is None:
|
373 |
-
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
374 |
-
else:
|
375 |
-
cos, sin, cu_seqlens = ctx.saved_tensors
|
376 |
-
apply_rotary(
|
377 |
-
dkv[..., 0, :, :],
|
378 |
-
cos,
|
379 |
-
sin,
|
380 |
-
seqlen_offsets=seqlen_offsets,
|
381 |
-
interleaved=ctx.interleaved,
|
382 |
-
inplace=True,
|
383 |
-
conjugate=True,
|
384 |
-
cu_seqlens=cu_seqlens,
|
385 |
-
max_seqlen=ctx.max_seqlen,
|
386 |
-
)
|
387 |
-
return dkv, None, None, None, None, None, None
|
388 |
-
|
389 |
-
|
390 |
-
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
|
391 |
-
|
392 |
-
|
393 |
-
def apply_rotary_emb_kv_(
|
394 |
-
kv,
|
395 |
-
cos,
|
396 |
-
sin,
|
397 |
-
interleaved=False,
|
398 |
-
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
399 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
400 |
-
max_seqlen: Optional[int] = None,
|
401 |
-
):
|
402 |
-
"""
|
403 |
-
Arguments:
|
404 |
-
kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
|
405 |
-
else (total_seqlen, 2, nheads, headdim)
|
406 |
-
cos, sin: (seqlen, rotary_dim / 2)
|
407 |
-
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
408 |
-
1st half and 2nd half (GPT-NeoX style).
|
409 |
-
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
|
410 |
-
Most commonly used in inference when we have KV cache.
|
411 |
-
cu_seqlens: (batch + 1,) or None
|
412 |
-
max_seqlen: int
|
413 |
-
Return:
|
414 |
-
kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
|
415 |
-
else (total_seqlen, 2, nheads, headdim)
|
416 |
-
rotary_dim must be <= headdim
|
417 |
-
Apply rotary embedding *inplace* to the first rotary_dim of K.
|
418 |
-
"""
|
419 |
-
return ApplyRotaryEmbKV_.apply(
|
420 |
-
kv, cos, sin, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
|
421 |
-
)
|
422 |
-
|
423 |
-
|
424 |
-
class RotaryEmbedding(torch.nn.Module):
|
425 |
-
"""
|
426 |
-
The rotary position embeddings from RoFormer_ (Su et. al).
|
427 |
-
A crucial insight from the method is that the query and keys are
|
428 |
-
transformed by rotation matrices which depend on the relative positions.
|
429 |
-
|
430 |
-
Other implementations are available in the Rotary Transformer repo_ and in
|
431 |
-
GPT-NeoX_, GPT-NeoX was an inspiration
|
432 |
-
|
433 |
-
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
434 |
-
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
435 |
-
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
436 |
-
|
437 |
-
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
438 |
-
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
439 |
-
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
440 |
-
"""
|
441 |
-
|
442 |
-
def __init__(
|
443 |
-
self,
|
444 |
-
dim: int,
|
445 |
-
base=10000.0,
|
446 |
-
interleaved=False,
|
447 |
-
scale_base=None,
|
448 |
-
pos_idx_in_fp32=True,
|
449 |
-
device=None,
|
450 |
-
use_flash_attn=True,
|
451 |
-
):
|
452 |
-
"""
|
453 |
-
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
454 |
-
of 1st half and 2nd half (GPT-NeoX style).
|
455 |
-
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
456 |
-
otherwise they might be in lower precision.
|
457 |
-
This option was added because previously (before 2023-07-02), when we construct
|
458 |
-
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
459 |
-
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
460 |
-
self.inv_freq would be bf16, and the position indices are also in bf16.
|
461 |
-
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
462 |
-
embeddings for some positions will coincide.
|
463 |
-
To maintain compatibility with models previously trained in pure bf16,
|
464 |
-
we add this option.
|
465 |
-
"""
|
466 |
-
super().__init__()
|
467 |
-
self.dim = dim
|
468 |
-
self._base = float(base)
|
469 |
-
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
470 |
-
self.use_flash_attn = use_flash_attn
|
471 |
-
# Generate and save the inverse frequency buffer (non trainable)
|
472 |
-
inv_freq = self._compute_inv_freq(device)
|
473 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
474 |
-
self.interleaved = interleaved
|
475 |
-
self.scale_base = scale_base
|
476 |
-
scale = (
|
477 |
-
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
478 |
-
/ (1.4 * dim)
|
479 |
-
if scale_base is not None
|
480 |
-
else None
|
481 |
-
)
|
482 |
-
self.register_buffer("scale", scale, persistent=False)
|
483 |
-
|
484 |
-
self._seq_len_cached = 0
|
485 |
-
self._cos_cached = None
|
486 |
-
self._sin_cached = None
|
487 |
-
self._cos_k_cached = None
|
488 |
-
self._sin_k_cached = None
|
489 |
-
|
490 |
-
@property
|
491 |
-
def base(self):
|
492 |
-
return self._base
|
493 |
-
|
494 |
-
@base.setter
|
495 |
-
def base(self, new_base):
|
496 |
-
new_base = float(new_base)
|
497 |
-
if new_base > 0:
|
498 |
-
if self._base != new_base: # only update if the base value has changed
|
499 |
-
self._base = new_base
|
500 |
-
self._update_cos_sin_cache(
|
501 |
-
self._seq_len_cached,
|
502 |
-
device=self.inv_freq.device,
|
503 |
-
dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
|
504 |
-
rotary_base_changed=True,
|
505 |
-
)
|
506 |
-
else:
|
507 |
-
raise ValueError("Rotary base value must be positive")
|
508 |
-
|
509 |
-
def _compute_inv_freq(self, device=None):
|
510 |
-
return 1.0 / (
|
511 |
-
self.base
|
512 |
-
** (
|
513 |
-
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
|
514 |
-
/ self.dim
|
515 |
-
)
|
516 |
-
)
|
517 |
-
|
518 |
-
def _update_cos_sin_cache(
|
519 |
-
self, seqlen, device=None, dtype=None, rotary_base_changed=False
|
520 |
-
):
|
521 |
-
# Reset the tables if the sequence length has changed,
|
522 |
-
# if we're on a new device (possibly due to tracing for instance),
|
523 |
-
# or if we're switching from inference mode to training
|
524 |
-
# or if the rotary base value was changed
|
525 |
-
if (
|
526 |
-
seqlen > self._seq_len_cached
|
527 |
-
or self._cos_cached is None
|
528 |
-
or self._cos_cached.device != device
|
529 |
-
or self._cos_cached.dtype != dtype
|
530 |
-
or (self.training and self._cos_cached.is_inference())
|
531 |
-
or rotary_base_changed
|
532 |
-
):
|
533 |
-
self._seq_len_cached = seqlen
|
534 |
-
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
535 |
-
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
536 |
-
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
537 |
-
if rotary_base_changed:
|
538 |
-
self.inv_freq = self._compute_inv_freq(device=device)
|
539 |
-
if self.pos_idx_in_fp32:
|
540 |
-
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
541 |
-
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
542 |
-
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
543 |
-
# cos & sin output to change significantly.
|
544 |
-
# We want to recompute self.inv_freq if it was not loaded in fp32
|
545 |
-
if self.inv_freq.dtype != torch.float32:
|
546 |
-
inv_freq = self._compute_inv_freq(device=device)
|
547 |
-
else:
|
548 |
-
inv_freq = self.inv_freq
|
549 |
-
else:
|
550 |
-
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
551 |
-
inv_freq = self.inv_freq
|
552 |
-
|
553 |
-
# Don't do einsum, it converts fp32 to fp16 under AMP
|
554 |
-
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
555 |
-
freqs = torch.outer(t, inv_freq)
|
556 |
-
if self.scale is None:
|
557 |
-
self._cos_cached = torch.cos(freqs).to(dtype)
|
558 |
-
self._sin_cached = torch.sin(freqs).to(dtype)
|
559 |
-
else:
|
560 |
-
power = (
|
561 |
-
torch.arange(
|
562 |
-
seqlen, dtype=self.scale.dtype, device=self.scale.device
|
563 |
-
)
|
564 |
-
- seqlen // 2
|
565 |
-
) / self.scale_base
|
566 |
-
scale = self.scale.to(device=power.device) ** rearrange(
|
567 |
-
power, "s -> s 1"
|
568 |
-
)
|
569 |
-
# We want the multiplication by scale to happen in fp32
|
570 |
-
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
571 |
-
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
572 |
-
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
573 |
-
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
574 |
-
|
575 |
-
def forward(
|
576 |
-
self,
|
577 |
-
qkv: torch.Tensor,
|
578 |
-
kv: Optional[torch.Tensor] = None,
|
579 |
-
seqlen_offset: Union[int, torch.Tensor] = 0,
|
580 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
581 |
-
max_seqlen: Optional[int] = None,
|
582 |
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
583 |
-
"""
|
584 |
-
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
|
585 |
-
else it's just q of shape (batch, seqlen, nheads, headdim)
|
586 |
-
kv: (batch, seqlen, 2, nheads, headdim)
|
587 |
-
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
588 |
-
Most commonly used in inference when we have KV cache.
|
589 |
-
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
|
590 |
-
should pass in max_seqlen, which will update the cos / sin cache up to that length.
|
591 |
-
Apply rotary embedding *inplace* to qkv and / or kv.
|
592 |
-
"""
|
593 |
-
if cu_seqlens is not None:
|
594 |
-
assert max_seqlen is not None
|
595 |
-
seqlen = qkv.shape[1] if max_seqlen is None else max_seqlen
|
596 |
-
if max_seqlen is not None:
|
597 |
-
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
598 |
-
elif isinstance(seqlen_offset, int):
|
599 |
-
self._update_cos_sin_cache(
|
600 |
-
seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype
|
601 |
-
)
|
602 |
-
if kv is None:
|
603 |
-
if self.scale is None:
|
604 |
-
return apply_rotary_emb_qkv_(
|
605 |
-
qkv,
|
606 |
-
self._cos_cached,
|
607 |
-
self._sin_cached,
|
608 |
-
interleaved=self.interleaved,
|
609 |
-
seqlen_offsets=seqlen_offset,
|
610 |
-
cu_seqlens=cu_seqlens,
|
611 |
-
max_seqlen=max_seqlen,
|
612 |
-
use_flash_attn=self.use_flash_attn,
|
613 |
-
)
|
614 |
-
else:
|
615 |
-
return apply_rotary_emb_qkv_(
|
616 |
-
qkv,
|
617 |
-
self._cos_cached,
|
618 |
-
self._sin_cached,
|
619 |
-
self._cos_k_cached,
|
620 |
-
self._sin_k_cached,
|
621 |
-
interleaved=self.interleaved,
|
622 |
-
seqlen_offsets=seqlen_offset,
|
623 |
-
cu_seqlens=cu_seqlens,
|
624 |
-
max_seqlen=max_seqlen,
|
625 |
-
use_flash_attn=self.use_flash_attn,
|
626 |
-
)
|
627 |
-
else:
|
628 |
-
q = qkv
|
629 |
-
q = apply_rotary_emb_func(
|
630 |
-
q,
|
631 |
-
self._cos_cached,
|
632 |
-
self._sin_cached,
|
633 |
-
interleaved=self.interleaved,
|
634 |
-
inplace=True,
|
635 |
-
seqlen_offsets=seqlen_offset,
|
636 |
-
cu_seqlens=cu_seqlens,
|
637 |
-
max_seqlen=max_seqlen,
|
638 |
-
)
|
639 |
-
if self.scale is None:
|
640 |
-
kv = apply_rotary_emb_kv_(
|
641 |
-
kv,
|
642 |
-
self._cos_cached,
|
643 |
-
self._sin_cached,
|
644 |
-
interleaved=self.interleaved,
|
645 |
-
seqlen_offsets=seqlen_offset,
|
646 |
-
cu_seqlens=cu_seqlens,
|
647 |
-
max_seqlen=max_seqlen,
|
648 |
-
)
|
649 |
-
else:
|
650 |
-
kv = apply_rotary_emb_kv_(
|
651 |
-
kv,
|
652 |
-
self._cos_k_cached,
|
653 |
-
self._sin_k_cached,
|
654 |
-
interleaved=self.interleaved,
|
655 |
-
seqlen_offsets=seqlen_offset,
|
656 |
-
cu_seqlens=cu_seqlens,
|
657 |
-
max_seqlen=max_seqlen,
|
658 |
-
)
|
659 |
-
return q, kv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stochastic_depth.py
CHANGED
@@ -34,7 +34,7 @@
|
|
34 |
|
35 |
import torch
|
36 |
import torch.fx
|
37 |
-
from torch import
|
38 |
|
39 |
|
40 |
def stochastic_depth(
|
|
|
34 |
|
35 |
import torch
|
36 |
import torch.fx
|
37 |
+
from torch import nn, Tensor
|
38 |
|
39 |
|
40 |
def stochastic_depth(
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_max_length": 8194,
|
3 |
+
"tokenizer_class": "XLMRobertaTokenizer"
|
4 |
+
}
|
xlm_padding.py
CHANGED
@@ -18,9 +18,7 @@ class IndexFirstAxis(torch.autograd.Function):
|
|
18 |
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
19 |
# return input[indices]
|
20 |
return torch.gather(
|
21 |
-
rearrange(input, "b ... -> b (...)"),
|
22 |
-
0,
|
23 |
-
repeat(indices, "z -> z d", d=second_dim),
|
24 |
).reshape(-1, *other_shape)
|
25 |
|
26 |
@staticmethod
|
@@ -36,9 +34,7 @@ class IndexFirstAxis(torch.autograd.Function):
|
|
36 |
)
|
37 |
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
38 |
# grad_input[indices] = grad_output
|
39 |
-
grad_input.scatter_(
|
40 |
-
0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
|
41 |
-
)
|
42 |
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
43 |
|
44 |
|
@@ -102,7 +98,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
|
|
102 |
index_first_axis_residual = IndexFirstAxisResidual.apply
|
103 |
|
104 |
|
105 |
-
def unpad_input(hidden_states, attention_mask
|
106 |
"""
|
107 |
Arguments:
|
108 |
hidden_states: (batch, seqlen, ...)
|
@@ -116,16 +112,7 @@ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
|
|
116 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
117 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
118 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
119 |
-
cu_seqlens = F.pad(
|
120 |
-
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
121 |
-
)
|
122 |
-
|
123 |
-
cu_adapter_mask = (
|
124 |
-
torch.repeat_interleave(adapter_mask, cu_seqlens[1:] - cu_seqlens[:-1])
|
125 |
-
if adapter_mask is not None
|
126 |
-
else None
|
127 |
-
)
|
128 |
-
|
129 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
130 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
131 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
@@ -136,7 +123,6 @@ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
|
|
136 |
indices,
|
137 |
cu_seqlens,
|
138 |
max_seqlen_in_batch,
|
139 |
-
cu_adapter_mask,
|
140 |
)
|
141 |
|
142 |
|
@@ -194,18 +180,14 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
|
|
194 |
"""
|
195 |
length = attention_mask_in_length.sum(dim=-1)
|
196 |
seqlen = attention_mask_in_length.size(-1)
|
197 |
-
attention_mask_2d = torch.arange(
|
198 |
-
|
199 |
-
|
200 |
-
real_indices_idx = torch.nonzero(
|
201 |
-
attention_mask_in_length.flatten(), as_tuple=False
|
202 |
-
).flatten()
|
203 |
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
204 |
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
205 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
206 |
-
cu_seqlens = F.pad(
|
207 |
-
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
208 |
-
)
|
209 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
210 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
211 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
@@ -233,4 +215,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
|
|
233 |
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
234 |
# output[indices] = hidden_states
|
235 |
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
236 |
-
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
|
|
18 |
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
19 |
# return input[indices]
|
20 |
return torch.gather(
|
21 |
+
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
|
|
|
|
|
22 |
).reshape(-1, *other_shape)
|
23 |
|
24 |
@staticmethod
|
|
|
34 |
)
|
35 |
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
36 |
# grad_input[indices] = grad_output
|
37 |
+
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
|
|
|
|
|
38 |
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
39 |
|
40 |
|
|
|
98 |
index_first_axis_residual = IndexFirstAxisResidual.apply
|
99 |
|
100 |
|
101 |
+
def unpad_input(hidden_states, attention_mask):
|
102 |
"""
|
103 |
Arguments:
|
104 |
hidden_states: (batch, seqlen, ...)
|
|
|
112 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
113 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
114 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
115 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
117 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
118 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
|
|
123 |
indices,
|
124 |
cu_seqlens,
|
125 |
max_seqlen_in_batch,
|
|
|
126 |
)
|
127 |
|
128 |
|
|
|
180 |
"""
|
181 |
length = attention_mask_in_length.sum(dim=-1)
|
182 |
seqlen = attention_mask_in_length.size(-1)
|
183 |
+
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
|
184 |
+
seqlen) < length.unsqueeze(
|
185 |
+
1)
|
186 |
+
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
|
|
|
|
|
187 |
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
188 |
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
189 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
190 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
|
|
|
|
191 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
192 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
193 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
|
|
215 |
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
216 |
# output[indices] = hidden_states
|
217 |
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
218 |
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|