from typing import Optional, Tuple import jax from flax import linen as nn from flax.core import FrozenDict, unfreeze, freeze from flax.traverse_util import flatten_dict, unflatten_dict from jax import numpy as jnp from transformers import FlaxPreTrainedModel from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput from transformers.modeling_flax_utils import ACT2FN from .configuration_retnet import RetNetConfig def rotate_every_two(tensor): rotate_half_tensor = jnp.stack( (-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1 ) rotate_half_tensor = rotate_half_tensor.reshape( rotate_half_tensor.shape[:-2] + (-1,) ) return rotate_half_tensor def theta_shift(x, sin, cos): return (x * cos) + (rotate_every_two(x) * sin) class FlaxRetNetRelPos(nn.Module): config: RetNetConfig dtype: jnp.dtype = jnp.float32 def setup(self) -> None: angle = 1.0 / ( 10000 ** jnp.linspace( 0, 1, self.config.hidden_size // self.config.num_rettention_heads // 2 ) ) self.angle = angle.repeat(2).flatten() self.decay = jnp.log( 1 - 2 ** (-5 - jnp.arange(self.config.num_rettention_heads, dtype=jnp.float32)) ) self.recurrent_chunk_size = self.config.recurrent_chunk_size def __call__( self, slen: int, activate_recurrent: bool = False, chunkwise_recurrent: bool = False, ): if activate_recurrent: sin = jnp.sin(self.angle * (slen - 1)) cos = jnp.cos(self.angle * (slen - 1)) retention_rel_pos = ((sin, cos), jnp.exp(self.decay)) elif chunkwise_recurrent: index = jnp.arange(slen) sin = jnp.sin(index[:, None] * self.angle[None, :]) cos = jnp.cos(index[:, None] * self.angle[None, :]) block_index = jnp.arange(self.recurrent_chunk_size) mask = jnp.tril( jnp.ones((self.recurrent_chunk_size, self.recurrent_chunk_size)) ) mask = jnp.where( ~mask.astype(jnp.bool_), float("inf"), block_index[:, None] - block_index[None, :], ) mask = jnp.exp(mask * self.decay[:, None, None]) mask = jnp.nan_to_num(mask) scale = jnp.sqrt(mask.sum(axis=-1, keepdims=True)) mask = mask / scale cross_decay = jnp.exp(self.decay * self.recurrent_chunk_size) inner_decay = jnp.exp(self.decay[:, None] * (block_index + 1)) cross_decay = cross_decay[:, None, None] inner_decay = inner_decay[:, :, None] / (scale / scale[:, -1, None]) retention_rel_pos = ((sin, cos), (mask, cross_decay, inner_decay)) else: index = jnp.arange(slen) sin = jnp.sin(index[:, None] * self.angle[None, :]) cos = jnp.cos(index[:, None] * self.angle[None, :]) mask = jnp.tril(jnp.ones((slen, slen))) mask = jnp.where( ~mask.astype(jnp.bool_), float("inf"), index[:, None] - index[None, :] ) mask = jnp.exp(mask * self.decay[:, None, None]) mask = jnp.nan_to_num(mask) mask = mask / jnp.sqrt(mask.sum(axis=-1, keepdims=True)) retention_rel_pos = ((sin, cos), mask) return retention_rel_pos class FlaxRetNetFeedForward(nn.Module): config: RetNetConfig dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.fc1 = nn.Dense( self.config.intermediate_size, kernel_init=nn.initializers.xavier_normal(), dtype=self.dtype, ) self.fc2 = nn.Dense( self.config.hidden_size, kernel_init=nn.initializers.xavier_normal(), dtype=self.dtype, ) self.activation_fn = ACT2FN[self.config.hidden_act] self.activation_dropout = nn.Dropout(rate=self.config.dropout) self.dropout = nn.Dropout(rate=self.config.dropout) def __call__( self, hidden_states: jnp.ndarray, deterministic: bool = True, ) -> jnp.ndarray: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.activation_dropout( hidden_states, deterministic=deterministic ) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) return hidden_states class FlaxRetNetRetention(nn.Module): config: RetNetConfig dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.factor = 2 self.embed_dim = self.config.hidden_size self.num_heads = self.config.num_rettention_heads self.head_dim = self.embed_dim * self.factor // self.num_heads self.key_dim = self.embed_dim // self.num_heads self.scaling = self.key_dim**-0.5 self.q_proj = nn.Dense( self.embed_dim, use_bias=True, kernel_init=jax.nn.initializers.xavier_normal(), dtype=self.dtype, ) self.k_proj = nn.Dense( self.embed_dim, use_bias=True, kernel_init=jax.nn.initializers.xavier_normal(), dtype=self.dtype, ) self.v_proj = nn.Dense( self.embed_dim * self.factor, use_bias=True, kernel_init=jax.nn.initializers.xavier_normal(), dtype=self.dtype, ) self.g_proj = nn.Dense( self.embed_dim * self.factor, use_bias=True, kernel_init=nn.initializers.xavier_normal(), dtype=self.dtype, ) self.out_proj = nn.Dense( self.embed_dim, use_bias=True, kernel_init=jax.nn.initializers.xavier_normal(), dtype=self.dtype, ) self.group_norm = nn.LayerNorm(epsilon=1e-6, dtype=self.dtype) def parallel_forward(self, qr, kr, v, mask): bsz, tgt_len, embed_dim = v.shape vr = v.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose( (0, 2, 1, 3) ) qk_mat = qr @ kr.transpose((0, 1, 3, 2)) qk_mat = qk_mat * mask qk_mat /= jnp.abs( jax.lax.stop_gradient(qk_mat).sum(axis=-1, keepdims=True) ).clip(min=1) output = jnp.matmul(qk_mat, vr) output = output.transpose((0, 2, 1, 3)) return output def chunk_recurrent_forward(self, qr, kr, v, inner_mask): mask, cross_decay, inner_decay = inner_mask bsz, tgt_len, embed_dim = v.shape chunk_len = mask.shape[1] num_chunks = tgt_len // chunk_len assert tgt_len % chunk_len == 0 qr = qr.reshape( bsz, self.num_heads, num_chunks, chunk_len, self.key_dim ).transpose((0, 2, 1, 3, 4)) kr = kr.reshape( bsz, self.num_heads, num_chunks, chunk_len, self.key_dim ).transpose((0, 2, 1, 3, 4)) v = v.reshape( bsz, num_chunks, chunk_len, self.num_heads, self.head_dim ).transpose((0, 1, 3, 2, 4)) kr_t = kr.transpose((0, 1, 2, 4, 3)) qk_mat = qr @ kr_t qk_mat = qk_mat inner_scale = jnp.abs( jax.lax.stop_gradient(qk_mat).sum(axis=-1, keepdims=True) ).clip(min=1) qk_mat = qk_mat / inner_scale inner_output = jnp.matmul(qk_mat, v) kv = kr_t @ v kv = kv.reshape(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim) kv_recurrent = [] cross_scale = [] kv_state = jnp.zeros((bsz, self.num_heads, self.key_dim, self.head_dim)) kv_scale = jnp.ones((bsz, self.num_heads, 1, 1)) for i in range(num_chunks): kv_recurrent.append(kv_state / kv_scale) cross_scale.append(kv_scale) kv_state = kv_state * cross_decay + kv[:, i] kv_scale = ( jnp.abs(jax.lax.stop_gradient(kv_state).sum(axis=-2, keepdims=True)) .max(axis=-1, keepdims=True) .clip(min=1) ) kv_recurrent = jnp.stack(kv_recurrent, axis=1) cross_scale = jnp.stack(cross_scale, axis=1) all_scale = jnp.maximum(inner_scale, cross_scale) align_inner_scale = all_scale / inner_scale align_cross_scale = all_scale / cross_scale cross_output = (qr * inner_decay) @ kv_recurrent output = inner_output / align_inner_scale + cross_output / align_cross_scale output = output.transpose((0, 2, 1, 3, 4)) return output def __call__( self, hidden_states: jnp.ndarray, rel_pos: Optional[jnp.ndarray] = None, chunkwise_recurrent: bool = True, incremental_state=None, ) -> jnp.ndarray: bsz, tgt_len, _ = hidden_states.shape (sin, cos), inner_mask = rel_pos q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) g = self.g_proj(hidden_states) k *= self.scaling q = q.reshape(bsz, tgt_len, self.num_heads, self.key_dim).transpose( (0, 2, 1, 3) ) k = k.reshape(bsz, tgt_len, self.num_heads, self.key_dim).transpose( (0, 2, 1, 3) ) qr = theta_shift(q, sin, cos) kr = theta_shift(k, sin, cos) if incremental_state is not None: raise NotImplementedError elif self.config.attention_type == "chunkwise_recurrent": output = self.chunk_recurrent_forward(qr, kr, v, inner_mask=inner_mask) else: output = self.parallel_forward(qr, kr, v, inner_mask) output = self.group_norm(output) output = output.reshape(bsz, tgt_len, -1) output = nn.swish(g) * output output = self.out_proj(output) return output class FlaxRetNetLayer(nn.Module): config: RetNetConfig dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.retention = FlaxRetNetRetention(self.config, dtype=self.dtype) self.retention_layer_norm = nn.LayerNorm( epsilon=self.config.layer_norm_eps, dtype=self.dtype ) self.ffn = FlaxRetNetFeedForward(self.config, dtype=self.dtype) self.final_layer_norm = nn.LayerNorm( epsilon=self.config.layer_norm_eps, dtype=self.dtype ) self.dropout_module = nn.Dropout(rate=self.config.dropout) def __call__( self, hidden_states: jnp.ndarray, retention_rel_pos: Optional[tuple] = None, deterministic: bool = True, ) -> jnp.ndarray: residual = hidden_states hidden_states = self.retention_layer_norm(hidden_states) hidden_states = self.retention(hidden_states, rel_pos=retention_rel_pos) hidden_states = self.dropout_module(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.ffn(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states return hidden_states class FlaxRetNetLayerCollection(nn.Module): config: RetNetConfig dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.layers = [ FlaxRetNetLayer(self.config, dtype=self.dtype) for _ in range(self.config.num_hidden_layers) ] def __call__( self, hidden_states: jnp.ndarray, retention_rel_pos: tuple = None, deterministic: bool = True, output_retentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> jnp.ndarray: all_hidden_states = () if output_hidden_states else None all_retentions = () if output_retentions else None for layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = layer( hidden_states, retention_rel_pos=retention_rel_pos, deterministic=deterministic, ) hidden_states = layer_outputs outputs = (hidden_states, all_hidden_states, all_retentions) return outputs class FlaxRetNetPretrainedModel(FlaxPreTrainedModel): config_class = RetNetConfig base_model_prefix = "transformer" main_input_name = "input_ids" module_class: nn.Module = None def __init__( self, config: RetNetConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, **kwargs ): module = self.module_class(config, dtype=dtype, **kwargs) super().__init__( config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init, ) def init_weights( self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None, ) -> FrozenDict: input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} module_init_outputs = self.module.init( rngs, input_ids, attention_mask, return_dict=False ) random_params = module_init_outputs["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = [] return freeze(unflatten_dict(params)) else: return random_params def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, params: dict = None, dropout_rng: jnp.ndarray = None, train: bool = False, output_retentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): output_retentions = ( output_retentions if output_retentions is not None else self.config.output_retentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) batch_size, sequence_length = input_ids.shape if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} outputs = self.module.apply( inputs, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), not train, output_retentions, output_hidden_states, return_dict, rngs=rngs, ) return outputs class FlaxRetNetModule(nn.Module): config: RetNetConfig dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.embed_tokens = nn.Embed( self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.xavier_normal(), dtype=self.dtype, ) self.retnet_rel_pos = FlaxRetNetRelPos(self.config, dtype=self.dtype) self.layers = FlaxRetNetLayerCollection(self.config, dtype=self.dtype) def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, deterministic: bool = True, output_retentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): input_embeds = self.embed_tokens(input_ids) batch_size, sequence_length = input_embeds.shape[:2] retention_rel_pos = self.retnet_rel_pos( sequence_length, activate_recurrent=False, chunkwise_recurrent=self.config.attention_type == "chunkwise_recurrent", ) outputs = self.layers( input_embeds, retention_rel_pos=retention_rel_pos, deterministic=deterministic, output_retentions=output_retentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput( last_hidden_state=outputs[0], hidden_states=outputs[1], attentions=outputs[-1], ) class FlaxRetNetModel(FlaxRetNetPretrainedModel): module_class = FlaxRetNetModule class FlaxRetNetForCausalLMModule(nn.Module): config: RetNetConfig dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.transformer = FlaxRetNetModule(self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.config.vocab_size, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, deterministic: bool = True, output_retentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): outputs = self.transformer( input_ids, attention_mask=attention_mask, deterministic=deterministic, output_retentions=output_retentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) if not return_dict: return (lm_logits,) + outputs[1:] return FlaxCausalLMOutput( logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlaxRetNetForCausalLM(FlaxRetNetPretrainedModel): module_class = FlaxRetNetForCausalLMModule