File size: 18,764 Bytes
6c1ac22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
from typing import Optional, Tuple

import jax
from flax import linen as nn
from flax.core import FrozenDict, unfreeze, freeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import numpy as jnp
from transformers import FlaxPreTrainedModel
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from transformers.modeling_flax_utils import ACT2FN

from .configuration_retnet import RetNetConfig


def rotate_every_two(tensor):
    rotate_half_tensor = jnp.stack(
        (-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1
    )
    rotate_half_tensor = rotate_half_tensor.reshape(
        rotate_half_tensor.shape[:-2] + (-1,)
    )
    return rotate_half_tensor


def theta_shift(x, sin, cos):
    return (x * cos) + (rotate_every_two(x) * sin)


class FlaxRetNetRelPos(nn.Module):
    config: RetNetConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        angle = 1.0 / (
            10000
            ** jnp.linspace(
                0, 1, self.config.hidden_size // self.config.num_rettention_heads // 2
            )
        )
        self.angle = angle.repeat(2).flatten()
        self.decay = jnp.log(
            1
            - 2
            ** (-5 - jnp.arange(self.config.num_rettention_heads, dtype=jnp.float32))
        )
        self.recurrent_chunk_size = self.config.recurrent_chunk_size

    def __call__(
        self,
        slen: int,
        activate_recurrent: bool = False,
        chunkwise_recurrent: bool = False,
    ):
        if activate_recurrent:
            sin = jnp.sin(self.angle * (slen - 1))
            cos = jnp.cos(self.angle * (slen - 1))
            retention_rel_pos = ((sin, cos), jnp.exp(self.decay))
        elif chunkwise_recurrent:
            index = jnp.arange(slen)
            sin = jnp.sin(index[:, None] * self.angle[None, :])
            cos = jnp.cos(index[:, None] * self.angle[None, :])

            block_index = jnp.arange(self.recurrent_chunk_size)
            mask = jnp.tril(
                jnp.ones((self.recurrent_chunk_size, self.recurrent_chunk_size))
            )
            mask = jnp.where(
                ~mask.astype(jnp.bool_),
                float("inf"),
                block_index[:, None] - block_index[None, :],
            )
            mask = jnp.exp(mask * self.decay[:, None, None])
            mask = jnp.nan_to_num(mask)
            scale = jnp.sqrt(mask.sum(axis=-1, keepdims=True))
            mask = mask / scale

            cross_decay = jnp.exp(self.decay * self.recurrent_chunk_size)
            inner_decay = jnp.exp(self.decay[:, None] * (block_index + 1))
            cross_decay = cross_decay[:, None, None]
            inner_decay = inner_decay[:, :, None] / (scale / scale[:, -1, None])

            retention_rel_pos = ((sin, cos), (mask, cross_decay, inner_decay))
        else:
            index = jnp.arange(slen)
            sin = jnp.sin(index[:, None] * self.angle[None, :])
            cos = jnp.cos(index[:, None] * self.angle[None, :])
            mask = jnp.tril(jnp.ones((slen, slen)))
            mask = jnp.where(
                ~mask.astype(jnp.bool_), float("inf"), index[:, None] - index[None, :]
            )
            mask = jnp.exp(mask * self.decay[:, None, None])
            mask = jnp.nan_to_num(mask)
            mask = mask / jnp.sqrt(mask.sum(axis=-1, keepdims=True))
            retention_rel_pos = ((sin, cos), mask)

        return retention_rel_pos


class FlaxRetNetFeedForward(nn.Module):
    config: RetNetConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        self.fc1 = nn.Dense(
            self.config.intermediate_size,
            kernel_init=nn.initializers.xavier_normal(),
            dtype=self.dtype,
        )
        self.fc2 = nn.Dense(
            self.config.hidden_size,
            kernel_init=nn.initializers.xavier_normal(),
            dtype=self.dtype,
        )
        self.activation_fn = ACT2FN[self.config.hidden_act]
        self.activation_dropout = nn.Dropout(rate=self.config.dropout)
        self.dropout = nn.Dropout(rate=self.config.dropout)

    def __call__(
        self,
        hidden_states: jnp.ndarray,
        deterministic: bool = True,
    ) -> jnp.ndarray:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.activation_dropout(
            hidden_states, deterministic=deterministic
        )
        hidden_states = self.fc2(hidden_states)
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)

        return hidden_states


class FlaxRetNetRetention(nn.Module):
    config: RetNetConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        self.factor = 2
        self.embed_dim = self.config.hidden_size
        self.num_heads = self.config.num_rettention_heads
        self.head_dim = self.embed_dim * self.factor // self.num_heads
        self.key_dim = self.embed_dim // self.num_heads
        self.scaling = self.key_dim**-0.5

        self.q_proj = nn.Dense(
            self.embed_dim,
            use_bias=True,
            kernel_init=jax.nn.initializers.xavier_normal(),
            dtype=self.dtype,
        )
        self.k_proj = nn.Dense(
            self.embed_dim,
            use_bias=True,
            kernel_init=jax.nn.initializers.xavier_normal(),
            dtype=self.dtype,
        )
        self.v_proj = nn.Dense(
            self.embed_dim * self.factor,
            use_bias=True,
            kernel_init=jax.nn.initializers.xavier_normal(),
            dtype=self.dtype,
        )
        self.g_proj = nn.Dense(
            self.embed_dim * self.factor,
            use_bias=True,
            kernel_init=nn.initializers.xavier_normal(),
            dtype=self.dtype,
        )

        self.out_proj = nn.Dense(
            self.embed_dim,
            use_bias=True,
            kernel_init=jax.nn.initializers.xavier_normal(),
            dtype=self.dtype,
        )

        self.group_norm = nn.LayerNorm(epsilon=1e-6, dtype=self.dtype)

    def parallel_forward(self, qr, kr, v, mask):
        bsz, tgt_len, embed_dim = v.shape

        vr = v.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(
            (0, 2, 1, 3)
        )

        qk_mat = qr @ kr.transpose((0, 1, 3, 2))
        qk_mat = qk_mat * mask
        qk_mat /= jnp.abs(
            jax.lax.stop_gradient(qk_mat).sum(axis=-1, keepdims=True)
        ).clip(min=1)
        output = jnp.matmul(qk_mat, vr)
        output = output.transpose((0, 2, 1, 3))

        return output

    def chunk_recurrent_forward(self, qr, kr, v, inner_mask):
        mask, cross_decay, inner_decay = inner_mask
        bsz, tgt_len, embed_dim = v.shape
        chunk_len = mask.shape[1]
        num_chunks = tgt_len // chunk_len

        assert tgt_len % chunk_len == 0

        qr = qr.reshape(
            bsz, self.num_heads, num_chunks, chunk_len, self.key_dim
        ).transpose((0, 2, 1, 3, 4))
        kr = kr.reshape(
            bsz, self.num_heads, num_chunks, chunk_len, self.key_dim
        ).transpose((0, 2, 1, 3, 4))
        v = v.reshape(
            bsz, num_chunks, chunk_len, self.num_heads, self.head_dim
        ).transpose((0, 1, 3, 2, 4))

        kr_t = kr.transpose((0, 1, 2, 4, 3))

        qk_mat = qr @ kr_t
        qk_mat = qk_mat
        inner_scale = jnp.abs(
            jax.lax.stop_gradient(qk_mat).sum(axis=-1, keepdims=True)
        ).clip(min=1)
        qk_mat = qk_mat / inner_scale
        inner_output = jnp.matmul(qk_mat, v)

        kv = kr_t @ v
        kv = kv.reshape(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim)

        kv_recurrent = []
        cross_scale = []
        kv_state = jnp.zeros((bsz, self.num_heads, self.key_dim, self.head_dim))
        kv_scale = jnp.ones((bsz, self.num_heads, 1, 1))

        for i in range(num_chunks):
            kv_recurrent.append(kv_state / kv_scale)
            cross_scale.append(kv_scale)

            kv_state = kv_state * cross_decay + kv[:, i]
            kv_scale = (
                jnp.abs(jax.lax.stop_gradient(kv_state).sum(axis=-2, keepdims=True))
                .max(axis=-1, keepdims=True)
                .clip(min=1)
            )

        kv_recurrent = jnp.stack(kv_recurrent, axis=1)
        cross_scale = jnp.stack(cross_scale, axis=1)

        all_scale = jnp.maximum(inner_scale, cross_scale)
        align_inner_scale = all_scale / inner_scale
        align_cross_scale = all_scale / cross_scale

        cross_output = (qr * inner_decay) @ kv_recurrent
        output = inner_output / align_inner_scale + cross_output / align_cross_scale

        output = output.transpose((0, 2, 1, 3, 4))
        return output

    def __call__(
        self,
        hidden_states: jnp.ndarray,
        rel_pos: Optional[jnp.ndarray] = None,
        chunkwise_recurrent: bool = True,
        incremental_state=None,
    ) -> jnp.ndarray:
        bsz, tgt_len, _ = hidden_states.shape
        (sin, cos), inner_mask = rel_pos

        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)
        g = self.g_proj(hidden_states)

        k *= self.scaling
        q = q.reshape(bsz, tgt_len, self.num_heads, self.key_dim).transpose(
            (0, 2, 1, 3)
        )
        k = k.reshape(bsz, tgt_len, self.num_heads, self.key_dim).transpose(
            (0, 2, 1, 3)
        )

        qr = theta_shift(q, sin, cos)
        kr = theta_shift(k, sin, cos)

        if incremental_state is not None:
            raise NotImplementedError
        elif self.config.attention_type == "chunkwise_recurrent":
            output = self.chunk_recurrent_forward(qr, kr, v, inner_mask=inner_mask)
        else:
            output = self.parallel_forward(qr, kr, v, inner_mask)

        output = self.group_norm(output)
        output = output.reshape(bsz, tgt_len, -1)

        output = nn.swish(g) * output
        output = self.out_proj(output)

        return output


