gpt-neo / models /utils.py
aliabd
full working demo
c6e7238
import tensorflow as tf
import mesh_tensorflow as mtf
from functools import partial
def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha=1.3, dim=None,
n_iter=50):
x, = explicit_inputs
y, = outputs
dY, = output_grads
gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y))
dX = dY * gppr
q = mtf.reduce_sum(dX, reduced_dim=dim) / mtf.reduce_sum(gppr, reduced_dim=dim)
dX = dX - q * gppr
return dX,
def entmax_forward(x, alpha=1.3, dim=None, n_iter=50):
assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2'
_gp = lambda x, alpha: x ** (alpha - 1)
_gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1)))
_p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha)
dim = x.shape[-1] if dim is None else dim
d = dim.size
x = x * (alpha - 1)
max_val = mtf.reduce_max(x, reduced_dim=dim)
tau_lo = max_val - _gp(1, alpha)
tau_hi = max_val - _gp(1 / d, alpha)
f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1
dm = tau_hi - tau_lo
for _ in range(n_iter):
dm = dm / 2
tau_m = tau_lo + dm
p_m = _p(x - tau_m, alpha)
f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1
mask = mtf.greater_equal((f_m * f_lo), 0)
tau_lo = mtf.where(mask, tau_m, tau_lo)
p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim)
return p_m
def entmax(x, alpha=1.3, dim=None, n_iter=50):
kwargs = dict(alpha=alpha, dim=dim, n_iter=n_iter)
return mtf.custom_gradient(
partial(entmax_forward, **kwargs),
partial(entmax_backward, **kwargs),
[x]
)
def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
if targets.dtype.is_integer:
# hard targets
if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])):
raise ValueError(
"softmax_cross_entropy_with_logits with hard targets "
"dims in targets=%s should be dims in logits=%s other than "
"vocab_dim=%s" % (targets, logits, vocab_dim))
targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype)
elif set(targets.shape.dims) != set(logits.shape.dims):
raise ValueError(
"softmax_cross_entropy_with_logits with soft targets "
"dims in targets=%s should be dims in logits=%s" % (targets, logits))
if vocab_dim not in logits.shape.dims:
raise ValueError("vocab_dim must be in logits.shape.dims")
log_entmax = mtf.log(entmax(logits, dim=vocab_dim))
loss = mtf.negative(
mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim))
return loss
def sample_categorical(x, dim=None):
dim = x.shape[-1] if dim is None else dim
cdf = mtf.cumsum(x, dim)
rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1)
mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32)
return mtf.argmax(mask, dim)
def biasmask_attn_weights(mesh, nd, ns, variable_dtype):
# The old mask_attn_weights applied directly to the QK;
# this returns a bias that the attention code from mtf adds to the attention matrix.
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
# n_src and n_dest are both the same, i.e equal to sequence length
# We rename ns because we want bias to have shape [batch, heads, memory_length, sequence] to match up with QK^T
# Information flows from k and v (memory_length) to q (sequence)
i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size
j = mtf.range(mesh, ns, tf.int32)
i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j))
dtype = variable_dtype.activation_dtype
return mtf.cast(mtf.less(i, j), dtype) * -1e10
def parse_inputs(mtf_features, other_features):
# Parse inputs and labels from the mtf_features / other_features input dicts
# All dimensions are defined inside model_fn for efficiency
x = mtf_features["inputs"]
batch_dim = x.shape[0]
sequence_dim = x.shape[1]
embd_dim = other_features["embd_dim"]
vocab_dim = other_features["vocab_dim"]
embed_sequence_dim = other_features["embed_sequence_dim"]
return x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim