retnet-small-ko / modeling_flax_retnet.py
hyunwoo3235's picture
Upload 8 files
6c1ac22
raw
history blame contribute delete
No virus
18.8 kB
from typing import Optional, Tuple
import jax
from flax import linen as nn
from flax.core import FrozenDict, unfreeze, freeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import numpy as jnp
from transformers import FlaxPreTrainedModel
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from transformers.modeling_flax_utils import ACT2FN
from .configuration_retnet import RetNetConfig
def rotate_every_two(tensor):
rotate_half_tensor = jnp.stack(
(-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1
)
rotate_half_tensor = rotate_half_tensor.reshape(
rotate_half_tensor.shape[:-2] + (-1,)
)
return rotate_half_tensor
def theta_shift(x, sin, cos):
return (x * cos) + (rotate_every_two(x) * sin)
class FlaxRetNetRelPos(nn.Module):
config: RetNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
angle = 1.0 / (
10000
** jnp.linspace(
0, 1, self.config.hidden_size // self.config.num_rettention_heads // 2
)
)
self.angle = angle.repeat(2).flatten()
self.decay = jnp.log(
1
- 2
** (-5 - jnp.arange(self.config.num_rettention_heads, dtype=jnp.float32))
)
self.recurrent_chunk_size = self.config.recurrent_chunk_size
def __call__(
self,
slen: int,
activate_recurrent: bool = False,
chunkwise_recurrent: bool = False,
):
if activate_recurrent:
sin = jnp.sin(self.angle * (slen - 1))
cos = jnp.cos(self.angle * (slen - 1))
retention_rel_pos = ((sin, cos), jnp.exp(self.decay))
elif chunkwise_recurrent:
index = jnp.arange(slen)
sin = jnp.sin(index[:, None] * self.angle[None, :])
cos = jnp.cos(index[:, None] * self.angle[None, :])
block_index = jnp.arange(self.recurrent_chunk_size)
mask = jnp.tril(
jnp.ones((self.recurrent_chunk_size, self.recurrent_chunk_size))
)
mask = jnp.where(
~mask.astype(jnp.bool_),
float("inf"),
block_index[:, None] - block_index[None, :],
)
mask = jnp.exp(mask * self.decay[:, None, None])
mask = jnp.nan_to_num(mask)
scale = jnp.sqrt(mask.sum(axis=-1, keepdims=True))
mask = mask / scale
cross_decay = jnp.exp(self.decay * self.recurrent_chunk_size)
inner_decay = jnp.exp(self.decay[:, None] * (block_index + 1))
cross_decay = cross_decay[:, None, None]
inner_decay = inner_decay[:, :, None] / (scale / scale[:, -1, None])
retention_rel_pos = ((sin, cos), (mask, cross_decay, inner_decay))
else:
index = jnp.arange(slen)
sin = jnp.sin(index[:, None] * self.angle[None, :])
cos = jnp.cos(index[:, None] * self.angle[None, :])
mask = jnp.tril(jnp.ones((slen, slen)))
mask = jnp.where(
~mask.astype(jnp.bool_), float("inf"), index[:, None] - index[None, :]
)
mask = jnp.exp(mask * self.decay[:, None, None])
mask = jnp.nan_to_num(mask)
mask = mask / jnp.sqrt(mask.sum(axis=-1, keepdims=True))
retention_rel_pos = ((sin, cos), mask)
return retention_rel_pos
class FlaxRetNetFeedForward(nn.Module):
config: RetNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.fc1 = nn.Dense(
self.config.intermediate_size,
kernel_init=nn.initializers.xavier_normal(),
dtype=self.dtype,
)
self.fc2 = nn.Dense(
self.config.hidden_size,
kernel_init=nn.initializers.xavier_normal(),
dtype=self.dtype,
)
self.activation_fn = ACT2FN[self.config.hidden_act]
self.activation_dropout = nn.Dropout(rate=self.config.dropout)
self.dropout = nn.Dropout(rate=self.config.dropout)
def __call__(
self,
hidden_states: jnp.ndarray,
deterministic: bool = True,
) -> jnp.ndarray:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.activation_dropout(
hidden_states, deterministic=deterministic
)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxRetNetRetention(nn.Module):
config: RetNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.factor = 2
self.embed_dim = self.config.hidden_size
self.num_heads = self.config.num_rettention_heads
self.head_dim = self.embed_dim * self.factor // self.num_heads
self.key_dim = self.embed_dim // self.num_heads
self.scaling = self.key_dim**-0.5
self.q_proj = nn.Dense(
self.embed_dim,
use_bias=True,
kernel_init=jax.nn.initializers.xavier_normal(),
dtype=self.dtype,
)
self.k_proj = nn.Dense(
self.embed_dim,
use_bias=True,
kernel_init=jax.nn.initializers.xavier_normal(),
dtype=self.dtype,
)
self.v_proj = nn.Dense(
self.embed_dim * self.factor,
use_bias=True,
kernel_init=jax.nn.initializers.xavier_normal(),
dtype=self.dtype,
)
self.g_proj = nn.Dense(
self.embed_dim * self.factor,
use_bias=True,
kernel_init=nn.initializers.xavier_normal(),
dtype=self.dtype,
)
self.out_proj = nn.Dense(
self.embed_dim,
use_bias=True,
kernel_init=jax.nn.initializers.xavier_normal(),
dtype=self.dtype,
)
self.group_norm = nn.LayerNorm(epsilon=1e-6, dtype=self.dtype)
def parallel_forward(self, qr, kr, v, mask):
bsz, tgt_len, embed_dim = v.shape
vr = v.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(
(0, 2, 1, 3)
)
qk_mat = qr @ kr.transpose((0, 1, 3, 2))
qk_mat = qk_mat * mask
qk_mat /= jnp.abs(
jax.lax.stop_gradient(qk_mat).sum(axis=-1, keepdims=True)
).clip(min=1)
output = jnp.matmul(qk_mat, vr)
output = output.transpose((0, 2, 1, 3))
return output
def chunk_recurrent_forward(self, qr, kr, v, inner_mask):
mask, cross_decay, inner_decay = inner_mask
bsz, tgt_len, embed_dim = v.shape
chunk_len = mask.shape[1]
num_chunks = tgt_len // chunk_len
assert tgt_len % chunk_len == 0
qr = qr.reshape(
bsz, self.num_heads, num_chunks, chunk_len, self.key_dim
).transpose((0, 2, 1, 3, 4))
kr = kr.reshape(
bsz, self.num_heads, num_chunks, chunk_len, self.key_dim
).transpose((0, 2, 1, 3, 4))
v = v.reshape(
bsz, num_chunks, chunk_len, self.num_heads, self.head_dim
).transpose((0, 1, 3, 2, 4))
kr_t = kr.transpose((0, 1, 2, 4, 3))
qk_mat = qr @ kr_t
qk_mat = qk_mat
inner_scale = jnp.abs(
jax.lax.stop_gradient(qk_mat).sum(axis=-1, keepdims=True)
).clip(min=1)
qk_mat = qk_mat / inner_scale
inner_output = jnp.matmul(qk_mat, v)
kv = kr_t @ v
kv = kv.reshape(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim)
kv_recurrent = []
cross_scale = []
kv_state = jnp.zeros((bsz, self.num_heads, self.key_dim, self.head_dim))
kv_scale = jnp.ones((bsz, self.num_heads, 1, 1))
for i in range(num_chunks):
kv_recurrent.append(kv_state / kv_scale)
cross_scale.append(kv_scale)
kv_state = kv_state * cross_decay + kv[:, i]
kv_scale = (
jnp.abs(jax.lax.stop_gradient(kv_state).sum(axis=-2, keepdims=True))
.max(axis=-1, keepdims=True)
.clip(min=1)
)
kv_recurrent = jnp.stack(kv_recurrent, axis=1)
cross_scale = jnp.stack(cross_scale, axis=1)
all_scale = jnp.maximum(inner_scale, cross_scale)
align_inner_scale = all_scale / inner_scale
align_cross_scale = all_scale / cross_scale
cross_output = (qr * inner_decay) @ kv_recurrent
output = inner_output / align_inner_scale + cross_output / align_cross_scale
output = output.transpose((0, 2, 1, 3, 4))
return output
def __call__(
self,
hidden_states: jnp.ndarray,
rel_pos: Optional[jnp.ndarray] = None,
chunkwise_recurrent: bool = True,
incremental_state=None,
) -> jnp.ndarray:
bsz, tgt_len, _ = hidden_states.shape
(sin, cos), inner_mask = rel_pos
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
g = self.g_proj(hidden_states)
k *= self.scaling
q = q.reshape(bsz, tgt_len, self.num_heads, self.key_dim).transpose(
(0, 2, 1, 3)
)
k = k.reshape(bsz, tgt_len, self.num_heads, self.key_dim).transpose(
(0, 2, 1, 3)
)
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
if incremental_state is not None:
raise NotImplementedError
elif self.config.attention_type == "chunkwise_recurrent":
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask=inner_mask)
else:
output = self.parallel_forward(qr, kr, v, inner_mask)
output = self.group_norm(output)
output = output.reshape(bsz, tgt_len, -1)
output = nn.swish(g) * output
output = self.out_proj(output)
return output
class FlaxRetNetLayer(nn.Module):
config: RetNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.retention = FlaxRetNetRetention(self.config, dtype=self.dtype)
self.retention_layer_norm = nn.LayerNorm(
epsilon=self.config.layer_norm_eps, dtype=self.dtype
)
self.ffn = FlaxRetNetFeedForward(self.config, dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(
epsilon=self.config.layer_norm_eps, dtype=self.dtype
)
self.dropout_module = nn.Dropout(rate=self.config.dropout)
def __call__(
self,
hidden_states: jnp.ndarray,
retention_rel_pos: Optional[tuple] = None,
deterministic: bool = True,
) -> jnp.ndarray:
residual = hidden_states
hidden_states = self.retention_layer_norm(hidden_states)
hidden_states = self.retention(hidden_states, rel_pos=retention_rel_pos)
hidden_states = self.dropout_module(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.ffn(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
return hidden_states
class FlaxRetNetLayerCollection(nn.Module):
config: RetNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.layers = [
FlaxRetNetLayer(self.config, dtype=self.dtype)
for _ in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states: jnp.ndarray,
retention_rel_pos: tuple = None,
deterministic: bool = True,
output_retentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> jnp.ndarray:
all_hidden_states = () if output_hidden_states else None
all_retentions = () if output_retentions else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states,
retention_rel_pos=retention_rel_pos,
deterministic=deterministic,
)
hidden_states = layer_outputs
outputs = (hidden_states, all_hidden_states, all_retentions)
return outputs
class FlaxRetNetPretrainedModel(FlaxPreTrainedModel):
config_class = RetNetConfig
base_model_prefix = "transformer"
main_input_name = "input_ids"
module_class: nn.Module = None
def __init__(
self,
config: RetNetConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs
):
module = self.module_class(config, dtype=dtype, **kwargs)
super().__init__(
config,
module,
input_shape=input_shape,
seed=seed,
dtype=dtype,
_do_init=_do_init,
)
def init_weights(
self,
rng: jax.random.PRNGKey,
input_shape: Tuple,
params: FrozenDict = None,
) -> FrozenDict:
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
module_init_outputs = self.module.init(
rngs, input_ids, attention_mask, return_dict=False
)
random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = []
return freeze(unflatten_dict(params))
else:
return random_params
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
params: dict = None,
dropout_rng: jnp.ndarray = None,
train: bool = False,
output_retentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
output_retentions = (
output_retentions
if output_retentions is not None
else self.config.output_retentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
)
batch_size, sequence_length = input_ids.shape
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
not train,
output_retentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
return outputs
class FlaxRetNetModule(nn.Module):
config: RetNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.xavier_normal(),
dtype=self.dtype,
)
self.retnet_rel_pos = FlaxRetNetRelPos(self.config, dtype=self.dtype)
self.layers = FlaxRetNetLayerCollection(self.config, dtype=self.dtype)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
output_retentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
input_embeds = self.embed_tokens(input_ids)
batch_size, sequence_length = input_embeds.shape[:2]
retention_rel_pos = self.retnet_rel_pos(
sequence_length,
activate_recurrent=False,
chunkwise_recurrent=self.config.attention_type == "chunkwise_recurrent",
)
outputs = self.layers(
input_embeds,
retention_rel_pos=retention_rel_pos,
deterministic=deterministic,
output_retentions=output_retentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=outputs[0],
hidden_states=outputs[1],
attentions=outputs[-1],
)
class FlaxRetNetModel(FlaxRetNetPretrainedModel):
module_class = FlaxRetNetModule
class FlaxRetNetForCausalLMModule(nn.Module):
config: RetNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.transformer = FlaxRetNetModule(self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
output_retentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
deterministic=deterministic,
output_retentions=output_retentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + outputs[1:]
return FlaxCausalLMOutput(
logits=lm_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class FlaxRetNetForCausalLM(FlaxRetNetPretrainedModel):
module_class = FlaxRetNetForCausalLMModule