class FlaxRetNetLayer(nn.Module):
    config: RetNetConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        self.retention = FlaxRetNetRetention(self.config, dtype=self.dtype)
        self.retention_layer_norm = nn.LayerNorm(
            epsilon=self.config.layer_norm_eps, dtype=self.dtype
        )

        self.ffn = FlaxRetNetFeedForward(self.config, dtype=self.dtype)
        self.final_layer_norm = nn.LayerNorm(
            epsilon=self.config.layer_norm_eps, dtype=self.dtype
        )

        self.dropout_module = nn.Dropout(rate=self.config.dropout)

    def __call__(
        self,
        hidden_states: jnp.ndarray,
        retention_rel_pos: Optional[tuple] = None,
        deterministic: bool = True,
    ) -> jnp.ndarray:
        residual = hidden_states
        hidden_states = self.retention_layer_norm(hidden_states)
        hidden_states = self.retention(hidden_states, rel_pos=retention_rel_pos)
        hidden_states = self.dropout_module(hidden_states, deterministic=deterministic)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.ffn(hidden_states, deterministic=deterministic)
        hidden_states = residual + hidden_states

        return hidden_states


class FlaxRetNetLayerCollection(nn.Module):
    config: RetNetConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        self.layers = [
            FlaxRetNetLayer(self.config, dtype=self.dtype)
            for _ in range(self.config.num_hidden_layers)
        ]

    def __call__(
        self,
        hidden_states: jnp.ndarray,
        retention_rel_pos: tuple = None,
        deterministic: bool = True,
        output_retentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ) -> jnp.ndarray:
        all_hidden_states = () if output_hidden_states else None
        all_retentions = () if output_retentions else None

        for layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            layer_outputs = layer(
                hidden_states,
                retention_rel_pos=retention_rel_pos,
                deterministic=deterministic,
            )
            hidden_states = layer_outputs

        outputs = (hidden_states, all_hidden_states, all_retentions)
        return outputs


class FlaxRetNetPretrainedModel(FlaxPreTrainedModel):
    config_class = RetNetConfig
    base_model_prefix = "transformer"
    main_input_name = "input_ids"
    module_class: nn.Module = None

    def __init__(
        self,
        config: RetNetConfig,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs
    ):
        module = self.module_class(config, dtype=dtype, **kwargs)
        super().__init__(
            config,
            module,
            input_shape=input_shape,
            seed=seed,
            dtype=dtype,
            _do_init=_do_init,
        )

    def init_weights(
        self,
        rng: jax.random.PRNGKey,
        input_shape: Tuple,
        params: FrozenDict = None,
    ) -> FrozenDict:
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        module_init_outputs = self.module.init(
            rngs, input_ids, attention_mask, return_dict=False
        )

        random_params = module_init_outputs["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = []
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        params: dict = None,
        dropout_rng: jnp.ndarray = None,
        train: bool = False,
        output_retentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        output_retentions = (
            output_retentions
            if output_retentions is not None
            else self.config.output_retentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.return_dict
        )

        batch_size, sequence_length = input_ids.shape

        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            not train,
            output_retentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
        )

        return outputs


class FlaxRetNetModule(nn.Module):
    config: RetNetConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        self.embed_tokens = nn.Embed(
            self.config.vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.xavier_normal(),
            dtype=self.dtype,
        )
        self.retnet_rel_pos = FlaxRetNetRelPos(self.config, dtype=self.dtype)

        self.layers = FlaxRetNetLayerCollection(self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        output_retentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        input_embeds = self.embed_tokens(input_ids)

        batch_size, sequence_length = input_embeds.shape[:2]
        retention_rel_pos = self.retnet_rel_pos(
            sequence_length,
            activate_recurrent=False,
            chunkwise_recurrent=self.config.attention_type == "chunkwise_recurrent",
        )

        outputs = self.layers(
            input_embeds,
            retention_rel_pos=retention_rel_pos,
            deterministic=deterministic,
            output_retentions=output_retentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        return FlaxBaseModelOutput(
            last_hidden_state=outputs[0],
            hidden_states=outputs[1],
            attentions=outputs[-1],
        )


class FlaxRetNetModel(FlaxRetNetPretrainedModel):
    module_class = FlaxRetNetModule


class FlaxRetNetForCausalLMModule(nn.Module):
    config: RetNetConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        self.transformer = FlaxRetNetModule(self.config, dtype=self.dtype)

        self.lm_head = nn.Dense(
            self.config.vocab_size,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )

    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        output_retentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            deterministic=deterministic,
            output_retentions=output_retentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]

        lm_logits = self.lm_head(hidden_states)

        if not return_dict:
            return (lm_logits,) + outputs[1:]

        return FlaxCausalLMOutput(
            logits=lm_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class FlaxRetNetForCausalLM(FlaxRetNetPretrainedModel):
    module_class = FlaxRetNetForCausalLMModule