hyunwoo3235 commited on
Commit
6c1ac22
1 Parent(s): 3fcb4d5

Upload 8 files

Browse files
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.1,
3
+ "architectures": [
4
+ "RetNetForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_retnet.RetNetConfig",
8
+ "FlaxAutoModel": "modeling_flax_retnet.FlaxRetNetModel",
9
+ "FlaxAutoModelForCausalLM": "modeling_flax_retnet.FlaxRetNetForCausalLM"
10
+ },
11
+ "attention_type": "parallel",
12
+ "dropout": 0.1,
13
+ "hidden_act": "gelu",
14
+ "hidden_size": 768,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 1536,
17
+ "layer_norm_eps": 1e-05,
18
+ "max_position_embeddings": 512,
19
+ "model_type": "retnet",
20
+ "normalize_before": false,
21
+ "num_hidden_layers": 12,
22
+ "num_rettention_heads": 4,
23
+ "output_retentions": false,
24
+ "recurrent_chunk_size": 512,
25
+ "transformers_version": "4.29.2",
26
+ "vocab_size": 50432
27
+ }
configuration_retnet.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class RetNetConfig(PretrainedConfig):
5
+ model_type = "retnet"
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size=32000,
10
+ hidden_size=512,
11
+ num_hidden_layers=6,
12
+ num_rettention_heads=8,
13
+ intermediate_size=2048,
14
+ hidden_act="gelu",
15
+ max_position_embeddings=512,
16
+ initializer_range=0.02,
17
+ layer_norm_eps=1e-5,
18
+ dropout=0.1,
19
+ activation_dropout=0.0,
20
+ normalize_before=False,
21
+ attention_type="parallel",
22
+ recurrent_chunk_size=512,
23
+ output_retentions=False,
24
+ output_hidden_states=False,
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+
29
+ self.vocab_size = vocab_size
30
+ self.hidden_size = hidden_size
31
+ self.num_hidden_layers = num_hidden_layers
32
+ self.num_rettention_heads = num_rettention_heads
33
+ self.intermediate_size = intermediate_size
34
+ self.hidden_act = hidden_act
35
+ self.attention_type = attention_type
36
+ self.max_position_embeddings = max_position_embeddings
37
+ self.initializer_range = initializer_range
38
+ self.layer_norm_eps = layer_norm_eps
39
+ self.dropout = dropout
40
+ self.normalize_before = normalize_before
41
+ self.activation_dropout = activation_dropout
42
+ self.recurrent_chunk_size = recurrent_chunk_size
43
+ self.output_retentions = output_retentions
44
+ self.output_hidden_states = output_hidden_states
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c90e66fc81a33a732b498e3ec0dcc271eca0defa146129ad245e8ff45387595d
3
+ size 650154262
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.30.2",
4
+ "eos_token_id": 0,
5
+ "pad_token_id": 1,
6
+ "use_cache": false
7
+ }
modeling_flax_retnet.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import jax
4
+ from flax import linen as nn
5
+ from flax.core import FrozenDict, unfreeze, freeze
6
+ from flax.traverse_util import flatten_dict, unflatten_dict
7
+ from jax import numpy as jnp
8
+ from transformers import FlaxPreTrainedModel
9
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
10
+ from transformers.modeling_flax_utils import ACT2FN
11
+
12
+ from .configuration_retnet import RetNetConfig
13
+
14
+
15
+ def rotate_every_two(tensor):
16
+ rotate_half_tensor = jnp.stack(
17
+ (-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1
18
+ )
19
+ rotate_half_tensor = rotate_half_tensor.reshape(
20
+ rotate_half_tensor.shape[:-2] + (-1,)
21
+ )
22
+ return rotate_half_tensor
23
+
24
+
25
+ def theta_shift(x, sin, cos):
26
+ return (x * cos) + (rotate_every_two(x) * sin)
27
+
28
+
29
+ class FlaxRetNetRelPos(nn.Module):
30
+ config: RetNetConfig
31
+ dtype: jnp.dtype = jnp.float32
32
+
33
+ def setup(self) -> None:
34
+ angle = 1.0 / (
35
+ 10000
36
+ ** jnp.linspace(
37
+ 0, 1, self.config.hidden_size // self.config.num_rettention_heads // 2
38
+ )
39
+ )
40
+ self.angle = angle.repeat(2).flatten()
41
+ self.decay = jnp.log(
42
+ 1
43
+ - 2
44
+ ** (-5 - jnp.arange(self.config.num_rettention_heads, dtype=jnp.float32))
45
+ )
46
+ self.recurrent_chunk_size = self.config.recurrent_chunk_size
47
+
48
+ def __call__(
49
+ self,
50
+ slen: int,
51
+ activate_recurrent: bool = False,
52
+ chunkwise_recurrent: bool = False,
53
+ ):
54
+ if activate_recurrent:
55
+ sin = jnp.sin(self.angle * (slen - 1))
56
+ cos = jnp.cos(self.angle * (slen - 1))
57
+ retention_rel_pos = ((sin, cos), jnp.exp(self.decay))
58
+ elif chunkwise_recurrent:
59
+ index = jnp.arange(slen)
60
+ sin = jnp.sin(index[:, None] * self.angle[None, :])
61
+ cos = jnp.cos(index[:, None] * self.angle[None, :])
62
+
63
+ block_index = jnp.arange(self.recurrent_chunk_size)
64
+ mask = jnp.tril(
65
+ jnp.ones((self.recurrent_chunk_size, self.recurrent_chunk_size))
66
+ )
67
+ mask = jnp.where(
68
+ ~mask.astype(jnp.bool_),
69
+ float("inf"),
70
+ block_index[:, None] - block_index[None, :],
71
+ )
72
+ mask = jnp.exp(mask * self.decay[:, None, None])
73
+ mask = jnp.nan_to_num(mask)
74
+ scale = jnp.sqrt(mask.sum(axis=-1, keepdims=True))
75
+ mask = mask / scale
76
+
77
+ cross_decay = jnp.exp(self.decay * self.recurrent_chunk_size)
78
+ inner_decay = jnp.exp(self.decay[:, None] * (block_index + 1))
79
+ cross_decay = cross_decay[:, None, None]
80
+ inner_decay = inner_decay[:, :, None] / (scale / scale[:, -1, None])
81
+
82
+ retention_rel_pos = ((sin, cos), (mask, cross_decay, inner_decay))
83
+ else:
84
+ index = jnp.arange(slen)
85
+ sin = jnp.sin(index[:, None] * self.angle[None, :])
86
+ cos = jnp.cos(index[:, None] * self.angle[None, :])
87
+ mask = jnp.tril(jnp.ones((slen, slen)))
88
+ mask = jnp.where(
89
+ ~mask.astype(jnp.bool_), float("inf"), index[:, None] - index[None, :]
90
+ )
91
+ mask = jnp.exp(mask * self.decay[:, None, None])
92
+ mask = jnp.nan_to_num(mask)
93
+ mask = mask / jnp.sqrt(mask.sum(axis=-1, keepdims=True))
94
+ retention_rel_pos = ((sin, cos), mask)
95
+
96
+ return retention_rel_pos
97
+
98
+
99
+ class FlaxRetNetFeedForward(nn.Module):
100
+ config: RetNetConfig
101
+ dtype: jnp.dtype = jnp.float32
102
+
103
+ def setup(self) -> None:
104
+ self.fc1 = nn.Dense(
105
+ self.config.intermediate_size,
106
+ kernel_init=nn.initializers.xavier_normal(),
107
+ dtype=self.dtype,
108
+ )
109
+ self.fc2 = nn.Dense(
110
+ self.config.hidden_size,
111
+ kernel_init=nn.initializers.xavier_normal(),
112
+ dtype=self.dtype,
113
+ )
114
+ self.activation_fn = ACT2FN[self.config.hidden_act]
115
+ self.activation_dropout = nn.Dropout(rate=self.config.dropout)
116
+ self.dropout = nn.Dropout(rate=self.config.dropout)
117
+
118
+ def __call__(
119
+ self,
120
+ hidden_states: jnp.ndarray,
121
+ deterministic: bool = True,
122
+ ) -> jnp.ndarray:
123
+ hidden_states = self.fc1(hidden_states)
124
+ hidden_states = self.activation_fn(hidden_states)
125
+ hidden_states = self.activation_dropout(
126
+ hidden_states, deterministic=deterministic
127
+ )
128
+ hidden_states = self.fc2(hidden_states)
129
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
130
+
131
+ return hidden_states
132
+
133
+
134
+ class FlaxRetNetRetention(nn.Module):
135
+ config: RetNetConfig
136
+ dtype: jnp.dtype = jnp.float32
137
+
138
+ def setup(self) -> None:
139
+ self.factor = 2
140
+ self.embed_dim = self.config.hidden_size
141
+ self.num_heads = self.config.num_rettention_heads
142
+ self.head_dim = self.embed_dim * self.factor // self.num_heads
143
+ self.key_dim = self.embed_dim // self.num_heads
144
+ self.scaling = self.key_dim**-0.5
145
+
146
+ self.q_proj = nn.Dense(
147
+ self.embed_dim,
148
+ use_bias=True,
149
+ kernel_init=jax.nn.initializers.xavier_normal(),
150
+ dtype=self.dtype,
151
+ )
152
+ self.k_proj = nn.Dense(
153
+ self.embed_dim,
154
+ use_bias=True,
155
+ kernel_init=jax.nn.initializers.xavier_normal(),
156
+ dtype=self.dtype,
157
+ )
158
+ self.v_proj = nn.Dense(
159
+ self.embed_dim * self.factor,
160
+ use_bias=True,
161
+ kernel_init=jax.nn.initializers.xavier_normal(),
162
+ dtype=self.dtype,
163
+ )
164
+ self.g_proj = nn.Dense(
165
+ self.embed_dim * self.factor,
166
+ use_bias=True,
167
+ kernel_init=nn.initializers.xavier_normal(),
168
+ dtype=self.dtype,
169
+ )
170
+
171
+ self.out_proj = nn.Dense(
172
+ self.embed_dim,
173
+ use_bias=True,
174
+ kernel_init=jax.nn.initializers.xavier_normal(),
175
+ dtype=self.dtype,
176
+ )
177
+
178
+ self.group_norm = nn.LayerNorm(epsilon=1e-6, dtype=self.dtype)
179
+
180
+ def parallel_forward(self, qr, kr, v, mask):
181
+ bsz, tgt_len, embed_dim = v.shape
182
+
183
+ vr = v.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(
184
+ (0, 2, 1, 3)
185
+ )
186
+
187
+ qk_mat = qr @ kr.transpose((0, 1, 3, 2))
188
+ qk_mat = qk_mat * mask
189
+ qk_mat /= jnp.abs(
190
+ jax.lax.stop_gradient(qk_mat).sum(axis=-1, keepdims=True)
191
+ ).clip(min=1)
192
+ output = jnp.matmul(qk_mat, vr)
193
+ output = output.transpose((0, 2, 1, 3))
194
+
195
+ return output
196
+
197
+ def chunk_recurrent_forward(self, qr, kr, v, inner_mask):
198
+ mask, cross_decay, inner_decay = inner_mask
199
+ bsz, tgt_len, embed_dim = v.shape
200
+ chunk_len = mask.shape[1]
201
+ num_chunks = tgt_len // chunk_len
202
+
203
+ assert tgt_len % chunk_len == 0
204
+
205
+ qr = qr.reshape(
206
+ bsz, self.num_heads, num_chunks, chunk_len, self.key_dim
207
+ ).transpose((0, 2, 1, 3, 4))
208
+ kr = kr.reshape(
209
+ bsz, self.num_heads, num_chunks, chunk_len, self.key_dim
210
+ ).transpose((0, 2, 1, 3, 4))
211
+ v = v.reshape(
212
+ bsz, num_chunks, chunk_len, self.num_heads, self.head_dim
213
+ ).transpose((0, 1, 3, 2, 4))
214
+
215
+ kr_t = kr.transpose((0, 1, 2, 4, 3))
216
+
217
+ qk_mat = qr @ kr_t
218
+ qk_mat = qk_mat
219
+ inner_scale = jnp.abs(
220
+ jax.lax.stop_gradient(qk_mat).sum(axis=-1, keepdims=True)
221
+ ).clip(min=1)
222
+ qk_mat = qk_mat / inner_scale
223
+ inner_output = jnp.matmul(qk_mat, v)
224
+
225
+ kv = kr_t @ v
226
+ kv = kv.reshape(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim)
227
+
228
+ kv_recurrent = []
229
+ cross_scale = []
230
+ kv_state = jnp.zeros((bsz, self.num_heads, self.key_dim, self.head_dim))
231
+ kv_scale = jnp.ones((bsz, self.num_heads, 1, 1))
232
+
233
+ for i in range(num_chunks):
234
+ kv_recurrent.append(kv_state / kv_scale)
235
+ cross_scale.append(kv_scale)
236
+
237
+ kv_state = kv_state * cross_decay + kv[:, i]
238
+ kv_scale = (
239
+ jnp.abs(jax.lax.stop_gradient(kv_state).sum(axis=-2, keepdims=True))
240
+ .max(axis=-1, keepdims=True)
241
+ .clip(min=1)
242
+ )
243
+
244
+ kv_recurrent = jnp.stack(kv_recurrent, axis=1)
245
+ cross_scale = jnp.stack(cross_scale, axis=1)
246
+
247
+ all_scale = jnp.maximum(inner_scale, cross_scale)
248
+ align_inner_scale = all_scale / inner_scale
249
+ align_cross_scale = all_scale / cross_scale
250
+
251
+ cross_output = (qr * inner_decay) @ kv_recurrent
252
+ output = inner_output / align_inner_scale + cross_output / align_cross_scale
253
+
254
+ output = output.transpose((0, 2, 1, 3, 4))
255
+ return output
256
+
257
+ def __call__(
258
+ self,
259
+ hidden_states: jnp.ndarray,
260
+ rel_pos: Optional[jnp.ndarray] = None,
261
+ chunkwise_recurrent: bool = True,
262
+ incremental_state=None,
263
+ ) -> jnp.ndarray:
264
+ bsz, tgt_len, _ = hidden_states.shape
265
+ (sin, cos), inner_mask = rel_pos
266
+
267
+ q = self.q_proj(hidden_states)
268
+ k = self.k_proj(hidden_states)
269
+ v = self.v_proj(hidden_states)
270
+ g = self.g_proj(hidden_states)
271
+
272
+ k *= self.scaling
273
+ q = q.reshape(bsz, tgt_len, self.num_heads, self.key_dim).transpose(
274
+ (0, 2, 1, 3)
275
+ )
276
+ k = k.reshape(bsz, tgt_len, self.num_heads, self.key_dim).transpose(
277
+ (0, 2, 1, 3)
278
+ )
279
+
280
+ qr = theta_shift(q, sin, cos)
281
+ kr = theta_shift(k, sin, cos)
282
+
283
+ if incremental_state is not None:
284
+ raise NotImplementedError
285
+ elif self.config.attention_type == "chunkwise_recurrent":
286
+ output = self.chunk_recurrent_forward(qr, kr, v, inner_mask=inner_mask)
287
+ else:
288
+ output = self.parallel_forward(qr, kr, v, inner_mask)
289
+
290
+ output = self.group_norm(output)
291
+ output = output.reshape(bsz, tgt_len, -1)
292
+
293
+ output = nn.swish(g) * output
294
+ output = self.out_proj(output)
295
+
296
+ return output
297
+
298
+
299
+ class FlaxRetNetLayer(nn.Module):
300
+ config: RetNetConfig
301
+ dtype: jnp.dtype = jnp.float32
302
+
303
+ def setup(self) -> None:
304
+ self.retention = FlaxRetNetRetention(self.config, dtype=self.dtype)
305
+ self.retention_layer_norm = nn.LayerNorm(
306
+ epsilon=self.config.layer_norm_eps, dtype=self.dtype
307
+ )
308
+
309
+ self.ffn = FlaxRetNetFeedForward(self.config, dtype=self.dtype)
310
+ self.final_layer_norm = nn.LayerNorm(
311
+ epsilon=self.config.layer_norm_eps, dtype=self.dtype
312
+ )
313
+
314
+ self.dropout_module = nn.Dropout(rate=self.config.dropout)
315
+
316
+ def __call__(
317
+ self,
318
+ hidden_states: jnp.ndarray,
319
+ retention_rel_pos: Optional[tuple] = None,
320
+ deterministic: bool = True,
321
+ ) -> jnp.ndarray:
322
+ residual = hidden_states
323
+ hidden_states = self.retention_layer_norm(hidden_states)
324
+ hidden_states = self.retention(hidden_states, rel_pos=retention_rel_pos)
325
+ hidden_states = self.dropout_module(hidden_states, deterministic=deterministic)
326
+ hidden_states = residual + hidden_states
327
+
328
+ residual = hidden_states
329
+ hidden_states = self.final_layer_norm(hidden_states)
330
+ hidden_states = self.ffn(hidden_states, deterministic=deterministic)
331
+ hidden_states = residual + hidden_states
332
+
333
+ return hidden_states
334
+
335
+
336
+ class FlaxRetNetLayerCollection(nn.Module):
337
+ config: RetNetConfig
338
+ dtype: jnp.dtype = jnp.float32
339
+
340
+ def setup(self) -> None:
341
+ self.layers = [
342
+ FlaxRetNetLayer(self.config, dtype=self.dtype)
343
+ for _ in range(self.config.num_hidden_layers)
344
+ ]
345
+
346
+ def __call__(
347
+ self,
348
+ hidden_states: jnp.ndarray,
349
+ retention_rel_pos: tuple = None,
350
+ deterministic: bool = True,
351
+ output_retentions: bool = False,
352
+ output_hidden_states: bool = False,
353
+ return_dict: bool = True,
354
+ ) -> jnp.ndarray:
355
+ all_hidden_states = () if output_hidden_states else None
356
+ all_retentions = () if output_retentions else None
357
+
358
+ for layer in self.layers:
359
+ if output_hidden_states:
360
+ all_hidden_states += (hidden_states,)
361
+
362
+ layer_outputs = layer(
363
+ hidden_states,
364
+ retention_rel_pos=retention_rel_pos,
365
+ deterministic=deterministic,
366
+ )
367
+ hidden_states = layer_outputs
368
+
369
+ outputs = (hidden_states, all_hidden_states, all_retentions)
370
+ return outputs
371
+
372
+
373
+ class FlaxRetNetPretrainedModel(FlaxPreTrainedModel):
374
+ config_class = RetNetConfig
375
+ base_model_prefix = "transformer"
376
+ main_input_name = "input_ids"
377
+ module_class: nn.Module = None
378
+
379
+ def __init__(
380
+ self,
381
+ config: RetNetConfig,
382
+ input_shape: Tuple = (1, 1),
383
+ seed: int = 0,
384
+ dtype: jnp.dtype = jnp.float32,
385
+ _do_init: bool = True,
386
+ **kwargs
387
+ ):
388
+ module = self.module_class(config, dtype=dtype, **kwargs)
389
+ super().__init__(
390
+ config,
391
+ module,
392
+ input_shape=input_shape,
393
+ seed=seed,
394
+ dtype=dtype,
395
+ _do_init=_do_init,
396
+ )
397
+
398
+ def init_weights(
399
+ self,
400
+ rng: jax.random.PRNGKey,
401
+ input_shape: Tuple,
402
+ params: FrozenDict = None,
403
+ ) -> FrozenDict:
404
+ input_ids = jnp.zeros(input_shape, dtype="i4")
405
+ attention_mask = jnp.ones_like(input_ids)
406
+ params_rng, dropout_rng = jax.random.split(rng)
407
+ rngs = {"params": params_rng, "dropout": dropout_rng}
408
+
409
+ module_init_outputs = self.module.init(
410
+ rngs, input_ids, attention_mask, return_dict=False
411
+ )
412
+
413
+ random_params = module_init_outputs["params"]
414
+
415
+ if params is not None:
416
+ random_params = flatten_dict(unfreeze(random_params))
417
+ params = flatten_dict(unfreeze(params))
418
+ for missing_key in self._missing_keys:
419
+ params[missing_key] = random_params[missing_key]
420
+ self._missing_keys = []
421
+ return freeze(unflatten_dict(params))
422
+ else:
423
+ return random_params
424
+
425
+ def __call__(
426
+ self,
427
+ input_ids: jnp.ndarray,
428
+ attention_mask: Optional[jnp.ndarray] = None,
429
+ params: dict = None,
430
+ dropout_rng: jnp.ndarray = None,
431
+ train: bool = False,
432
+ output_retentions: bool = False,
433
+ output_hidden_states: bool = False,
434
+ return_dict: bool = True,
435
+ ):
436
+ output_retentions = (
437
+ output_retentions
438
+ if output_retentions is not None
439
+ else self.config.output_retentions
440
+ )
441
+ output_hidden_states = (
442
+ output_hidden_states
443
+ if output_hidden_states is not None
444
+ else self.config.output_hidden_states
445
+ )
446
+ return_dict = (
447
+ return_dict if return_dict is not None else self.config.return_dict
448
+ )
449
+
450
+ batch_size, sequence_length = input_ids.shape
451
+
452
+ if attention_mask is None:
453
+ attention_mask = jnp.ones((batch_size, sequence_length))
454
+
455
+ rngs = {}
456
+ if dropout_rng is not None:
457
+ rngs["dropout"] = dropout_rng
458
+
459
+ inputs = {"params": params or self.params}
460
+
461
+ outputs = self.module.apply(
462
+ inputs,
463
+ jnp.array(input_ids, dtype="i4"),
464
+ jnp.array(attention_mask, dtype="i4"),
465
+ not train,
466
+ output_retentions,
467
+ output_hidden_states,
468
+ return_dict,
469
+ rngs=rngs,
470
+ )
471
+
472
+ return outputs
473
+
474
+
475
+ class FlaxRetNetModule(nn.Module):
476
+ config: RetNetConfig
477
+ dtype: jnp.dtype = jnp.float32
478
+
479
+ def setup(self) -> None:
480
+ self.embed_tokens = nn.Embed(
481
+ self.config.vocab_size,
482
+ self.config.hidden_size,
483
+ embedding_init=jax.nn.initializers.xavier_normal(),
484
+ dtype=self.dtype,
485
+ )
486
+ self.retnet_rel_pos = FlaxRetNetRelPos(self.config, dtype=self.dtype)
487
+
488
+ self.layers = FlaxRetNetLayerCollection(self.config, dtype=self.dtype)
489
+
490
+ def __call__(
491
+ self,
492
+ input_ids: jnp.ndarray,
493
+ attention_mask: Optional[jnp.ndarray] = None,
494
+ deterministic: bool = True,
495
+ output_retentions: bool = False,
496
+ output_hidden_states: bool = False,
497
+ return_dict: bool = True,
498
+ ):
499
+ input_embeds = self.embed_tokens(input_ids)
500
+
501
+ batch_size, sequence_length = input_embeds.shape[:2]
502
+ retention_rel_pos = self.retnet_rel_pos(
503
+ sequence_length,
504
+ activate_recurrent=False,
505
+ chunkwise_recurrent=self.config.attention_type == "chunkwise_recurrent",
506
+ )
507
+
508
+ outputs = self.layers(
509
+ input_embeds,
510
+ retention_rel_pos=retention_rel_pos,
511
+ deterministic=deterministic,
512
+ output_retentions=output_retentions,
513
+ output_hidden_states=output_hidden_states,
514
+ return_dict=return_dict,
515
+ )
516
+
517
+ if not return_dict:
518
+ return tuple(v for v in outputs if v is not None)
519
+
520
+ return FlaxBaseModelOutput(
521
+ last_hidden_state=outputs[0],
522
+ hidden_states=outputs[1],
523
+ attentions=outputs[-1],
524
+ )
525
+
526
+
527
+ class FlaxRetNetModel(FlaxRetNetPretrainedModel):
528
+ module_class = FlaxRetNetModule
529
+
530
+
531
+ class FlaxRetNetForCausalLMModule(nn.Module):
532
+ config: RetNetConfig
533
+ dtype: jnp.dtype = jnp.float32
534
+
535
+ def setup(self) -> None:
536
+ self.transformer = FlaxRetNetModule(self.config, dtype=self.dtype)
537
+
538
+ self.lm_head = nn.Dense(
539
+ self.config.vocab_size,
540
+ use_bias=False,
541
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
542
+ dtype=self.dtype,
543
+ )
544
+
545
+ def __call__(
546
+ self,
547
+ input_ids: jnp.ndarray,
548
+ attention_mask: Optional[jnp.ndarray] = None,
549
+ deterministic: bool = True,
550
+ output_retentions: bool = False,
551
+ output_hidden_states: bool = False,
552
+ return_dict: bool = True,
553
+ ):
554
+ outputs = self.transformer(
555
+ input_ids,
556
+ attention_mask=attention_mask,
557
+ deterministic=deterministic,
558
+ output_retentions=output_retentions,
559
+ output_hidden_states=output_hidden_states,
560
+ return_dict=return_dict,
561
+ )
562
+ hidden_states = outputs[0]
563
+
564
+ lm_logits = self.lm_head(hidden_states)
565
+
566
+ if not return_dict:
567
+ return (lm_logits,) + outputs[1:]
568
+
569
+ return FlaxCausalLMOutput(
570
+ logits=lm_logits,
571
+ hidden_states=outputs.hidden_states,
572
+ attentions=outputs.attentions,
573
+ )
574
+
575
+
576
+ class FlaxRetNetForCausalLM(FlaxRetNetPretrainedModel):
577
+ module_class = FlaxRetNetForCausalLMModule
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|endoftext|>",
6
+ "model_max_length": 2048,
7
+ "tokenizer_class": "GPTNeoXTokenizer",
8
+ "unk_token": "<|endoftext|>"
9
+ }