phoebeklett commited on
Commit
43430cc
1 Parent(s): e77c26b

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration.py +247 -0
  2. modeling.py +1585 -0
configuration.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ # This code has been adapted from Meta and Huggingface and inherits the above lisence.
21
+ # The original code can be found here:
22
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/configuration_llama.py
23
+
24
+ """Extended Mind LLaMA model configuration"""
25
+
26
+ from transformers.configuration_utils import PretrainedConfig
27
+ from transformers.utils import logging
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class ExtendedLlamaConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`ExtendedLlamaModel`].
35
+ It is used to instantiate an Extended Mind LLaMA model according to the specified arguments,
36
+ defining the model architecture. Instantiating a configuration with the
37
+ defaults will yield a similar configuration to that of the Extended Mind LLaMA-7B.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`]
40
+ and can be used to control the model outputs.
41
+ Read the documentation from [`PretrainedConfig`] for more information.
42
+
43
+ Args:
44
+ vocab_size (`int`, *optional*, defaults to 32000):
45
+ Vocabulary size of the LLaMA model. Defines the number of different tokens
46
+ that can be represented by the `inputs_ids` passed when calling [`LlamaModel`]
47
+ hidden_size (`int`, *optional*, defaults to 4096):
48
+ Dimension of the hidden representations.
49
+ intermediate_size (`int`, *optional*, defaults to 11008):
50
+ Dimension of the MLP representations.
51
+ num_hidden_layers (`int`, *optional*, defaults to 32):
52
+ Number of hidden layers in the Transformer encoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 32):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ num_key_value_heads (`int`, *optional*):
56
+ This is the number of key_value heads that should be used to implement
57
+ Grouped Query Attention. If `num_key_value_heads=num_attention_heads`,
58
+ the model will use Multi Head Attention (MHA), if `num_key_value_heads=1
59
+ the model will use Multi Query Attention (MQA) otherwise GQA is used.
60
+ When converting a multi-head checkpoint to a GQA checkpoint,
61
+ each group key and value head should be constructed by meanpooling
62
+ all the original heads within that group. For more details checkout
63
+ [this paper](https://arxiv.org/pdf/2305.13245.pdf).
64
+ If it is not specified, will default to
65
+ `num_attention_heads`.
66
+ pretraining_tp (`int`, *optional*, defaults to `1`):
67
+ Experimental feature. Tensor parallelism rank used during pretraining.
68
+ Please refer to [this document]
69
+ (https://huggingface.co/docs/transformers/parallelism)
70
+ to understand more about it. This value is
71
+ necessary to ensure exact reproducibility of the pretraining results.
72
+ Please refer to [this issue]
73
+ (https://github.com/pytorch/pytorch/issues/76232).
74
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
75
+ The non-linear activation function (function or string) in the decoder.
76
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
77
+ The maximum sequence length that this model might ever be used with.
78
+ Llama 1 supports up to 2048 tokens,
79
+ Llama 2 up to 4096, CodeLlama up to 16384.
80
+ initializer_range (`float`, *optional*, defaults to 0.02):
81
+ The standard deviation of the truncated_normal_initializer
82
+ for initializing all weight matrices.
83
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
84
+ The epsilon used by the rms normalization layers.
85
+ use_cache (`bool`, *optional*, defaults to `True`):
86
+ Whether or not the model should return the last key/values attentions
87
+ (not used by all models). Only relevant if `config.is_decoder=True`.
88
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
89
+ Whether to tie weight embeddings
90
+ rope_theta (`float`, *optional*, defaults to 10000.0):
91
+ The base period of the RoPE embeddings.
92
+ rope_scaling (`Dict`, *optional*):
93
+ Dictionary containing the scaling configuration for the RoPE embeddings.
94
+ Currently supports two scaling strategies: linear and dynamic.
95
+ Their scaling factor must be an float greater than 1. The expected format
96
+ is `{"type": strategy name, "factor": scaling factor}`.
97
+ When using this flag, don't update `max_position_embeddings`
98
+ to the expected new maximum. See the following thread for more information
99
+ on how these scaling strategies behave:
100
+ https://www.reddit.com/r/LocalLLaMA/comments/
101
+ 14mrgpr/dynamically_scaled_rope_further_increases/.
102
+ This is an experimental feature, subject to breaking API changes in future versions.
103
+
104
+ #### Memory Configuration ####
105
+ use_external_mind (`bool`, *optional*, defaults to `True`):
106
+ Whether to attend to external memories.
107
+ use_external_mind_by_layer (`List[bool]`, *optional*,
108
+ defaults to List[`True`, ..., `True`]):
109
+ Whether to attend to external memories, on each decoder layer.
110
+ topk (`int`, *optional*, defaults to `10`):
111
+ Number of external memories for each query token to retrieve and attend to.
112
+ memory_type (`string`, *optional*, defaults to `manual`):
113
+ Whether to store external memories manually or in a vector database.
114
+ memory_device (`string`, *optional*, defaults to `cpu`):
115
+ Specify device to store memory.
116
+ mask_by_sim (`bool`, *optional*, defaults to `True`):
117
+ Whether or not to mask retrieved memories by similarity.
118
+ sim_threshold (`float`, *optional*, defaults to `0.25`):
119
+ Threshold for masking retrieved memories.
120
+ tokenizer_all_special_ids (`list`, *optional*, defaults to `[0,1,2]`):
121
+ Ids for special tokens to remove from memories.
122
+ remove_special_tokens (`bool`, *optional*, defaults to `True`):
123
+ Remove memories that correspond to tokenizer special ids.
124
+ #### Memory Configuration ####
125
+
126
+ Example:
127
+
128
+ ```python
129
+ >>> from transformers import LlamaModel, LlamaConfig
130
+
131
+ >>> # Initializing a LLaMA llama-7b style configuration
132
+ >>> configuration = LlamaConfig()
133
+
134
+ >>> # Initializing a model from the llama-7b style configuration
135
+ >>> model = LlamaModel(configuration)
136
+
137
+ >>> # Accessing the model configuration
138
+ >>> configuration = model.config
139
+ ```"""
140
+
141
+ model_type = "extended-llama"
142
+ keys_to_ignore_at_inference = ["past_key_values"]
143
+
144
+ def __init__(
145
+ self,
146
+ vocab_size=32000,
147
+ hidden_size=4096,
148
+ intermediate_size=11008,
149
+ num_hidden_layers=32,
150
+ num_attention_heads=32,
151
+ num_key_value_heads=None,
152
+ hidden_act="silu",
153
+ max_position_embeddings=2048,
154
+ initializer_range=0.02,
155
+ rms_norm_eps=1e-5,
156
+ use_cache=True,
157
+ pad_token_id=None,
158
+ bos_token_id=1,
159
+ eos_token_id=2,
160
+ pretraining_tp=1,
161
+ tie_word_embeddings=False,
162
+ rope_theta=10000.0,
163
+ rope_scaling=None,
164
+ memory_config=None,
165
+ **kwargs,
166
+ ):
167
+ if memory_config is None:
168
+ memory_config = {
169
+ "mask_by_sim": False,
170
+ "sim_threshold": 0.25,
171
+ "topk": 10,
172
+ "use_external_mind": True,
173
+ "memory_type": "manual",
174
+ "memory_device": "cpu",
175
+ "tokenizer_all_special_ids": [0, bos_token_id, eos_token_id],
176
+ "use_external_mind_by_layer": [
177
+ True for _ in range(num_hidden_layers)
178
+ ],
179
+ "remove_special_ids": True,
180
+ }
181
+ for key, value in memory_config.items():
182
+ setattr(self, key, value)
183
+
184
+ self.vocab_size = vocab_size
185
+ self.max_position_embeddings = max_position_embeddings
186
+ self.hidden_size = hidden_size
187
+ self.intermediate_size = intermediate_size
188
+ self.num_hidden_layers = num_hidden_layers
189
+ self.num_attention_heads = num_attention_heads
190
+
191
+ # for backward compatibility
192
+ if num_key_value_heads is None:
193
+ num_key_value_heads = num_attention_heads
194
+
195
+ self.num_key_value_heads = num_key_value_heads
196
+ self.hidden_act = hidden_act
197
+ self.initializer_range = initializer_range
198
+ self.rms_norm_eps = rms_norm_eps
199
+ self.pretraining_tp = pretraining_tp
200
+ self.use_cache = use_cache
201
+ self.rope_theta = rope_theta
202
+ self.rope_scaling = rope_scaling
203
+ self._rope_scaling_validation()
204
+
205
+ super().__init__(
206
+ pad_token_id=pad_token_id,
207
+ bos_token_id=bos_token_id,
208
+ eos_token_id=eos_token_id,
209
+ tie_word_embeddings=tie_word_embeddings,
210
+ **kwargs,
211
+ )
212
+
213
+ def _rope_scaling_validation(self):
214
+ """
215
+ Validate the `rope_scaling` configuration.
216
+ """
217
+ if self.rope_scaling is None:
218
+ return
219
+
220
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
221
+ raise ValueError(
222
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
223
+ f"got {self.rope_scaling}"
224
+ )
225
+ rope_scaling_type = self.rope_scaling.get("type", None)
226
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
227
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
228
+ raise ValueError(
229
+ f"""`rope_scaling`'s type field must be one of ['linear', 'dynamic'],
230
+ got {rope_scaling_type}"""
231
+ )
232
+ if (
233
+ rope_scaling_factor is None
234
+ or not isinstance(rope_scaling_factor, float)
235
+ or rope_scaling_factor <= 1.0
236
+ ):
237
+ raise ValueError(
238
+ f"""`rope_scaling`'s factor field must be an float > 1,
239
+ got {rope_scaling_factor}"""
240
+ )
241
+
242
+ # Faiss memory not compatible with Grouped Query Attention
243
+ if self.memory_type=='faiss' and self.num_key_value_heads != self.num_attention_heads:
244
+ raise NotImplementedError(
245
+ 'Faiss memory not compatible with Grouped Query Attention.'
246
+ )
247
+
modeling.py ADDED
@@ -0,0 +1,1585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ # This code has been adapted from Meta and Huggingface and inherits the above lisence.
21
+ # The original code can be found here:
22
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
23
+ # We annotate the edited code below with 'EM' comments to indicate where we have made changes.
24
+ """PyTorch Extended LLaMA model."""
25
+ import math
26
+ from typing import List, Optional, Tuple, Union
27
+
28
+ import faiss
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import torch.utils.checkpoint
33
+ from einops import rearrange
34
+ from torch import nn
35
+ from torch.linalg import vector_norm
36
+ from torch.nn import CrossEntropyLoss
37
+ from transformers.activations import ACT2FN
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPast,
40
+ CausalLMOutputWithPast,
41
+ )
42
+ from transformers.modeling_utils import PreTrainedModel
43
+ from transformers.utils import (
44
+ add_start_docstrings,
45
+ add_start_docstrings_to_model_forward,
46
+ logging,
47
+ replace_return_docstrings,
48
+ )
49
+
50
+ from emts_clean.src.llama.configuration import ExtendedLlamaConfig
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ _CONFIG_FOR_DOC = "ExtendedLlamaConfig"
55
+
56
+
57
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
58
+ def _make_causal_mask(
59
+ input_ids_shape: torch.Size,
60
+ dtype: torch.dtype,
61
+ device: torch.device,
62
+ past_key_values_length: int = 0,
63
+ ):
64
+ """
65
+ Make causal mask used for bi-directional self-attention.
66
+ """
67
+ bsz, tgt_len = input_ids_shape
68
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
69
+ mask_cond = torch.arange(mask.size(-1), device=device)
70
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
71
+ mask = mask.to(dtype)
72
+
73
+ if past_key_values_length > 0:
74
+ mask = torch.cat(
75
+ [
76
+ torch.zeros(
77
+ tgt_len, past_key_values_length, dtype=dtype, device=device
78
+ ),
79
+ mask,
80
+ ],
81
+ dim=-1,
82
+ )
83
+ return mask[None, None, :, :].expand(
84
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
85
+ )
86
+
87
+
88
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
89
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
90
+ """
91
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
92
+ """
93
+ bsz, src_len = mask.size()
94
+ tgt_len = tgt_len if tgt_len is not None else src_len
95
+
96
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
97
+
98
+ inverted_mask = 1.0 - expanded_mask
99
+
100
+ return inverted_mask.masked_fill(
101
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
102
+ )
103
+
104
+
105
+ class LlamaRMSNorm(nn.Module):
106
+ """LlamaRMSNorm is equivalent to T5LayerNorm"""
107
+
108
+ def __init__(self, hidden_size, eps=1e-6):
109
+ """
110
+ LlamaRMSNorm is equivalent to T5LayerNorm
111
+ """
112
+ super().__init__()
113
+ self.weight = nn.Parameter(torch.ones(hidden_size))
114
+ self.variance_epsilon = eps
115
+
116
+ def forward(self, hidden_states):
117
+ """Apply RMS Norm"""
118
+ input_dtype = hidden_states.dtype
119
+ hidden_states = hidden_states.to(torch.float32)
120
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
121
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
122
+ return self.weight * hidden_states.to(input_dtype)
123
+
124
+
125
+ class LlamaRotaryEmbedding(torch.nn.Module):
126
+ """Rotary Positional Embedding"""
127
+
128
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
129
+ super().__init__()
130
+ self.dim = dim
131
+ self.max_position_embeddings = max_position_embeddings
132
+ self.base = base
133
+ inv_freq = 1.0 / (
134
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
135
+ )
136
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
137
+
138
+ # Build here to make `torch.jit.trace` work.
139
+ self._set_cos_sin_cache(
140
+ seq_len=max_position_embeddings,
141
+ device=self.inv_freq.device,
142
+ dtype=torch.get_default_dtype(),
143
+ )
144
+
145
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
146
+ self.max_seq_len_cached = seq_len
147
+ t = torch.arange(
148
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
149
+ )
150
+
151
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
152
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
153
+ emb = torch.cat((freqs, freqs), dim=-1)
154
+ self.register_buffer(
155
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
156
+ )
157
+ self.register_buffer(
158
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
159
+ )
160
+
161
+ def forward(self, x, seq_len=None):
162
+ # x: [bs, num_attention_heads, seq_len, head_size]
163
+ if seq_len > self.max_seq_len_cached:
164
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
165
+
166
+ return (
167
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
168
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
169
+ )
170
+
171
+
172
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
173
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
174
+
175
+ def __init__(
176
+ self,
177
+ dim,
178
+ max_position_embeddings=2048,
179
+ base=10000,
180
+ device=None,
181
+ scaling_factor=1.0,
182
+ ):
183
+ self.scaling_factor = scaling_factor
184
+ super().__init__(dim, max_position_embeddings, base, device)
185
+
186
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
187
+ self.max_seq_len_cached = seq_len
188
+ t = torch.arange(
189
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
190
+ )
191
+ t = t / self.scaling_factor
192
+
193
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
194
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
195
+ emb = torch.cat((freqs, freqs), dim=-1)
196
+ self.register_buffer(
197
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
198
+ )
199
+ self.register_buffer(
200
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
201
+ )
202
+
203
+
204
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
205
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
206
+
207
+ def __init__(
208
+ self,
209
+ dim,
210
+ max_position_embeddings=2048,
211
+ base=10000,
212
+ device=None,
213
+ scaling_factor=1.0,
214
+ ):
215
+ self.scaling_factor = scaling_factor
216
+ super().__init__(dim, max_position_embeddings, base, device)
217
+
218
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
219
+ self.max_seq_len_cached = seq_len
220
+
221
+ if seq_len > self.max_position_embeddings:
222
+ base = self.base * (
223
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
224
+ - (self.scaling_factor - 1)
225
+ ) ** (self.dim / (self.dim - 2))
226
+ inv_freq = 1.0 / (
227
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
228
+ )
229
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
230
+
231
+ t = torch.arange(
232
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
233
+ )
234
+
235
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
236
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
237
+ emb = torch.cat((freqs, freqs), dim=-1)
238
+ self.register_buffer(
239
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
240
+ )
241
+ self.register_buffer(
242
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
243
+ )
244
+
245
+
246
+ def rotate_half(x):
247
+ """Rotates half the hidden dims of the input."""
248
+ x1 = x[..., : x.shape[-1] // 2]
249
+ x2 = x[..., x.shape[-1] // 2 :]
250
+ return torch.cat((-x2, x1), dim=-1)
251
+
252
+
253
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
254
+ """Apply rotary positional embedding to q and k."""
255
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
256
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
257
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
258
+
259
+ s_q = q.size(
260
+ -2
261
+ )
262
+ # EM: Since we apply rotary pos emb after reading from cache, queries may be shorter
263
+ _q_position_ids = position_ids[:, -s_q:]
264
+ _q_cos = cos[_q_position_ids].unsqueeze(1)
265
+ _q_sin = sin[_q_position_ids].unsqueeze(1)
266
+ q_embed = (q * _q_cos) + (rotate_half(q) * _q_sin)
267
+
268
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
269
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
270
+ k_embed = (k * cos) + (rotate_half(k) * sin)
271
+ return q_embed, k_embed
272
+
273
+
274
+ class LlamaMLP(nn.Module):
275
+ """MLP Module"""
276
+
277
+ def __init__(self, config):
278
+ super().__init__()
279
+ self.config = config
280
+ self.hidden_size = config.hidden_size
281
+ self.intermediate_size = config.intermediate_size
282
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
283
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
284
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
285
+ self.act_fn = ACT2FN[config.hidden_act]
286
+
287
+ def forward(self, x):
288
+ if self.config.pretraining_tp > 1:
289
+ slice = self.intermediate_size // self.config.pretraining_tp
290
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
291
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
292
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
293
+
294
+ gate_proj = torch.cat(
295
+ [
296
+ F.linear(x, gate_proj_slices[i])
297
+ for i in range(self.config.pretraining_tp)
298
+ ],
299
+ dim=-1,
300
+ )
301
+ up_proj = torch.cat(
302
+ [
303
+ F.linear(x, up_proj_slices[i])
304
+ for i in range(self.config.pretraining_tp)
305
+ ],
306
+ dim=-1,
307
+ )
308
+
309
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
310
+ down_proj = [
311
+ F.linear(intermediate_states[i], down_proj_slices[i])
312
+ for i in range(self.config.pretraining_tp)
313
+ ]
314
+ down_proj = sum(down_proj)
315
+ else:
316
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
317
+
318
+ return down_proj
319
+
320
+
321
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
322
+ """
323
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
324
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
325
+ """
326
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
327
+ if n_rep == 1:
328
+ return hidden_states
329
+ hidden_states = hidden_states[:, :, None, :, :].expand(
330
+ batch, num_key_value_heads, n_rep, slen, head_dim
331
+ )
332
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
333
+
334
+
335
+ class ExtendedLlamaAttention(nn.Module):
336
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
337
+
338
+ def __init__(self, config: ExtendedLlamaConfig):
339
+ super().__init__()
340
+ self.config = config
341
+ self.hidden_size = config.hidden_size
342
+ self.num_heads = config.num_attention_heads
343
+ self.head_dim = self.hidden_size // self.num_heads
344
+ self.num_key_value_heads = config.num_key_value_heads
345
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
346
+ self.max_position_embeddings = config.max_position_embeddings
347
+ self.rope_theta = config.rope_theta
348
+
349
+ if (self.head_dim * self.num_heads) != self.hidden_size:
350
+ raise ValueError(
351
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
352
+ f" and `num_heads`: {self.num_heads})."
353
+ )
354
+ self.q_proj = nn.Linear(
355
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
356
+ )
357
+ self.k_proj = nn.Linear(
358
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
359
+ )
360
+ self.v_proj = nn.Linear(
361
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
362
+ )
363
+ self.o_proj = nn.Linear(
364
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
365
+ )
366
+ self._init_rope()
367
+
368
+ def _init_rope(self):
369
+ if self.config.rope_scaling is None:
370
+ self.rotary_emb = LlamaRotaryEmbedding(
371
+ self.head_dim,
372
+ max_position_embeddings=self.max_position_embeddings,
373
+ base=self.rope_theta,
374
+ )
375
+ else:
376
+ scaling_type = self.config.rope_scaling["type"]
377
+ scaling_factor = self.config.rope_scaling["factor"]
378
+ if scaling_type == "linear":
379
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
380
+ self.head_dim,
381
+ max_position_embeddings=self.max_position_embeddings,
382
+ scaling_factor=scaling_factor,
383
+ base=self.rope_theta,
384
+ )
385
+ elif scaling_type == "dynamic":
386
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
387
+ self.head_dim,
388
+ max_position_embeddings=self.max_position_embeddings,
389
+ scaling_factor=scaling_factor,
390
+ base=self.rope_theta,
391
+ )
392
+ else:
393
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
394
+
395
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
396
+ return (
397
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
398
+ .transpose(1, 2)
399
+ .contiguous()
400
+ )
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states: torch.Tensor,
405
+ attention_mask: Optional[torch.Tensor] = None,
406
+ position_ids: Optional[torch.LongTensor] = None,
407
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
408
+ output_attentions: bool = False,
409
+ output_retrieved_memory_idx: bool = False,
410
+ use_cache: bool = False,
411
+ long_range_past_key_value=None,
412
+ faiss_indexes=None,
413
+ mask_by_sim=False,
414
+ sim_threshold=0.0,
415
+ topk=None,
416
+ current_layer=None,
417
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
418
+ """forward"""
419
+ bsz, q_len, _ = hidden_states.size()
420
+
421
+ if self.config.pretraining_tp > 1:
422
+ key_value_slicing = (
423
+ self.num_key_value_heads * self.head_dim
424
+ ) // self.config.pretraining_tp
425
+ query_slices = self.q_proj.weight.split(
426
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
427
+ )
428
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
429
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
430
+
431
+ query_states = [
432
+ F.linear(hidden_states, query_slices[i])
433
+ for i in range(self.config.pretraining_tp)
434
+ ]
435
+ query_states = torch.cat(query_states, dim=-1)
436
+
437
+ key_states = [
438
+ F.linear(hidden_states, key_slices[i])
439
+ for i in range(self.config.pretraining_tp)
440
+ ]
441
+ key_states = torch.cat(key_states, dim=-1)
442
+
443
+ value_states = [
444
+ F.linear(hidden_states, value_slices[i])
445
+ for i in range(self.config.pretraining_tp)
446
+ ]
447
+ value_states = torch.cat(value_states, dim=-1)
448
+
449
+ else:
450
+ query_states = self.q_proj(hidden_states)
451
+ key_states = self.k_proj(hidden_states)
452
+ value_states = self.v_proj(hidden_states)
453
+
454
+ query_states = query_states.view(
455
+ bsz, q_len, self.num_heads, self.head_dim
456
+ ).transpose(1, 2)
457
+ key_states = key_states.view(
458
+ bsz, q_len, self.num_key_value_heads, self.head_dim
459
+ ).transpose(1, 2)
460
+ value_states = value_states.view(
461
+ bsz, q_len, self.num_key_value_heads, self.head_dim
462
+ ).transpose(1, 2)
463
+
464
+ # EM: Read from cache before position information is added
465
+ if past_key_value is not None:
466
+ # reuse k, v, self_attention
467
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
468
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
469
+
470
+ past_key_value = (key_states, value_states) if use_cache else None
471
+
472
+ kv_seq_len = key_states.shape[-2]
473
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
474
+
475
+ query_states, key_states = apply_rotary_pos_emb(
476
+ query_states, key_states, cos, sin, position_ids
477
+ )
478
+
479
+ # repeat k/v heads if n_kv_heads < n_heads
480
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
481
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
482
+ bsz, nh, s_q, hd = query_states.shape
483
+
484
+ attn_weights = torch.matmul(
485
+ query_states, key_states.transpose(2, 3)
486
+ ) / math.sqrt(self.head_dim)
487
+
488
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
489
+ raise ValueError(
490
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
491
+ f" {attn_weights.size()}"
492
+ )
493
+
494
+ # EM: Retrieve memories from cache or faiss indexes
495
+ if long_range_past_key_value is not None or faiss_indexes is not None:
496
+ if long_range_past_key_value is not None: # manual memories
497
+ k_cache, v_cache = long_range_past_key_value
498
+ k_cache = repeat_kv(k_cache, self.num_key_value_groups)
499
+ v_cache = repeat_kv(v_cache, self.num_key_value_groups)
500
+
501
+ s_cache = k_cache.size(-2)
502
+
503
+ k_cache = k_cache.to(key_states.device)
504
+ v_cache = v_cache.to(key_states.device)
505
+
506
+ # Normalize query and key vectors
507
+ q_n = query_states / vector_norm(
508
+ query_states, ord=2, dim=-1, keepdim=True
509
+ )
510
+ k_n = k_cache / vector_norm(k_cache, ord=2, dim=-1, keepdim=True)
511
+
512
+ sim = q_n.matmul(k_n.transpose(2, 3))
513
+ if s_cache < topk:
514
+ topk = s_cache # number of tokens in cache < topk
515
+ val, idx = torch.topk(sim, k=topk, dim=-1) # Retrieve topk memories
516
+
517
+ reshaped_idx = idx.reshape(bsz, nh, s_q * topk)
518
+
519
+ selected_k = k_cache.gather(
520
+ dim=-2, index=reshaped_idx.unsqueeze(-1).expand(-1, -1, -1, hd)
521
+ )
522
+ selected_v = v_cache.gather(
523
+ dim=-2, index=reshaped_idx.unsqueeze(-1).expand(-1, -1, -1, hd)
524
+ )
525
+
526
+ elif faiss_indexes is not None: # FAISS indexes
527
+ kn_index, kv_index = faiss_indexes
528
+ q_n = query_states / vector_norm(
529
+ query_states, ord=2, dim=-1, keepdim=True
530
+ )
531
+
532
+ # One-hot encoding for layer, head to only retrieve memories from the same layer, head
533
+ one_hot_encodings = (
534
+ F.one_hot(
535
+ torch.arange(
536
+ 0,
537
+ nh * self.config.num_hidden_layers,
538
+ device=query_states.device,
539
+ )
540
+ )
541
+ * 10
542
+ )
543
+ q_n = torch.concat(
544
+ [
545
+ rearrange(q_n, "b h s d -> b (h s) d", h=nh),
546
+ one_hot_encodings[nh * current_layer : nh * (current_layer + 1)]
547
+ .unsqueeze(0)
548
+ .repeat_interleave(repeats=query_states.size(-2), dim=-2),
549
+ ],
550
+ dim=-1,
551
+ ).squeeze()
552
+
553
+ if kn_index.ntotal / (nh * self.config.num_hidden_layers) < topk:
554
+ topk = kn_index.ntotal / (nh * self.config.num_hidden_layers)
555
+
556
+ val, idx = kn_index.search(q_n.to("cpu").detach().numpy(), k=topk)
557
+ val = torch.tensor(val - 100).reshape(bsz, nh, s_q, topk) #Similarity includes scale factor from one-hot encoding
558
+ reshaped_idx = torch.tensor(
559
+ idx % (kn_index.ntotal / (nh * self.config.num_hidden_layers))
560
+ ).reshape(bsz, nh, s_q * topk)
561
+
562
+ selected_k = rearrange(
563
+ torch.tensor(kv_index.reconstruct_batch(idx.flatten()))[:, :hd],
564
+ "(h s) d -> 1 h s d",
565
+ h=nh,
566
+ ).to(query_states.device)
567
+
568
+ selected_v = rearrange(
569
+ torch.tensor(kv_index.reconstruct_batch(idx.flatten()))[:, hd:],
570
+ "(h s) d -> 1 h s d",
571
+ h=nh,
572
+ ).to(query_states.device)
573
+
574
+ attn_weight_cache = torch.matmul(
575
+ query_states, selected_k.transpose(2, 3)
576
+ ) / math.sqrt(self.head_dim)
577
+ # EM: Mask by similarity
578
+ if mask_by_sim:
579
+ sim_mask = (
580
+ rearrange(~(val > sim_threshold).bool(), "b h s i -> b h (s i)")
581
+ .unsqueeze(-2)
582
+ .expand(-1, -1, s_q, -1)
583
+ ).to(query_states.device)
584
+ attn_weight_cache = attn_weight_cache.masked_fill(
585
+ sim_mask, torch.finfo(query_states.dtype).min
586
+ )
587
+ # EM: Concatenate cache and current attention weights, values
588
+ attn_weights = torch.cat([attn_weight_cache, attn_weights], dim=-1)
589
+ value_states = torch.cat([selected_v, value_states], dim=-2)
590
+
591
+ min_val = torch.finfo(attn_weights.dtype).min
592
+
593
+ # EM: Create mask for external memories, queries only attend to their own memories
594
+ def _create_external_memories_mask(k, s_q, device, min_val=min_val):
595
+ mask = torch.ones(s_q, s_q * k, device=device, dtype=torch.float32)
596
+ for i in range(s_q):
597
+ mask[i, i * k : (i + 1) * k] = 0
598
+
599
+ filled = mask.masked_fill(mask.bool(), min_val)
600
+ return filled
601
+
602
+ if attention_mask is not None:
603
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
604
+ raise ValueError(
605
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
606
+ )
607
+ # EM: Concatenate attention mask with external memories mask
608
+ if long_range_past_key_value is not None or faiss_indexes is not None:
609
+ memory_mask = _create_external_memories_mask(
610
+ k=topk, s_q=s_q, device=attn_weights.device
611
+ )
612
+ attention_mask = (
613
+ torch.cat(
614
+ [
615
+ memory_mask,
616
+ attention_mask.squeeze(dim=[0, 1]),
617
+ ],
618
+ dim=1,
619
+ )
620
+ .unsqueeze(dim=0)
621
+ .unsqueeze(dim=1)
622
+ )
623
+ attn_weights = attn_weights + attention_mask
624
+
625
+ # upcast attention to fp32
626
+ attn_weights = nn.functional.softmax(
627
+ attn_weights, dim=-1, dtype=torch.float32
628
+ ).to(query_states.dtype)
629
+ attn_output = torch.matmul(attn_weights, value_states)
630
+
631
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
632
+ raise ValueError(
633
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
634
+ f" {attn_output.size()}"
635
+ )
636
+
637
+ attn_output = attn_output.transpose(1, 2).contiguous()
638
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
639
+
640
+ if self.config.pretraining_tp > 1:
641
+ attn_output = attn_output.split(
642
+ self.hidden_size // self.config.pretraining_tp, dim=2
643
+ )
644
+ o_proj_slices = self.o_proj.weight.split(
645
+ self.hidden_size // self.config.pretraining_tp, dim=1
646
+ )
647
+ attn_output = sum(
648
+ F.linear(attn_output[i], o_proj_slices[i])
649
+ for i in range(self.config.pretraining_tp)
650
+ )
651
+ else:
652
+ attn_output = self.o_proj(attn_output)
653
+
654
+ if not output_attentions:
655
+ attn_weights = None
656
+
657
+ if not output_retrieved_memory_idx:
658
+ reshaped_idx = None
659
+ return attn_output, attn_weights, past_key_value, reshaped_idx
660
+
661
+
662
+ class ExtendedLlamaDecoderLayer(nn.Module):
663
+ """Decoder Layer for LLaMA"""
664
+
665
+ def __init__(self, config: ExtendedLlamaConfig):
666
+ super().__init__()
667
+ self.hidden_size = config.hidden_size
668
+ self.self_attn = ExtendedLlamaAttention(config=config)
669
+ self.mlp = LlamaMLP(config)
670
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
671
+ self.post_attention_layernorm = LlamaRMSNorm(
672
+ config.hidden_size, eps=config.rms_norm_eps
673
+ )
674
+
675
+ def forward(
676
+ self,
677
+ hidden_states: torch.Tensor,
678
+ attention_mask: Optional[torch.Tensor] = None,
679
+ position_ids: Optional[torch.LongTensor] = None,
680
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
681
+ output_attentions: Optional[bool] = False,
682
+ output_retrieved_memory_idx: Optional[bool] = False,
683
+ use_cache: Optional[bool] = False,
684
+ long_range_past_key_value: Optional[Tuple[torch.Tensor]] = None,
685
+ faiss_indexes: Tuple = None,
686
+ mask_by_sim: bool = False,
687
+ sim_threshold: float = None,
688
+ topk: int = None,
689
+ current_layer=None,
690
+ ) -> Tuple[
691
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
692
+ ]:
693
+ """
694
+ Args:
695
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
696
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
697
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
698
+ output_attentions (`bool`, *optional*):
699
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
700
+ returned tensors for more detail.
701
+ use_cache (`bool`, *optional*):
702
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
703
+ (see `past_key_values`).
704
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
705
+ """
706
+
707
+ residual = hidden_states
708
+
709
+ hidden_states = self.input_layernorm(hidden_states)
710
+
711
+ # Self Attention
712
+ (
713
+ hidden_states,
714
+ self_attn_weights,
715
+ present_key_value,
716
+ selected_idx,
717
+ ) = self.self_attn(
718
+ hidden_states=hidden_states,
719
+ attention_mask=attention_mask,
720
+ position_ids=position_ids,
721
+ past_key_value=past_key_value,
722
+ output_attentions=output_attentions,
723
+ output_retrieved_memory_idx=output_retrieved_memory_idx,
724
+ use_cache=use_cache,
725
+ long_range_past_key_value=long_range_past_key_value,
726
+ faiss_indexes=faiss_indexes,
727
+ mask_by_sim=mask_by_sim,
728
+ sim_threshold=sim_threshold,
729
+ topk=topk,
730
+ current_layer=current_layer,
731
+ )
732
+ hidden_states = residual + hidden_states
733
+
734
+ # Fully Connected
735
+ residual = hidden_states
736
+ hidden_states = self.post_attention_layernorm(hidden_states)
737
+ hidden_states = self.mlp(hidden_states)
738
+ hidden_states = residual + hidden_states
739
+
740
+ outputs = (hidden_states,)
741
+
742
+ if output_attentions:
743
+ outputs += (self_attn_weights,)
744
+
745
+ if use_cache:
746
+ outputs += (present_key_value,)
747
+
748
+ if output_retrieved_memory_idx:
749
+ outputs += (selected_idx,)
750
+
751
+ return outputs
752
+
753
+
754
+ LLAMA_START_DOCSTRING = r"""
755
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
756
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
757
+ etc.)
758
+
759
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
760
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
761
+ and behavior.
762
+
763
+ Parameters:
764
+ config ([`ExtendedLlamaConfig`]):
765
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
766
+ load the weights associated with the model, only the configuration. Check out the
767
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
768
+ """
769
+
770
+
771
+ @add_start_docstrings(
772
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
773
+ LLAMA_START_DOCSTRING,
774
+ )
775
+ class LlamaPreTrainedModel(PreTrainedModel):
776
+ """Wrapper class"""
777
+
778
+ config_class = ExtendedLlamaConfig
779
+ base_model_prefix = "model"
780
+ supports_gradient_checkpointing = True
781
+ _no_split_modules = ["LlamaDecoderLayer"]
782
+ _skip_keys_device_placement = "past_key_values"
783
+
784
+ def _init_weights(self, module):
785
+ std = self.config.initializer_range
786
+ if isinstance(module, nn.Linear):
787
+ module.weight.data.normal_(mean=0.0, std=std)
788
+ if module.bias is not None:
789
+ module.bias.data.zero_()
790
+ elif isinstance(module, nn.Embedding):
791
+ module.weight.data.normal_(mean=0.0, std=std)
792
+ if module.padding_idx is not None:
793
+ module.weight.data[module.padding_idx].zero_()
794
+
795
+ def _set_gradient_checkpointing(self, module, value=False):
796
+ if isinstance(module, ExtendedLlamaModel):
797
+ module.gradient_checkpointing = value
798
+
799
+
800
+ LLAMA_INPUTS_DOCSTRING = r"""
801
+ Args:
802
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
803
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
804
+ it.
805
+
806
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
807
+ [`PreTrainedTokenizer.__call__`] for details.
808
+
809
+ [What are input IDs?](../glossary#input-ids)
810
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
811
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
812
+
813
+ - 1 for tokens that are **not masked**,
814
+ - 0 for tokens that are **masked**.
815
+
816
+ [What are attention masks?](../glossary#attention-mask)
817
+
818
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
819
+ [`PreTrainedTokenizer.__call__`] for details.
820
+
821
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
822
+ `past_key_values`).
823
+
824
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
825
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
826
+ information on the default strategy.
827
+
828
+ - 1 indicates the head is **not masked**,
829
+ - 0 indicates the head is **masked**.
830
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
831
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
832
+ config.n_positions - 1]`.
833
+
834
+ [What are position IDs?](../glossary#position-ids)
835
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
836
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
837
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
838
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
839
+
840
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
841
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
842
+
843
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
844
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
845
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
846
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
847
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
848
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
849
+ model's internal embedding lookup matrix.
850
+ use_cache (`bool`, *optional*):
851
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
852
+ `past_key_values`).
853
+ output_attentions (`bool`, *optional*):
854
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
855
+ tensors for more detail.
856
+ output_hidden_states (`bool`, *optional*):
857
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
858
+ more detail.
859
+ return_dict (`bool`, *optional*):
860
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
861
+ """
862
+
863
+
864
+ @add_start_docstrings(
865
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
866
+ LLAMA_START_DOCSTRING,
867
+ )
868
+ class ExtendedLlamaModel(LlamaPreTrainedModel):
869
+ """
870
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
871
+
872
+ Args:
873
+ config: LlamaConfig
874
+ """
875
+
876
+ def __init__(self, config: ExtendedLlamaConfig):
877
+ super().__init__(config)
878
+ self.padding_idx = config.pad_token_id
879
+ self.vocab_size = config.vocab_size
880
+
881
+ self.embed_tokens = nn.Embedding(
882
+ config.vocab_size, config.hidden_size, self.padding_idx
883
+ )
884
+ self.layers = nn.ModuleList(
885
+ [ExtendedLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
886
+ )
887
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
888
+
889
+ self.gradient_checkpointing = False
890
+ # Initialize weights and apply final processing
891
+ self.mask_by_sim = config.mask_by_sim
892
+ self.sim_threshold = config.sim_threshold
893
+ self.topk = config.topk
894
+ self.use_external_mind = config.use_external_mind
895
+ self.use_external_mind_by_layer = config.use_external_mind_by_layer
896
+ self.post_init()
897
+
898
+ def get_input_embeddings(self):
899
+ return self.embed_tokens
900
+
901
+ def set_input_embeddings(self, value):
902
+ self.embed_tokens = value
903
+
904
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
905
+ def _prepare_decoder_attention_mask(
906
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
907
+ ):
908
+ # create causal mask
909
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
910
+ combined_attention_mask = None
911
+ if input_shape[-1] > 1:
912
+ combined_attention_mask = _make_causal_mask(
913
+ input_shape,
914
+ inputs_embeds.dtype,
915
+ device=inputs_embeds.device,
916
+ past_key_values_length=past_key_values_length,
917
+ )
918
+
919
+ if attention_mask is not None:
920
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
921
+ expanded_attn_mask = _expand_mask(
922
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
923
+ ).to(inputs_embeds.device)
924
+ combined_attention_mask = (
925
+ expanded_attn_mask
926
+ if combined_attention_mask is None
927
+ else expanded_attn_mask + combined_attention_mask
928
+ )
929
+
930
+ return combined_attention_mask
931
+
932
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
933
+ def forward(
934
+ self,
935
+ input_ids: torch.LongTensor = None,
936
+ attention_mask: Optional[torch.Tensor] = None,
937
+ position_ids: Optional[torch.LongTensor] = None,
938
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
939
+ inputs_embeds: Optional[torch.FloatTensor] = None,
940
+ use_cache: Optional[bool] = None,
941
+ output_attentions: Optional[bool] = None,
942
+ output_retrieved_memory_idx: Optional[bool] = None,
943
+ output_hidden_states: Optional[bool] = None,
944
+ return_dict: Optional[bool] = None,
945
+ use_external_mind: Optional[bool] = None,
946
+ long_range_past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
947
+ faiss_indexes: Tuple = None,
948
+ topk: int = None,
949
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
950
+ """forward"""
951
+ output_attentions = (
952
+ output_attentions
953
+ if output_attentions is not None
954
+ else self.config.output_attentions
955
+ )
956
+ output_retrieved_memory_idx = (
957
+ output_retrieved_memory_idx
958
+ if output_retrieved_memory_idx is not None
959
+ else False
960
+ )
961
+ output_hidden_states = (
962
+ output_hidden_states
963
+ if output_hidden_states is not None
964
+ else self.config.output_hidden_states
965
+ )
966
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
967
+
968
+ return_dict = (
969
+ return_dict if return_dict is not None else self.config.use_return_dict
970
+ )
971
+ use_external_mind = (
972
+ use_external_mind
973
+ if use_external_mind is not None
974
+ else self.use_external_mind
975
+ )
976
+ topk = topk if topk is not None else self.topk
977
+
978
+ # retrieve input_ids and inputs_embeds
979
+ if input_ids is not None and inputs_embeds is not None:
980
+ raise ValueError(
981
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
982
+ )
983
+ elif input_ids is not None:
984
+ batch_size, seq_length = input_ids.shape
985
+ elif inputs_embeds is not None:
986
+ batch_size, seq_length, _ = inputs_embeds.shape
987
+ else:
988
+ raise ValueError(
989
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
990
+ )
991
+
992
+ seq_length_with_past = seq_length
993
+ past_key_values_length = 0
994
+
995
+ if past_key_values is not None:
996
+ past_key_values_length = past_key_values[0][0].shape[2]
997
+ seq_length_with_past = seq_length_with_past + past_key_values_length
998
+
999
+ # EM: Range of position ids is total seq length since we apply rotary pos emb after reading from cache
1000
+ if position_ids is None:
1001
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1002
+ position_ids = torch.arange(
1003
+ seq_length_with_past,
1004
+ dtype=torch.long,
1005
+ device=device,
1006
+ )
1007
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length_with_past)
1008
+ else:
1009
+ position_ids = position_ids.view(-1, seq_length_with_past).long()
1010
+
1011
+ if inputs_embeds is None:
1012
+ inputs_embeds = self.embed_tokens(input_ids)
1013
+ # embed positions
1014
+ if attention_mask is None:
1015
+ attention_mask = torch.ones(
1016
+ (batch_size, seq_length_with_past),
1017
+ dtype=torch.bool,
1018
+ device=inputs_embeds.device,
1019
+ )
1020
+ attention_mask = self._prepare_decoder_attention_mask(
1021
+ attention_mask,
1022
+ (batch_size, seq_length),
1023
+ inputs_embeds,
1024
+ past_key_values_length,
1025
+ )
1026
+
1027
+ hidden_states = inputs_embeds
1028
+
1029
+ if self.gradient_checkpointing and self.training:
1030
+ if use_cache:
1031
+ logger.warning_once(
1032
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1033
+ )
1034
+ use_cache = False
1035
+
1036
+ # decoder layers
1037
+ all_hidden_states = () if output_hidden_states else None
1038
+ all_self_attns = () if output_attentions else None
1039
+ next_decoder_cache = () if use_cache else None
1040
+ all_idx = () if output_retrieved_memory_idx else None
1041
+
1042
+ for idx, decoder_layer in enumerate(self.layers):
1043
+ if output_hidden_states:
1044
+ all_hidden_states += (hidden_states,)
1045
+
1046
+ past_key_value = (
1047
+ past_key_values[idx] if past_key_values is not None else None
1048
+ )
1049
+
1050
+ long_range_past_key_value = (
1051
+ long_range_past_key_values[idx]
1052
+ if (
1053
+ long_range_past_key_values is not None
1054
+ and self.use_external_mind_by_layer[idx]
1055
+ and use_external_mind is True
1056
+ )
1057
+ else None
1058
+ )
1059
+
1060
+ if long_range_past_key_value is not None and faiss_indexes is not None:
1061
+ raise NotImplementedError(
1062
+ """Using faiss and passing key value pairs
1063
+ manually are mutually exclusive right now."""
1064
+ )
1065
+
1066
+ if self.gradient_checkpointing and self.training:
1067
+
1068
+ def create_custom_forward(module):
1069
+ def custom_forward(*inputs):
1070
+ # None for past_key_value
1071
+ return module(*inputs, past_key_value, output_attentions)
1072
+
1073
+ return custom_forward
1074
+
1075
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1076
+ create_custom_forward(decoder_layer),
1077
+ hidden_states,
1078
+ attention_mask,
1079
+ position_ids,
1080
+ )
1081
+ else:
1082
+ layer_outputs = decoder_layer(
1083
+ hidden_states,
1084
+ attention_mask=attention_mask,
1085
+ position_ids=position_ids,
1086
+ past_key_value=past_key_value,
1087
+ output_attentions=output_attentions,
1088
+ output_retrieved_memory_idx=output_retrieved_memory_idx,
1089
+ use_cache=use_cache,
1090
+ topk=topk,
1091
+ long_range_past_key_value=long_range_past_key_value,
1092
+ faiss_indexes=faiss_indexes,
1093
+ mask_by_sim=self.mask_by_sim,
1094
+ sim_threshold=self.sim_threshold,
1095
+ current_layer=idx,
1096
+ )
1097
+
1098
+ hidden_states = layer_outputs[0]
1099
+
1100
+ if use_cache:
1101
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1102
+
1103
+ if output_attentions:
1104
+ all_self_attns += (layer_outputs[1],)
1105
+
1106
+ if output_retrieved_memory_idx:
1107
+ idx = (
1108
+ 3
1109
+ if (use_cache & output_attentions)
1110
+ else 2
1111
+ if (use_cache or output_attentions)
1112
+ else 1
1113
+ )
1114
+ all_idx += (layer_outputs[idx],) # Record which memories were retrieved
1115
+ hidden_states = self.norm(hidden_states)
1116
+
1117
+ # add hidden states from the last decoder layer
1118
+ if output_hidden_states:
1119
+ all_hidden_states += (hidden_states,)
1120
+
1121
+ next_cache = next_decoder_cache if use_cache else None
1122
+ if not return_dict:
1123
+ return tuple(
1124
+ v
1125
+ for v in [
1126
+ hidden_states,
1127
+ next_cache,
1128
+ all_hidden_states,
1129
+ all_self_attns,
1130
+ all_idx,
1131
+ ]
1132
+ if v is not None
1133
+ )
1134
+ return BaseModelOutputWithPast(
1135
+ last_hidden_state=hidden_states,
1136
+ past_key_values=next_cache,
1137
+ hidden_states=all_hidden_states,
1138
+ attentions=(all_self_attns, all_idx), # EM: Return idx of retrieved memories
1139
+ )
1140
+
1141
+
1142
+ class ExtendedLlamaForCausalLM(LlamaPreTrainedModel):
1143
+ """LlamaForCausalLM"""
1144
+
1145
+ _tied_weights_keys = ["lm_head.weight"]
1146
+
1147
+ def __init__(self, config, external_memories=None):
1148
+ super().__init__(config)
1149
+ self.model = ExtendedLlamaModel(config)
1150
+ self.vocab_size = config.vocab_size
1151
+ self.tokenizer_all_special_ids = config.tokenizer_all_special_ids
1152
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1153
+
1154
+ self.use_external_mind = config.use_external_mind
1155
+ self.memory_type = config.memory_type
1156
+ self.memory_device = config.memory_device
1157
+ self.remove_special_ids = config.remove_special_ids
1158
+ self.memory_ids = None
1159
+ self.memories = None
1160
+
1161
+ # EM: Memory token ids
1162
+ if external_memories is not None:
1163
+ self.memory_ids = external_memories
1164
+
1165
+ # Initialize weights and apply final processing
1166
+ self.post_init()
1167
+
1168
+ # EM: Clear memory cache
1169
+ def clear_memory(self):
1170
+ """Clear memory cache."""
1171
+ self.memory_ids = None
1172
+ self.memories = None
1173
+
1174
+ def get_input_embeddings(self):
1175
+ return self.model.embed_tokens
1176
+
1177
+ def set_input_embeddings(self, value):
1178
+ self.model.embed_tokens = value
1179
+
1180
+ def get_output_embeddings(self):
1181
+ return self.lm_head
1182
+
1183
+ def set_output_embeddings(self, new_embeddings):
1184
+ """Set output embeddings."""
1185
+ self.lm_head = new_embeddings
1186
+
1187
+ def set_decoder(self, decoder):
1188
+ """Set decoder."""
1189
+ self.model = decoder
1190
+
1191
+ def get_decoder(self):
1192
+ """Get decoder."""
1193
+ return self.model
1194
+
1195
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1196
+ @replace_return_docstrings(
1197
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1198
+ )
1199
+ def forward(
1200
+ self,
1201
+ input_ids: torch.LongTensor = None,
1202
+ attention_mask: Optional[torch.Tensor] = None,
1203
+ position_ids: Optional[torch.LongTensor] = None,
1204
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1205
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1206
+ labels: Optional[torch.LongTensor] = None,
1207
+ use_cache: Optional[bool] = None,
1208
+ output_attentions: Optional[bool] = None,
1209
+ output_hidden_states: Optional[bool] = None,
1210
+ output_retrieved_memory_idx: Optional[bool] = None,
1211
+ return_dict: Optional[bool] = None,
1212
+ use_external_mind: Optional[bool] = None,
1213
+ topk: int = None,
1214
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1215
+ r"""
1216
+ Args:
1217
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1218
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1219
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1220
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1221
+
1222
+ Returns:
1223
+
1224
+ Example:
1225
+
1226
+ ```python
1227
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1228
+
1229
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1230
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1231
+
1232
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1233
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1234
+
1235
+ >>> # Generate
1236
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1237
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1238
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1239
+ ```"""
1240
+
1241
+ # EM: Generate key value cache once on first call
1242
+ if (
1243
+ self.memory_ids is not None and self.memories is None
1244
+ ):
1245
+ self.memories = self.generate_cache(
1246
+ torch.tensor(self.memory_ids, device=self.device),
1247
+ cache_type=self.memory_type,
1248
+ )
1249
+ # EM: Remove special tokens from memory cache
1250
+ if self.remove_special_ids:
1251
+ idx_to_remove = [
1252
+ token_idx
1253
+ for token_idx, token in enumerate(self.memory_ids[0])
1254
+ if token in self.tokenizer_all_special_ids
1255
+ ]
1256
+ if self.memory_type == "manual":
1257
+ mask = torch.ones(self.memories[0][0].size(), dtype=torch.bool)
1258
+ mask[:, :, idx_to_remove, :] = False
1259
+
1260
+ new_size = (
1261
+ self.memories[0][0].size(0),
1262
+ self.memories[0][0].size(1),
1263
+ -1,
1264
+ self.memories[0][0].size(3),
1265
+ )
1266
+ self.memories = [
1267
+ (ks[mask].view(new_size), vs[mask].view(new_size))
1268
+ for ks, vs in self.memories
1269
+ ]
1270
+ else:
1271
+ kn_index, kv_index = self.memories
1272
+ all_idx_to_remove = [
1273
+ [
1274
+ i
1275
+ for i in range(0, kn_index.ntotal)
1276
+ if (
1277
+ i
1278
+ % (
1279
+ kn_index.ntotal
1280
+ / (
1281
+ self.config.num_attention_heads
1282
+ * self.config.num_hidden_layers
1283
+ )
1284
+ )
1285
+ )
1286
+ == j
1287
+ ]
1288
+ for j in idx_to_remove
1289
+ ]
1290
+ kn_index.remove_ids(
1291
+ np.array(all_idx_to_remove).flatten().astype("int64")
1292
+ )
1293
+ kv_index.remove_ids(
1294
+ np.array(all_idx_to_remove).flatten().astype("int64")
1295
+ )
1296
+
1297
+ output_attentions = (
1298
+ output_attentions
1299
+ if output_attentions is not None
1300
+ else self.config.output_attentions
1301
+ )
1302
+ output_retrieved_memory_idx = (
1303
+ output_retrieved_memory_idx
1304
+ if output_retrieved_memory_idx is not None
1305
+ else False
1306
+ )
1307
+ output_hidden_states = (
1308
+ output_hidden_states
1309
+ if output_hidden_states is not None
1310
+ else self.config.output_hidden_states
1311
+ )
1312
+ return_dict = (
1313
+ return_dict if return_dict is not None else self.config.use_return_dict
1314
+ )
1315
+
1316
+ use_external_mind = (
1317
+ use_external_mind
1318
+ if use_external_mind is not None
1319
+ else self.use_external_mind
1320
+ )
1321
+ topk = topk if topk is not None else None
1322
+
1323
+ long_range_past_key_values = None
1324
+ faiss_indexes = None
1325
+ if hasattr(self, "memories") and isinstance(self.memories, list):
1326
+ long_range_past_key_values = self.memories
1327
+ elif hasattr(self, "memories"):
1328
+ faiss_indexes = self.memories
1329
+
1330
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1331
+ outputs = self.model(
1332
+ input_ids=input_ids,
1333
+ attention_mask=attention_mask,
1334
+ position_ids=position_ids,
1335
+ past_key_values=past_key_values,
1336
+ inputs_embeds=inputs_embeds,
1337
+ use_cache=use_cache,
1338
+ output_attentions=output_attentions,
1339
+ output_retrieved_memory_idx=output_retrieved_memory_idx,
1340
+ output_hidden_states=output_hidden_states,
1341
+ return_dict=return_dict,
1342
+ long_range_past_key_values=long_range_past_key_values,
1343
+ faiss_indexes=faiss_indexes,
1344
+ use_external_mind=use_external_mind,
1345
+ topk=topk,
1346
+ )
1347
+
1348
+ hidden_states = outputs[0]
1349
+ if self.config.pretraining_tp > 1:
1350
+ lm_head_slices = self.lm_head.weight.split(
1351
+ self.vocab_size // self.config.pretraining_tp, dim=0
1352
+ )
1353
+ logits = [
1354
+ F.linear(hidden_states, lm_head_slices[i])
1355
+ for i in range(self.config.pretraining_tp)
1356
+ ]
1357
+ logits = torch.cat(logits, dim=-1)
1358
+ else:
1359
+ logits = self.lm_head(hidden_states)
1360
+ logits = logits.float()
1361
+
1362
+ loss = None
1363
+ if labels is not None:
1364
+ # Shift so that tokens < n predict n
1365
+ shift_logits = logits[..., :-1, :].contiguous()
1366
+ shift_labels = labels[..., 1:].contiguous()
1367
+ # Flatten the tokens
1368
+ loss_fct = CrossEntropyLoss()
1369
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1370
+ shift_labels = shift_labels.view(-1)
1371
+ # Enable model parallelism
1372
+ shift_labels = shift_labels.to(shift_logits.device)
1373
+ loss = loss_fct(shift_logits, shift_labels)
1374
+
1375
+ if not return_dict:
1376
+ output = (logits,) + outputs[1:]
1377
+ return (loss,) + output if loss is not None else output
1378
+
1379
+ return CausalLMOutputWithPast(
1380
+ loss=loss,
1381
+ logits=logits,
1382
+ past_key_values=outputs.past_key_values,
1383
+ hidden_states=outputs.hidden_states,
1384
+ attentions=outputs.attentions,
1385
+ )
1386
+
1387
+ # EM: Add method to generate key-value cache
1388
+ def generate_cache(
1389
+ self,
1390
+ input_ids: torch.LongTensor,
1391
+ stride: int = 512,
1392
+ max_len: int = 3072,
1393
+ cache_type: str = "manual",
1394
+ ):
1395
+ """Stride over memory inputs to get kv pairs"""
1396
+ if cache_type not in ["manual", "faiss"]:
1397
+ raise NotImplementedError(f"Cache type {cache_type} not implemented.")
1398
+
1399
+ prev_end_loc = 0
1400
+ long_range_past_key_values = None
1401
+ faiss_indexes = None
1402
+ for b_idx in range(
1403
+ 0, input_ids.size(-1), stride
1404
+ ): # generate kv-pairs using stride
1405
+ end_loc = min(b_idx + max_len, input_ids.size(-1))
1406
+ trg_len = end_loc - prev_end_loc
1407
+ subseq = input_ids[:, b_idx:end_loc].to(self.model.device)
1408
+ with torch.inference_mode():
1409
+ outputs = self.model(
1410
+ subseq,
1411
+ use_cache=True,
1412
+ use_external_mind=False,
1413
+ )
1414
+ to_cache = [
1415
+ (kv[0][:, :, -trg_len:], kv[1][:, :, -trg_len:])
1416
+ for kv in outputs.past_key_values
1417
+ ]
1418
+ long_range_past_key_values, faiss_indexes = self.cache(
1419
+ to_cache,
1420
+ cache_type,
1421
+ long_range_past_key_values=long_range_past_key_values,
1422
+ faiss_indexes=faiss_indexes,
1423
+ )
1424
+
1425
+ prev_end_loc = end_loc
1426
+ if end_loc == input_ids.size(-1):
1427
+ break
1428
+ if long_range_past_key_values is not None:
1429
+ return long_range_past_key_values
1430
+ else:
1431
+ return faiss_indexes
1432
+
1433
+ # EM: Add method to cache key value pairs
1434
+ def cache(
1435
+ self,
1436
+ to_cache: List,
1437
+ cache_type: str = "manual",
1438
+ long_range_past_key_values: List = None,
1439
+ faiss_indexes: faiss.IndexFlatIP = None,
1440
+ max_length_cache=100000,
1441
+ verbose=False,
1442
+ ):
1443
+ """Cache key value pairs for Extended Mind attention."""
1444
+ if (long_range_past_key_values is not None) & (faiss_indexes is not None):
1445
+ raise NotImplementedError(
1446
+ "Using faiss and passing key value pairs manually are mutually exclusive right now."
1447
+ )
1448
+ # To avoid spinning up a new index for each layer, we add one-hot encodings to the keys so that queries match with the appropriate layer, head
1449
+ if cache_type == "faiss": # add one-hot encoding to match layer, head indices
1450
+ one_hot_encodings = (
1451
+ F.one_hot(
1452
+ torch.arange(
1453
+ 0,
1454
+ self.config.num_attention_heads * self.config.num_hidden_layers,
1455
+ )
1456
+ )
1457
+ * 10
1458
+ )
1459
+ # New indices, one to store normalized keys with one-hot encodings, another to retrieve kv pairs without normalization
1460
+ if faiss_indexes is None:
1461
+ faiss_indexes = (
1462
+ faiss.IndexFlatIP(
1463
+ to_cache[0][0].size(-1) + one_hot_encodings.size(-1)
1464
+ ),
1465
+ faiss.IndexFlatIP(to_cache[0][0].size(-1) * 2),
1466
+ )
1467
+ kn_index, kv_index = faiss_indexes
1468
+ for l_idx, (k, v) in enumerate(to_cache):
1469
+ k_n = (k / vector_norm(k, ord=2, dim=-1, keepdim=True)).to("cpu") #Normalize keys for cosine sim
1470
+ # Indices are 2 dimensional, so flatten
1471
+
1472
+ # Add normalized keys with one-hot encodings
1473
+ k_n = torch.concat(
1474
+ [
1475
+ rearrange(
1476
+ k_n,
1477
+ "b h s d -> b (h s) d",
1478
+ h=self.config.num_attention_heads,
1479
+ ),
1480
+ one_hot_encodings[
1481
+ self.config.num_attention_heads
1482
+ * l_idx : self.config.num_attention_heads
1483
+ * (l_idx + 1)
1484
+ ]
1485
+ .unsqueeze(0)
1486
+ .repeat_interleave(repeats=k.size(-2), dim=-2),
1487
+ ],
1488
+ dim=-1,
1489
+ )
1490
+ kn_index.add(k_n.squeeze().numpy())
1491
+
1492
+ # Add unnormalized keys and values
1493
+ k = rearrange(
1494
+ k, "b h s d -> b (h s) d", h=self.config.num_attention_heads
1495
+ )
1496
+ v = rearrange(
1497
+ v, "b h s d -> b (h s) d", h=self.config.num_attention_heads
1498
+ )
1499
+ kv_index.add(
1500
+ torch.concat([k.squeeze(), v.squeeze()], dim=1).to("cpu").numpy()
1501
+ )
1502
+ else:
1503
+ # Simply use list to store key value pairs
1504
+ if long_range_past_key_values is None:
1505
+ long_range_past_key_values = [
1506
+ (k.to(self.memory_device), v.to(self.memory_device))
1507
+ for k, v in to_cache
1508
+ ]
1509
+ else:
1510
+ long_range_past_key_values = [
1511
+ (
1512
+ torch.concat(
1513
+ [kv[0], to_cache[ind][0].to(self.memory_device)], dim=2
1514
+ ),
1515
+ torch.concat(
1516
+ [kv[1], to_cache[ind][1].to(self.memory_device)], dim=2
1517
+ ),
1518
+ )
1519
+ for ind, kv in enumerate(long_range_past_key_values)
1520
+ ]
1521
+ if (
1522
+ long_range_past_key_values is not None
1523
+ ): # set a limit on manual memory length
1524
+ if long_range_past_key_values[0][0].size(-2) > max_length_cache:
1525
+ long_range_past_key_values = [
1526
+ (kv[0][:, :, -max_length_cache:], kv[1][:, :, -max_length_cache:])
1527
+ for kv in long_range_past_key_values
1528
+ ]
1529
+ if verbose:
1530
+ if cache_type == "faiss":
1531
+ print(f"{kn_index.ntotal} keys in faiss index")
1532
+ else:
1533
+ print(f"{long_range_past_key_values[0][0].size(-2)} cached kvs")
1534
+
1535
+ return (
1536
+ long_range_past_key_values,
1537
+ (kn_index, kv_index) if cache_type == "faiss" else None,
1538
+ )
1539
+
1540
+ def prepare_inputs_for_generation(
1541
+ self,
1542
+ input_ids,
1543
+ past_key_values=None,
1544
+ attention_mask=None,
1545
+ inputs_embeds=None,
1546
+ **kwargs,
1547
+ ):
1548
+ if past_key_values:
1549
+ input_ids = input_ids[:, -1:]
1550
+
1551
+ position_ids = kwargs.get("position_ids", None)
1552
+ if attention_mask is not None and position_ids is None:
1553
+ # create position_ids on the fly for batch generation
1554
+ position_ids = attention_mask.long().cumsum(-1) - 1
1555
+ position_ids.masked_fill_(attention_mask == 0, 1)
1556
+
1557
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1558
+ if inputs_embeds is not None and past_key_values is None:
1559
+ model_inputs = {"inputs_embeds": inputs_embeds}
1560
+ else:
1561
+ model_inputs = {"input_ids": input_ids}
1562
+
1563
+ model_inputs.update(
1564
+ {
1565
+ "position_ids": position_ids,
1566
+ "past_key_values": past_key_values,
1567
+ "use_cache": kwargs.get("use_cache"),
1568
+ "attention_mask": attention_mask,
1569
+ "use_external_mind": kwargs.get("use_external_mind"), # EM: Add config here
1570
+ "topk": kwargs.get("topk"),
1571
+ }
1572
+ )
1573
+ return model_inputs
1574
+
1575
+ @staticmethod
1576
+ def _reorder_cache(past_key_values, beam_idx):
1577
+ reordered_past = ()
1578
+ for layer_past in past_key_values:
1579
+ reordered_past += (
1580
+ tuple(
1581
+ past_state.index_select(0, beam_idx.to(past_state.device))
1582
+ for past_state in layer_past
1583
+ ),
1584
+ )
1585
+ return reordered_past