File size: 10,173 Bytes
c6e7238 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
"""GPT-like model in Mesh-Tensorflow"""
import tensorflow.compat.v1 as tf
import mesh_tensorflow.transformer as mtf_transformer
from models.utils import parse_inputs, entmax_cross_entropy_with_logits
from models.layers import *
# --------------------------------------------------------------------------------
# TRANSFORMER BLOCK:
def block(params, scope, layer_num, bias, sequence_dim, memory_length_dim, pos_emb, variable_dtype, context=None):
use_mlp_glu = params["mlp_glu"] == True
use_scale_norm = params["scalenorm"] == True
use_moe = exists(params["moe_layers"]) and (layer_num in params["moe_layers"])
use_rezero = params["rezero"] == True
macaron_attention = params["macaron"] == True
def fn(x):
with tf.variable_scope(scope):
nx = x.shape[-1] # Grab last dimension from input
if use_rezero:
prenorm = identity
elif use_scale_norm:
prenorm = scale_norm
else:
prenorm = layer_norm
pre_residual_fn = rezero if use_rezero else identity
attention_type = params["attention_types"][layer_num]
if macaron_attention:
mult = 0.5
mlp_fn = mlp_glu if use_mlp_glu else mlp
intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)
# Define intermediate layer of mlp - to split
dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size)
m = mlp_fn(x, "mlp_macaron", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params)
x = x + (m * mult)
else:
mult = 1
if attention_type != "none":
res_x = prenorm(x, "norm_1", variable_dtype=variable_dtype, params=params)
a = attn(res_x, "attn", nx, attention_type=attention_type,
params=params, bias=bias, dim_seq=sequence_dim, memory_length_dim=memory_length_dim,
variable_dtype=variable_dtype, context=context, pos_emb=pos_emb)
else:
a = x
x = x + pre_residual_fn(a, "norm_rezero_1", dtype=variable_dtype)
res_x = prenorm(x, "norm_2", variable_dtype=variable_dtype, params=params)
if use_moe:
moe_params = mtf.transformer.moe.HParams()
mtf.transformer.moe.set_default_moe_hparams(moe_params)
moe_params.add_hparam("moe_min_expert_capacity", 1)
moe_params.add_hparam("moe_use_experts_attention", False)
# Override defaults
for k, v in params["moe_params"].items():
moe_params.add_hparam(k, v)
moe_train = params["mode"] == "train"
m, aux_loss = mtf.transformer.moe.transformer_moe_layer_v1(res_x, x.shape[-1], moe_params,
train=moe_train,
mesh_shape=params["mesh_shape"],
layout=params["layout"],
activation=params.get("moe_activation",
"relu"),
variable_dtype=variable_dtype,
num_microbatches=params["num_microbatches"])
m = mtf.dropout(m, rate=params["res_dropout"], name="moe_dropout")
else:
mlp_fn = mlp_glu if use_mlp_glu else mlp
intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)
# Define intermediate layer of mlp - to split
dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size)
m = mlp_fn(res_x, "mlp", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params)
aux_loss = mtf.zeros(x.mesh, mtf.Shape([]), dtype=variable_dtype.slice_dtype)
x = x + pre_residual_fn((m * mult), "norm_rezero_2", variable_dtype)
return x, aux_loss
return fn
# --------------------------------------------------------------------------------
# GPT2 MODEL:
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None):
"""A GPT style model implemented in mesh tensorflow."""
x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features)
if is_incremental_inference(context):
# reshape inputs if in inference mode
x = mtf.gather(x, context.position - 1, sequence_dim)
x = mtf.reshape(x, [batch_dim])
use_axial_pos_emb = exists(params["axial_pos_emb"])
use_rotary_emb = exists(params["rotary_emb"])
# Text encoding
wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]),
initializer=tf.random_normal_initializer(stddev=0.02),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype)
with tf.variable_scope("token_embd"):
# Text embedding
h = mtf.gather(wte, x, vocab_dim)
if params["embed_dropout"] > 0 and params["mode"] == "train":
h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout")
# Position encoding
if use_rotary_emb:
wpe = None
layer_pos_emb = rotary_positional_emb(mesh, sequence_dim, params, variable_dtype)
elif use_axial_pos_emb:
wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype)
layer_pos_emb = None
else:
# Use standard position encoding
wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]),
initializer=tf.random_normal_initializer(stddev=0.01),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype)
layer_pos_emb = None
if exists(wpe):
with tf.variable_scope("pos_embd"):
# Positional embedding
position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else (
context.position - 1)
pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
if params["embed_dropout"] > 0 and params["mode"] == "train":
pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout")
h += pos_emb
aux_losses = 0 # instantiate auxiliary losses (for MOE models)
for layer in range(params["n_layer"]):
# attn blocks
share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True
block_scope = f"h{layer}" if not share_parameters else ""
block_fn = block(params=params, scope=block_scope, layer_num=layer,
bias=other_features["attn_bias"],
sequence_dim=sequence_dim,
memory_length_dim=other_features["memory_length_dim"],
pos_emb = layer_pos_emb,
variable_dtype=variable_dtype,
context=context)
# If true and in train mode, enable gradient checkpointing
recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True
h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h])
aux_losses += loss
no_weight_tie_emb = params["no_weight_tie"] == True
if no_weight_tie_emb:
with tf.variable_scope("wte_final_linear"):
logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params)
else:
# Layer normalize & affine transform
h = layer_norm(h, "ln_f", variable_dtype=variable_dtype)
seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1)
with tf.variable_scope("wte_final_einsum"):
# Equivalent to tf.matmul
logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim])
if params["mode"] in ["train", "eval"]:
labels = mtf_features["labels"]
z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy
# Go to full precision for the logits
logits = mtf.cast(logits, tf.float32)
use_entmax_loss = params.get("entmax_loss", False)
loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits
with tf.variable_scope("xentropy_final"):
loss_batch = loss_fn(logits=logits, targets=labels,
vocab_dim=logits.shape[-1], z_loss=z_loss)
# For non-autoregressive models (masked language modeling training)
# Make sure labels with padding tokens are not counted in the loss
if not params["causal"]:
padding_id = params.get("padding_id", 0)
loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch))
with tf.variable_scope("reduce_mean_final"):
loss = mtf.reduce_mean(loss_batch)
loss += aux_losses # Add on auxiliary losses (currently only used for MoE)
loss /= params["num_microbatches"]
# Convert to train dtype
loss = mtf.cast(loss, variable_dtype.slice_dtype)
else:
loss = None
loss_batch = None
# Cast back to checkpoint dtype
logits = mtf.cast(logits, variable_dtype.master_dtype)
return logits, loss, loss_batch
|