|
import mesh_tensorflow as mtf |
|
import tensorflow.compat.v1 as tf |
|
import math |
|
import mesh_tensorflow.transformer as mtf_transformer |
|
|
|
from models.activations import get_activation_fn |
|
|
|
|
|
|
|
|
|
|
|
sentinel = object() |
|
|
|
|
|
def exists(x): |
|
return x is not None |
|
|
|
|
|
def identity(x, *args, **kwargs): |
|
return x |
|
|
|
|
|
def is_incremental_inference(context): |
|
return exists(context) and context.mode == "incremental" |
|
|
|
|
|
def norm(x, axis, epsilon=1e-8): |
|
x -= mtf.reduce_mean(x, reduced_dim=axis, name="norm_reduce_mean_u") |
|
s = mtf.reduce_mean(mtf.square(x), reduced_dim=axis, name="norm_reduce_mean_s") |
|
return x * mtf.rsqrt(s + epsilon) |
|
|
|
|
|
def rezero(x, scope, dtype): |
|
with tf.variable_scope(scope): |
|
g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(0), dtype=dtype) |
|
return x * g |
|
|
|
|
|
def scale_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None): |
|
if axis is sentinel: |
|
axis = x.shape[-1] |
|
|
|
with tf.variable_scope(scope): |
|
g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(1), |
|
master_dtype=variable_dtype.master_dtype, |
|
slice_dtype=variable_dtype.slice_dtype, |
|
activation_dtype=variable_dtype.activation_dtype) |
|
|
|
x = norm(x, axis, epsilon) |
|
x = x * g |
|
return x |
|
|
|
|
|
def layer_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None): |
|
"""Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" |
|
if axis is sentinel: |
|
axis = x.shape[-1] |
|
|
|
with tf.variable_scope(scope): |
|
n_state = x.shape[-1] |
|
|
|
g = mtf.get_variable(x.mesh, "g", [n_state], initializer=tf.constant_initializer(1), |
|
master_dtype=variable_dtype.master_dtype, |
|
slice_dtype=variable_dtype.slice_dtype, |
|
activation_dtype=variable_dtype.activation_dtype) |
|
b = mtf.get_variable(x.mesh, "b", [n_state], initializer=tf.constant_initializer(0), |
|
master_dtype=variable_dtype.master_dtype, |
|
slice_dtype=variable_dtype.slice_dtype, |
|
activation_dtype=variable_dtype.activation_dtype) |
|
|
|
x = norm(x, axis, epsilon) |
|
x = x * g + b |
|
return x |
|
|
|
|
|
def linear_attention(q, k, v): |
|
batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) |
|
q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") |
|
k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") |
|
|
|
dim_in = k.shape[-1] |
|
|
|
q = mtf.softmax(q, dim_in) |
|
k = mtf.softmax(k, seq_dim) |
|
|
|
context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out]) |
|
attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) |
|
return attn |
|
|
|
|
|
def causal_linear_attention(q, k, v, eps = 1e-6): |
|
batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) |
|
q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") |
|
k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") |
|
|
|
dim_in = k.shape[-1] |
|
|
|
q = mtf.softmax(q, dim_in) |
|
k = mtf.exp(k) |
|
|
|
cumulative_k = mtf.cumsum(k, seq_dim) + eps |
|
D_inv = 1. / mtf.einsum([q, cumulative_k], output_shape=[batch_dim, seq_dim, head_dim]) |
|
|
|
context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out]) |
|
cumulative_context = mtf.cumsum(context, seq_dim) |
|
|
|
attn = mtf.einsum([q, cumulative_context, D_inv], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) |
|
return attn |
|
|
|
|
|
def linear(x, scope, nf, *, w_init_stdev=0.02, variable_dtype, params=None, scale=False): |
|
|
|
if params["scale_by_depth"] and scale: |
|
|
|
w_init_stdev = w_init_stdev * (1. / math.sqrt(params["n_layer"])) |
|
if params["scale_by_in"]: |
|
w_init_stdev = w_init_stdev * (1. / math.sqrt(x.shape[-1].size)) |
|
|
|
with tf.variable_scope("conv1d_main"): |
|
c = mtf.layers.dense(x, new_dims=[nf], reduced_dims=[x.shape[-1]], name=scope, use_bias=True, |
|
kernel_initializer=tf.random_normal_initializer(stddev=w_init_stdev), |
|
variable_dtype=variable_dtype, |
|
) |
|
return c |
|
|
|
|
|
def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh): |
|
"""memory / key values from all attention paper""" |
|
|
|
dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv) |
|
emb_dim = k.shape[-1] |
|
mem_std = 1 / math.sqrt(emb_dim.size) |
|
|
|
mem_k = mtf.get_variable(mesh, "mem_k", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), |
|
initializer=tf.random_normal_initializer(stddev=mem_std), |
|
master_dtype=variable_dtype.master_dtype, |
|
slice_dtype=variable_dtype.slice_dtype, |
|
activation_dtype=variable_dtype.activation_dtype, |
|
) |
|
mem_v = mtf.get_variable(mesh, "mem_v", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), |
|
initializer=tf.random_normal_initializer(stddev=mem_std), |
|
master_dtype=variable_dtype.master_dtype, |
|
slice_dtype=variable_dtype.slice_dtype, |
|
activation_dtype=variable_dtype.activation_dtype) |
|
|
|
mem_k, mem_v = map(lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]), |
|
(mem_k, mem_v)) |
|
mem_k, mem_v = map(lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"), |
|
(mem_k, mem_v)) |
|
|
|
k = mtf.concat([mem_k, k], "sequence") |
|
v = mtf.concat([mem_v, v], "sequence") |
|
return k, v |
|
|
|
|
|
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None, pos_emb=None): |
|
|
|
x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh |
|
|
|
|
|
assert n_state.size % params["n_head"] == 0 |
|
|
|
dim_heads = mtf.Dimension("heads", params["n_head"]) |
|
|
|
num_mem_kv = params.get("num_mem_kv", 0) |
|
use_num_mem_kv = num_mem_kv > 0 |
|
|
|
with tf.variable_scope(scope): |
|
|
|
dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"]) |
|
mtfparams = mtf.transformer.attention.attention_params_simple( |
|
x.mesh, |
|
io_dim=dim_embd, |
|
kv_dim=dim_kv, |
|
heads_dim=dim_heads, |
|
variable_dtype=variable_dtype |
|
) |
|
q = mtfparams.compute_q(x) |
|
k = mtfparams.compute_k(x) |
|
v = mtfparams.compute_v(x) |
|
|
|
if is_incremental_inference(context): |
|
one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype) |
|
inv_one_hot = 1.0 - one_hot |
|
old_k, old_v = context.get_states(2) |
|
k = old_k * inv_one_hot + k * one_hot |
|
v = old_v * inv_one_hot + v * one_hot |
|
|
|
if exists(context): |
|
context.record_new_states([k, v]) |
|
|
|
if exists(pos_emb): |
|
cos, sin = pos_emb |
|
k = apply_rotary_emb(k, cos, sin) |
|
|
|
if is_incremental_inference(context): |
|
seq_dim = cos.shape.get_dim_by_name('sequence') |
|
cos = mtf.gather(cos, context.position - 1, seq_dim) |
|
sin = mtf.gather(sin, context.position - 1, seq_dim) |
|
|
|
q = apply_rotary_emb(q, cos, sin) |
|
|
|
with tf.variable_scope("attention"): |
|
if attention_type == "local": |
|
|
|
radius = params.get("local_attention_radius", 256) |
|
|
|
if is_incremental_inference(context): |
|
q *= one_hot |
|
|
|
a = mtf_transformer.attention.local_attention_1d( |
|
q, k, v, |
|
length_dim=k.shape[1], |
|
key_dim=dim_kv, |
|
value_dim=dim_kv, |
|
radius=radius, |
|
length_dim_num_splits=1, |
|
fully_autoregressive=params["causal"], |
|
attention_kwargs={}, |
|
) |
|
|
|
if is_incremental_inference(context): |
|
a = mtf.gather(a, context.position - 1, dim_seq) |
|
|
|
elif attention_type == "global": |
|
|
|
|
|
|
|
if exists(bias): |
|
if not is_incremental_inference(context): |
|
broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]]) |
|
else: |
|
|
|
bias = mtf.gather(bias, context.position - 1, dim_seq) |
|
broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]]) |
|
|
|
|
|
if use_num_mem_kv: |
|
k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh) |
|
|
|
k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim) |
|
v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim) |
|
|
|
attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0 |
|
|
|
a = mtf_transformer.attention.attention( |
|
q, k, v, |
|
memory_length_dim=memory_length_dim, |
|
key_dim=dim_kv, |
|
value_dim=dim_kv, |
|
bias=broadcasted_bias, |
|
dropout_rate=attn_dropout_rate |
|
) |
|
|
|
elif attention_type == "linear": |
|
linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention |
|
a = linear_attn_fn(q, k, v) |
|
|
|
else: |
|
raise NotImplementedError("Unknown attention type {}!".format(attention_type)) |
|
|
|
with tf.variable_scope("compute_output"): |
|
a = mtfparams.compute_output(a, x_shape) |
|
|
|
with tf.variable_scope("compute_output_bias"): |
|
b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0), |
|
master_dtype=variable_dtype.master_dtype, |
|
slice_dtype=variable_dtype.slice_dtype, |
|
activation_dtype=variable_dtype.activation_dtype) |
|
a += b |
|
|
|
if params["mode"] == "train" and params["res_dropout"] > 0: |
|
a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout") |
|
return a |
|
|
|
|
|
def mlp(x, scope, n_state, *, variable_dtype, params): |
|
activation_fn = get_activation_fn(params) |
|
with tf.variable_scope(scope): |
|
nx = x.shape[-1] |
|
h = activation_fn(linear(x, "c_fc", n_state, variable_dtype=variable_dtype, params=params)) |
|
h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True) |
|
if params["mode"] == "train" and params["res_dropout"] > 0: |
|
h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout") |
|
return h2 |
|
|
|
|
|
def mlp_glu(x, scope, n_state, *, variable_dtype, params): |
|
activation_fn = get_activation_fn(params) |
|
with tf.variable_scope(scope): |
|
nx = x.shape[-1] |
|
h = linear(x, "c_fc", n_state, params=params) |
|
|
|
h, gate = mtf.split(h, h.shape[-1], 2) |
|
h *= activation_fn(gate) |
|
|
|
h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True) |
|
if params["mode"] == "train" and params["res_dropout"] > 0: |
|
h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout") |
|
return h2 |
|
|
|
|
|
def axial_positional_emb(embd_dim, mesh, params, variable_dtype): |
|
|
|
axial_dim_1, axial_dim_2 = params["axial_pos_emb"] |
|
|
|
axial_dim = mtf.Dimension("axial_dim", axial_dim_1 * axial_dim_2) |
|
dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_1, axial_dim_2))] |
|
|
|
axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]), |
|
initializer=tf.random_normal_initializer(stddev=0.01), |
|
master_dtype=variable_dtype.master_dtype, |
|
slice_dtype=variable_dtype.slice_dtype, |
|
activation_dtype=variable_dtype.activation_dtype) |
|
|
|
axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]), |
|
initializer=tf.random_normal_initializer(stddev=0.01), |
|
master_dtype=variable_dtype.master_dtype, |
|
slice_dtype=variable_dtype.slice_dtype, |
|
activation_dtype=variable_dtype.activation_dtype) |
|
|
|
axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]), |
|
(axial_wpe_1, axial_wpe_2)) |
|
wpe = (axial_wpe_1 + axial_wpe_2) / 2 |
|
|
|
wpe = mtf.reshape(wpe, [axial_dim, embd_dim]) |
|
|
|
return wpe |
|
|
|
def rotary_positional_emb(mesh, sequence_dim, params, variable_dtype): |
|
dtype = variable_dtype.master_dtype |
|
dim_head = params["n_embd"] // params["n_head"] |
|
|
|
dim_head = mtf.Dimension("features_per_head", dim_head) |
|
half_dim_head = mtf.Dimension("half_features_per_head", dim_head.size // 2) |
|
|
|
dim_range = mtf.range(mesh, half_dim_head, dtype) * 2 / dim_head.size |
|
half_freqs = 1. / mtf.pow(mtf.constant(mesh, 10000, dtype = dtype), dim_range) |
|
|
|
seq = mtf.range(mesh, sequence_dim, dtype) |
|
half_freqs = mtf.einsum([half_freqs, seq], [sequence_dim, half_dim_head]) |
|
|
|
freqs = mtf.concat((half_freqs, half_freqs), half_dim_head.name) |
|
freqs = mtf.rename_dimension(freqs, half_dim_head.name, dim_head.name) |
|
return mtf.cos(freqs), mtf.sin(freqs) |
|
|
|
def rotate_half(x): |
|
dim_head_name = "features_per_head" |
|
dim_head = x.shape.get_dim_by_name(dim_head_name) |
|
half_dim_head_size = dim_head.size // 2 |
|
x1 = mtf.slice(x, 0, half_dim_head_size, dim_head_name) |
|
x2 = mtf.slice(x, half_dim_head_size, half_dim_head_size, dim_head_name) |
|
return mtf.concat((-x2, x1), dim_head.name) |
|
|
|
def apply_rotary_emb(x, cos, sin): |
|
rotated_x = rotate_half(x) |
|
return x * cos + rotated_x * sin |
|
|