iohadrubin commited on
Commit
931060f
1 Parent(s): b695b9a

Upload neox_model_old.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. neox_model_old.py +783 -0
neox_model_old.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The EleutherAI and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax GPT NeoX model."""
16
+
17
+ from typing import Optional, Tuple
18
+
19
+ import flax.linen as nn
20
+ import jax
21
+ import jax.numpy as jnp
22
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
23
+ from flax.linen import combine_masks, make_causal_mask
24
+ from flax.linen.attention import dot_product_attention_weights
25
+ from flax.traverse_util import flatten_dict, unflatten_dict
26
+ from jax import lax
27
+
28
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
29
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
30
+ from transformers.models.gpt_neox.configuration_gpt_neox import GPTNeoXConfig
31
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ _CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neox-20b"
37
+ _CONFIG_FOR_DOC = "GPTNeoXConfig"
38
+
39
+
40
+ GPT_NEOX_START_DOCSTRING = r"""
41
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
42
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
43
+ etc.)
44
+
45
+ This model is also a Flax nn
46
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
47
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
48
+
49
+ Finally, this model supports inherent JAX features such as:
50
+
51
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
52
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
53
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
54
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
55
+
56
+ Parameters:
57
+ config ([`GPTNeoXConfig`]): Model configuration class with all the parameters of the model.
58
+ Initializing with a config file does not load the weights associated with the model, only the
59
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
60
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
61
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
62
+ `jax.numpy.bfloat16` (on TPUs).
63
+
64
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
65
+ specified all the computation will be performed with the given `dtype`.
66
+
67
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
68
+ parameters.**
69
+
70
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
71
+ [`~FlaxPreTrainedModel.to_bf16`].
72
+ """
73
+
74
+ GPT_NEOX_INPUTS_DOCSTRING = r"""
75
+ Args:
76
+ input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
77
+ `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
78
+
79
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
80
+ [`PreTrainedTokenizer.__call__`] for details.
81
+
82
+ [What are input IDs?](../glossary#input-ids)
83
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
84
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
85
+
86
+ - 1 for tokens that are **not masked**,
87
+ - 0 for tokens that are **masked**.
88
+
89
+ [What are attention masks?](../glossary#attention-mask)
90
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
91
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
92
+ config.max_position_embeddings - 1]`.
93
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
94
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
95
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
96
+ output_attentions (`bool`, *optional*):
97
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
98
+ tensors for more detail.
99
+ output_hidden_states (`bool`, *optional*):
100
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
101
+ more detail.
102
+ return_dict (`bool`, *optional*):
103
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
104
+ """
105
+
106
+
107
+ def rotate_half(hidden_states):
108
+ first_half = hidden_states[..., : hidden_states.shape[-1] // 2]
109
+ second_half = hidden_states[..., hidden_states.shape[-1] // 2 :]
110
+ return jnp.concatenate((-second_half, first_half), axis=-1)
111
+
112
+
113
+ class FlaxGPTNeoXRotaryEmbedding(nn.Module):
114
+ dim: int
115
+ max_position_embeddings: int
116
+ base: int = 10000
117
+ dtype: jnp.dtype = jnp.float32
118
+
119
+ def setup(self):
120
+ self.inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2).astype(self.dtype) / self.dim))
121
+ self.cos_cached, self.sin_cached = self._compute_cos_sin(self.max_position_embeddings)
122
+
123
+ def _get_cos_sin_cache(self, seq_len):
124
+ if seq_len > self.max_position_embeddings:
125
+ return self._compute_cos_sin(seq_len)
126
+ else:
127
+ return self.cos_cached, self.sin_cached
128
+
129
+ def _compute_cos_sin(self, seq_len):
130
+ t = jnp.arange(seq_len, dtype=self.inv_freq.dtype)
131
+ freqs = jnp.outer(t, self.inv_freq)
132
+ emb = jnp.concatenate((freqs, freqs), axis=-1)
133
+ cos = jnp.expand_dims(jnp.expand_dims(jnp.cos(emb), 0), 0)
134
+ sin = jnp.expand_dims(jnp.expand_dims(jnp.sin(emb), 0), 0)
135
+ return cos, sin
136
+
137
+ def __call__(self, seq_len=None):
138
+ cos_cached, sin_cached = self._get_cos_sin_cache(seq_len)
139
+ return cos_cached[:seq_len, ...], sin_cached[:seq_len, ...]
140
+
141
+
142
+ class FlaxGPTNeoXLinearScalingRotaryEmbedding(FlaxGPTNeoXRotaryEmbedding):
143
+ """FlaxGPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
144
+
145
+ scaling_factor: float = 1.0
146
+
147
+ def _compute_cos_sin(self, seq_len):
148
+ t = jnp.arange(seq_len, dtype=self.inv_freq.dtype)
149
+ t = t / self.scaling_factor
150
+ freqs = jnp.outer(t, self.inv_freq)
151
+ emb = jnp.concatenate((freqs, freqs), axis=-1)
152
+ cos = jnp.expand_dims(jnp.expand_dims(jnp.cos(emb), 0), 0)
153
+ sin = jnp.expand_dims(jnp.expand_dims(jnp.sin(emb), 0), 0)
154
+ return cos, sin
155
+
156
+
157
+ class FlaxGPTNeoXDynamicNTKScalingRotaryEmbedding(FlaxGPTNeoXRotaryEmbedding):
158
+ """FlaxGPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
159
+
160
+ scaling_factor: float = 1.0
161
+
162
+ def _compute_cos_sin(self, seq_len):
163
+ if seq_len > self.max_position_embeddings:
164
+ base = self.base * (
165
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
166
+ ) ** (self.dim / (self.dim - 2))
167
+ inv_freq = 1.0 / (base ** (jnp.arange(0, self.dim, 2, dtype=self.dtype) / self.dim))
168
+ else:
169
+ inv_freq = self.inv_freq
170
+
171
+ t = jnp.arange(seq_len, dtype=self.dtype)
172
+
173
+ freqs = jnp.outer(t, inv_freq)
174
+ emb = jnp.concatenate((freqs, freqs), axis=-1)
175
+ cos = jnp.expand_dims(jnp.expand_dims(jnp.cos(emb), 0), 0)
176
+ sin = jnp.expand_dims(jnp.expand_dims(jnp.sin(emb), 0), 0)
177
+ return cos, sin
178
+
179
+
180
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
181
+ gather_indices = position_ids[:, :, None, None] # [bs, seq_len, 1, 1]
182
+ gather_indices = jnp.repeat(gather_indices, cos.shape[1], axis=1)
183
+ gather_indices = jnp.repeat(gather_indices, cos.shape[3], axis=3)
184
+ cos = jnp.take_along_axis(cos.repeat(gather_indices.shape[0], axis=0), gather_indices, axis=2)
185
+ sin = jnp.take_along_axis(sin.repeat(gather_indices.shape[0], axis=0), gather_indices, axis=2)
186
+ q_embed = (q * cos) + (rotate_half(q) * sin)
187
+ k_embed = (k * cos) + (rotate_half(k) * sin)
188
+ return q_embed, k_embed
189
+
190
+
191
+ class FlaxGPTNeoXAttention(nn.Module):
192
+ config: GPTNeoXConfig
193
+ dtype: jnp.dtype = jnp.float32
194
+
195
+ def setup(self):
196
+ config = self.config
197
+ self.num_attention_heads = config.num_attention_heads
198
+ self.hidden_size = config.hidden_size
199
+ self.head_size = self.hidden_size // self.num_attention_heads
200
+ self.rotary_ndims = int(self.head_size * config.rotary_pct)
201
+ self.norm_factor = jnp.sqrt(self.head_size)
202
+ self.query_key_value = nn.Dense(
203
+ 3 * config.hidden_size,
204
+ dtype=self.dtype,
205
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
206
+ )
207
+ self.dense = nn.Dense(
208
+ config.hidden_size,
209
+ dtype=self.dtype,
210
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
211
+ )
212
+
213
+ if config.rope_scaling is None:
214
+ max_seq_length = config.max_position_embeddings
215
+ else:
216
+ max_seq_length = int(config.max_position_embeddings * config.rope_scaling["factor"])
217
+
218
+ self.causal_mask = make_causal_mask(jnp.ones((1, max_seq_length), dtype="bool"), dtype="bool")
219
+ self._init_rope()
220
+
221
+ def _init_rope(self):
222
+ if self.config.rope_scaling is None:
223
+ self.rotary_emb = FlaxGPTNeoXRotaryEmbedding(
224
+ self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base
225
+ )
226
+ else:
227
+ scaling_type = self.config.rope_scaling["type"]
228
+ scaling_factor = self.config.rope_scaling["factor"]
229
+ if scaling_type == "linear":
230
+ self.rotary_emb = FlaxGPTNeoXLinearScalingRotaryEmbedding(
231
+ self.rotary_ndims,
232
+ self.config.max_position_embeddings,
233
+ base=self.config.rotary_emb_base,
234
+ scaling_factor=scaling_factor,
235
+ )
236
+ elif scaling_type == "dynamic":
237
+ self.rotary_emb = FlaxGPTNeoXDynamicNTKScalingRotaryEmbedding(
238
+ self.rotary_ndims,
239
+ self.config.max_position_embeddings,
240
+ base=self.config.rotary_emb_base,
241
+ scaling_factor=scaling_factor,
242
+ )
243
+ else:
244
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
245
+
246
+ @nn.compact
247
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache
248
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
249
+ """
250
+ This function takes projected key, value states from a single input token and concatenates the states to cached
251
+ states from previous steps. This function is slighly adapted from the official Flax repository:
252
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
253
+ """
254
+ # detect if we're initializing by absence of existing cache data.
255
+ is_initialized = self.has_variable("cache", "cached_key")
256
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
257
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
258
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
259
+
260
+ if is_initialized:
261
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
262
+ # update key, value caches with our new 1d spatial slices
263
+ cur_index = cache_index.value
264
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
265
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
266
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
267
+ cached_key.value = key
268
+ cached_value.value = value
269
+ num_updated_cache_vectors = query.shape[1]
270
+ cache_index.value = cache_index.value + num_updated_cache_vectors
271
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
272
+ pad_mask = jnp.broadcast_to(
273
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
274
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
275
+ )
276
+ attention_mask = combine_masks(pad_mask, attention_mask)
277
+ return key, value, attention_mask
278
+
279
+ def _split_heads(self, hidden_states):
280
+ return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_attention_heads, self.head_size * 3))
281
+
282
+ def _merge_heads(self, hidden_states):
283
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
284
+
285
+ def __call__(
286
+ self,
287
+ hidden_states,
288
+ attention_mask,
289
+ position_ids,
290
+ deterministic: bool = True,
291
+ init_cache: bool = False,
292
+ output_attentions: bool = False,
293
+ ):
294
+ qkv = self.query_key_value(hidden_states)
295
+ batch, seq_len, _ = qkv.shape
296
+
297
+ # proj q, k, v
298
+ fused_qkv = self.query_key_value(hidden_states)
299
+ fused_qkv = self._split_heads(fused_qkv)
300
+ query, key, value = jnp.split(fused_qkv, 3, axis=-1)
301
+
302
+ cos, sin = self.rotary_emb(seq_len)
303
+ if self.rotary_ndims is not None:
304
+ k_rot = key[..., : self.rotary_ndims]
305
+ k_pass = key[..., self.rotary_ndims :]
306
+
307
+ q_rot = query[..., : self.rotary_ndims]
308
+ q_pass = query[..., self.rotary_ndims :]
309
+
310
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin, position_ids)
311
+
312
+ key = jnp.concatenate([k_rot, k_pass], axis=-1)
313
+ query = jnp.concatenate([q_rot, q_pass], axis=-1)
314
+ else:
315
+ query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
316
+
317
+ query_length, key_length = query.shape[1], key.shape[1]
318
+
319
+ if self.has_variable("cache", "cached_key"):
320
+ mask_shift = self.variables["cache"]["cache_index"]
321
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
322
+
323
+ causal_mask = lax.dynamic_slice(
324
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
325
+ )
326
+ else:
327
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
328
+
329
+ causal_mask = jnp.broadcast_to(causal_mask, (batch,) + causal_mask.shape[1:])
330
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
331
+ attention_mask = combine_masks(attention_mask, causal_mask)
332
+
333
+ dropout_rng = None
334
+ if not deterministic and self.config.attention_dropout > 0.0:
335
+ dropout_rng = self.make_rng("dropout")
336
+
337
+ # During fast autoregressive decoding, we feed one position at a time,
338
+ # and cache the keys and values step by step.
339
+ if self.has_variable("cache", "cached_key") or init_cache:
340
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
341
+
342
+ # transform boolean mask into float mask
343
+ attention_bias = lax.select(
344
+ attention_mask > 0,
345
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
346
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
347
+ )
348
+
349
+ attn_weights = dot_product_attention_weights(
350
+ query,
351
+ key,
352
+ bias=attention_bias,
353
+ dropout_rng=dropout_rng,
354
+ dropout_rate=self.config.attention_dropout,
355
+ deterministic=deterministic,
356
+ dtype=jnp.promote_types(self.dtype, jnp.float32),
357
+ precision=None,
358
+ )
359
+ attn_output = jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
360
+ attn_output = self._merge_heads(attn_output)
361
+ attn_output = self.dense(attn_output)
362
+
363
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
364
+ return outputs
365
+
366
+
367
+ class FlaxGPTNeoXMLP(nn.Module):
368
+ config: GPTNeoXConfig
369
+ dtype: jnp.dtype = jnp.float32
370
+
371
+ def setup(self):
372
+ embed_dim = self.config.hidden_size
373
+ kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
374
+
375
+ self.dense_h_to_4h = nn.Dense(self.config.intermediate_size, dtype=self.dtype, kernel_init=kernel_init)
376
+ self.dense_4h_to_h = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init)
377
+
378
+ self.act = ACT2FN[self.config.hidden_act]
379
+
380
+ def __call__(self, hidden_states):
381
+ hidden_states = self.dense_h_to_4h(hidden_states)
382
+ hidden_states = self.act(hidden_states)
383
+ hidden_states = self.dense_4h_to_h(hidden_states)
384
+ return hidden_states
385
+
386
+
387
+ class FlaxGPTNeoXBlock(nn.Module):
388
+ config: GPTNeoXConfig
389
+ dtype: jnp.dtype = jnp.float32
390
+
391
+ def setup(self):
392
+ self.use_parallel_residual = self.config.use_parallel_residual
393
+ self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
394
+ self.attention = FlaxGPTNeoXAttention(self.config, dtype=self.dtype)
395
+ self.post_attention_dropout = nn.Dropout(rate=self.config.hidden_dropout)
396
+ self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
397
+
398
+ self.mlp = FlaxGPTNeoXMLP(self.config, dtype=self.dtype)
399
+ self.post_mlp_dropout = nn.Dropout(rate=self.config.hidden_dropout)
400
+
401
+ def __call__(
402
+ self,
403
+ hidden_states,
404
+ attention_mask=None,
405
+ position_ids=None,
406
+ deterministic: bool = True,
407
+ init_cache: bool = False,
408
+ output_attentions: bool = False,
409
+ ):
410
+ attn_outputs = self.attention(
411
+ self.input_layernorm(hidden_states),
412
+ attention_mask=attention_mask,
413
+ position_ids=position_ids,
414
+ deterministic=deterministic,
415
+ init_cache=init_cache,
416
+ output_attentions=output_attentions,
417
+ )
418
+ attn_output = attn_outputs[0]
419
+ attn_output = self.post_attention_dropout(attn_output, deterministic=deterministic)
420
+
421
+ if self.use_parallel_residual:
422
+ # pseudocode:
423
+ # x = x + attn(ln1(x)) + mlp(ln2(x))
424
+ mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
425
+ mlp_output = self.post_mlp_dropout(mlp_output, deterministic=deterministic)
426
+ hidden_states = mlp_output + attn_output + hidden_states
427
+ else:
428
+ # pseudocode:
429
+ # x = x + attn(ln1(x))
430
+ # x = x + mlp(ln2(x))
431
+ attn_output = attn_output + hidden_states
432
+ mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
433
+ mlp_output = self.post_mlp_dropout(mlp_output, deterministic=deterministic)
434
+ hidden_states = mlp_output + attn_output
435
+
436
+ return (hidden_states,) + attn_outputs[1:]
437
+
438
+
439
+ class FlaxGPTNeoXPreTrainedModel(FlaxPreTrainedModel):
440
+ """
441
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
442
+ models.
443
+ """
444
+
445
+ config_class = GPTNeoXConfig
446
+ base_model_prefix = "gpt_neox"
447
+ module_class: nn.Module = None
448
+
449
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel.__init__ with GPTNeo->GPTNeoX
450
+ def __init__(
451
+ self,
452
+ config: GPTNeoXConfig,
453
+ input_shape: Tuple = (1, 1),
454
+ seed: int = 0,
455
+ dtype: jnp.dtype = jnp.float32,
456
+ _do_init: bool = True,
457
+ **kwargs,
458
+ ):
459
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
460
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
461
+
462
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel.init_weights with GPTNeo->GPTNeoX
463
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
464
+ # init input tensors
465
+ input_ids = jnp.zeros(input_shape, dtype="i4")
466
+ attention_mask = jnp.ones_like(input_ids)
467
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
468
+ params_rng, dropout_rng = jax.random.split(rng)
469
+ rngs = {"params": params_rng, "dropout": dropout_rng}
470
+
471
+ random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
472
+
473
+ if params is not None:
474
+ random_params = flatten_dict(unfreeze(random_params))
475
+ params = flatten_dict(unfreeze(params))
476
+ for missing_key in self._missing_keys:
477
+ params[missing_key] = random_params[missing_key]
478
+ self._missing_keys = set()
479
+ return freeze(unflatten_dict(params))
480
+ else:
481
+ return random_params
482
+
483
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel.init_cache
484
+ def init_cache(self, batch_size, max_length):
485
+ r"""
486
+ Args:
487
+ batch_size (`int`):
488
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
489
+ max_length (`int`):
490
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
491
+ cache.
492
+ """
493
+ # init input variables to retrieve cache
494
+ input_ids = jnp.ones((batch_size, max_length))
495
+ attention_mask = jnp.ones_like(input_ids)
496
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
497
+
498
+ init_variables = self.module.init(
499
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
500
+ )
501
+ return unfreeze(init_variables["cache"])
502
+
503
+ @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING)
504
+ def __call__(
505
+ self,
506
+ input_ids,
507
+ attention_mask=None,
508
+ position_ids=None,
509
+ params: dict = None,
510
+ past_key_values: dict = None,
511
+ dropout_rng: jax.random.PRNGKey = None,
512
+ train: bool = False,
513
+ output_attentions: Optional[bool] = None,
514
+ output_hidden_states: Optional[bool] = None,
515
+ return_dict: Optional[bool] = None,
516
+ ):
517
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
518
+ output_hidden_states = (
519
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
520
+ )
521
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
522
+
523
+ batch_size, sequence_length = input_ids.shape
524
+
525
+ if position_ids is None:
526
+ if past_key_values is not None:
527
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
528
+
529
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
530
+
531
+ if attention_mask is None:
532
+ attention_mask = jnp.ones((batch_size, sequence_length))
533
+
534
+ # Handle any PRNG if needed
535
+ rngs = {}
536
+ if dropout_rng is not None:
537
+ rngs["dropout"] = dropout_rng
538
+
539
+ inputs = {"params": params or self.params}
540
+
541
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTNeoXAttention module
542
+ if past_key_values:
543
+ inputs["cache"] = past_key_values
544
+ mutable = ["cache"]
545
+ else:
546
+ mutable = False
547
+
548
+ outputs = self.module.apply(
549
+ inputs,
550
+ jnp.array(input_ids, dtype="i4"),
551
+ jnp.array(attention_mask, dtype="i4"),
552
+ jnp.array(position_ids, dtype="i4"),
553
+ not train,
554
+ False,
555
+ output_attentions,
556
+ output_hidden_states,
557
+ return_dict,
558
+ rngs=rngs,
559
+ mutable=mutable,
560
+ )
561
+
562
+ # add updated cache to model output
563
+ if past_key_values is not None and return_dict:
564
+ outputs, past_key_values = outputs
565
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
566
+ return outputs
567
+ elif past_key_values is not None and not return_dict:
568
+ outputs, past_key_values = outputs
569
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
570
+
571
+ return outputs
572
+
573
+
574
+ class FlaxGPTNeoXBlockCollection(nn.Module):
575
+ config: GPTNeoXConfig
576
+ dtype: jnp.dtype = jnp.float32
577
+
578
+ def setup(self):
579
+ self.blocks = [
580
+ FlaxGPTNeoXBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
581
+ ]
582
+
583
+ def __call__(
584
+ self,
585
+ hidden_states,
586
+ attention_mask=None,
587
+ position_ids=None,
588
+ deterministic: bool = True,
589
+ init_cache: bool = False,
590
+ output_attentions: bool = False,
591
+ output_hidden_states: bool = False,
592
+ return_dict: bool = True,
593
+ ):
594
+ all_attentions = () if output_attentions else None
595
+ all_hidden_states = () if output_hidden_states else None
596
+
597
+ for block in self.blocks:
598
+ if output_hidden_states:
599
+ all_hidden_states += (hidden_states,)
600
+
601
+ layer_outputs = block(
602
+ hidden_states,
603
+ attention_mask,
604
+ position_ids=position_ids,
605
+ deterministic=deterministic,
606
+ init_cache=init_cache,
607
+ output_attentions=output_attentions,
608
+ )
609
+ hidden_states = layer_outputs[0]
610
+
611
+ if output_attentions:
612
+ all_attentions += (layer_outputs[1],)
613
+
614
+ # this contains possible `None` values - `FlaxGPTNeoXModule` will filter them out
615
+ outputs = (hidden_states, all_hidden_states, all_attentions)
616
+
617
+ return outputs
618
+
619
+
620
+ class FlaxGPTNeoXModule(nn.Module):
621
+ config: GPTNeoXConfig
622
+ dtype: jnp.dtype = jnp.float32
623
+
624
+ def setup(self):
625
+ self.embed_dim = self.config.hidden_size
626
+
627
+ self.embed_in = nn.Embed(
628
+ self.config.vocab_size,
629
+ self.config.hidden_size,
630
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
631
+ )
632
+ self.emb_dropout = nn.Dropout(self.config.hidden_dropout)
633
+ self.layers = FlaxGPTNeoXBlockCollection(self.config, dtype=self.dtype)
634
+ self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
635
+
636
+ def __call__(
637
+ self,
638
+ input_ids,
639
+ attention_mask=None,
640
+ position_ids=None,
641
+ deterministic=True,
642
+ init_cache: bool = False,
643
+ output_attentions: bool = False,
644
+ output_hidden_states: bool = False,
645
+ return_dict: bool = True,
646
+ ):
647
+ input_embeds = self.embed_in(input_ids.astype("i4"))
648
+ hidden_states = self.emb_dropout(input_embeds, deterministic=deterministic)
649
+
650
+ outputs = self.layers(
651
+ hidden_states,
652
+ attention_mask,
653
+ position_ids=position_ids,
654
+ deterministic=deterministic,
655
+ init_cache=init_cache,
656
+ output_attentions=output_attentions,
657
+ output_hidden_states=output_hidden_states,
658
+ return_dict=return_dict,
659
+ )
660
+
661
+ hidden_states = outputs[0]
662
+ hidden_states = self.final_layer_norm(hidden_states)
663
+
664
+ if output_hidden_states:
665
+ all_hidden_states = outputs[1] + (hidden_states,)
666
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
667
+ else:
668
+ outputs = (hidden_states,) + outputs[1:]
669
+
670
+ if not return_dict:
671
+ return tuple(v for v in outputs if v is not None)
672
+
673
+ return FlaxBaseModelOutput(
674
+ last_hidden_state=hidden_states,
675
+ hidden_states=outputs[1],
676
+ attentions=outputs[-1],
677
+ )
678
+
679
+
680
+ @add_start_docstrings(
681
+ "The bare GPTNeoX Model transformer outputting raw hidden-states without any specific head on top.",
682
+ GPT_NEOX_START_DOCSTRING,
683
+ )
684
+ class FlaxGPTNeoXModel(FlaxGPTNeoXPreTrainedModel):
685
+ module_class = FlaxGPTNeoXModule
686
+
687
+
688
+ append_call_sample_docstring(
689
+ FlaxGPTNeoXModel,
690
+ _CHECKPOINT_FOR_DOC,
691
+ FlaxCausalLMOutput,
692
+ _CONFIG_FOR_DOC,
693
+ )
694
+
695
+
696
+ class FlaxGPTNeoXForCausalLMModule(nn.Module):
697
+ config: GPTNeoXConfig
698
+ dtype: jnp.dtype = jnp.float32
699
+
700
+ def setup(self):
701
+ self.gpt_neox = FlaxGPTNeoXModule(self.config, dtype=self.dtype)
702
+ self.embed_out = nn.Dense(
703
+ self.config.vocab_size,
704
+ dtype=self.dtype,
705
+ use_bias=False,
706
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
707
+ )
708
+
709
+ def __call__(
710
+ self,
711
+ input_ids,
712
+ attention_mask=None,
713
+ position_ids=None,
714
+ deterministic: bool = True,
715
+ init_cache: bool = False,
716
+ output_attentions: bool = False,
717
+ output_hidden_states: bool = False,
718
+ return_dict: bool = True,
719
+ ):
720
+ outputs = self.gpt_neox(
721
+ input_ids,
722
+ attention_mask,
723
+ position_ids,
724
+ deterministic=deterministic,
725
+ init_cache=init_cache,
726
+ output_attentions=output_attentions,
727
+ output_hidden_states=output_hidden_states,
728
+ return_dict=return_dict,
729
+ )
730
+
731
+ hidden_states = outputs[0]
732
+
733
+ lm_logits = self.embed_out(hidden_states)
734
+
735
+ if not return_dict:
736
+ return (lm_logits,) + outputs[1:]
737
+
738
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
739
+
740
+
741
+ @add_start_docstrings(
742
+ """
743
+ The GPTNeoX Model transformer with a language modeling head on top.
744
+ """,
745
+ GPT_NEOX_START_DOCSTRING,
746
+ )
747
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLM with GPTNeo->GPTNeoX
748
+ class FlaxGPTNeoXForCausalLM(FlaxGPTNeoXPreTrainedModel):
749
+ module_class = FlaxGPTNeoXForCausalLMModule
750
+
751
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
752
+ # initializing the cache
753
+ batch_size, seq_length = input_ids.shape
754
+
755
+ past_key_values = self.init_cache(batch_size, max_length)
756
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
757
+ # But since GPTNeoX uses a causal mask, those positions are masked anyways.
758
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
759
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
760
+ if attention_mask is not None:
761
+ position_ids = attention_mask.cumsum(axis=-1) - 1
762
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
763
+ else:
764
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
765
+
766
+ return {
767
+ "past_key_values": past_key_values,
768
+ "attention_mask": extended_attention_mask,
769
+ "position_ids": position_ids,
770
+ }
771
+
772
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
773
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
774
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
775
+ return model_kwargs
776
+
777
+
778
+ append_call_sample_docstring(
779
+ FlaxGPTNeoXForCausalLM,
780
+ _CHECKPOINT_FOR_DOC,
781
+ FlaxCausalLMOutput,
782
+ _CONFIG_FOR_DOC,
783
+ )