gpt-neo / model_fns.py
aliabd
full working demo
c6e7238
raw history blame
No virus
14.7 kB
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
from tensorflow.python.tpu import tpu_estimator
import mesh_tensorflow.transformer as mtf_transformer
from optimizers import get_optimizer
from utils import (create_host_call, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params,
get_batch_size, auto_layout, auto_layout_and_mesh_shape)
from models.utils import biasmask_attn_weights
from tensorflow.python.ops import resources
from sample import sample_autoregressive
from models.gpt2 import gpt2
import math
def model_fn(features, labels, mode, params):
# Get global step
global_step = tf.train.get_global_step()
# Construct mtf graph + mesh from params
graph = mtf.Graph()
mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
layout_rules = mtf.convert_to_layout_rules(params["layout"])
# Mesh setup
if params["use_tpu"]:
var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules)
else:
var_placer = None
gpu_ids = params["gpu_ids"]
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
mesh_shape, layout_rules, gpu_ids)
# Trainable variable precision
# Store to checkpoints in master type, train in slice type, compute in activation type
if params["precision"] == "bfloat16":
variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32,
activation_dtype=tf.bfloat16)
else:
variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32)
# Build mtf mesh object
mesh = mtf.Mesh(graph, "my_mesh", var_placer)
# Build mtf_features & seq length dict for getting number of microbatches
# We need to pack inputs into a dict to pass into serialize_training_step
features_dict = {"inputs": features, "labels": labels}
sequence_length_dict = {"inputs": params["n_ctx"], "labels": params["n_ctx"]}
params = add_mode_to_params(params, mode)
batch_size = get_batch_size(params)
batch_dim = mtf.Dimension("batch", batch_size)
batch_dims = [batch_dim]
feature_length = sequence_length_dict["inputs"]
length_dim = mtf.Dimension("sequence", feature_length)
mtf_features = {}
for key, x in features_dict.items():
if x is not None:
feature_shape = mtf.Shape(batch_dims + [length_dim])
if type(features_dict[key]) == dict:
features_dict[key] = features_dict[key]["feature"]
x = tf.cast(features_dict[key], tf.int32)
x = tf.reshape(x, feature_shape.to_integer_list)
mtf_features[key] = mtf.import_fully_replicated(
mesh, x, feature_shape, name=key)
# Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
other_features = {}
memory_length_dim = mtf.Dimension("memory_length", length_dim.size)
attn_bias = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None
# Add attn_bias into mtf_features
other_features["attn_bias"] = attn_bias
# Define other Dimensions that we'll need inside the model
embd_dim = mtf.Dimension("embd", params["n_embd"])
vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
# We need this because gathering when both the args have the same dimension in them breaks things
# This dim is specifically for the weights
# This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])
other_features["embd_dim"] = embd_dim
other_features["vocab_dim"] = vocab_dim
other_features["embed_sequence_dim"] = embed_sequence_dim
other_features["memory_length_dim"] = memory_length_dim
if mode == tf.estimator.ModeKeys.PREDICT:
# Set up the model for prediction
inputs = mtf_features["inputs"]
if params["remove_partial_sequences"] is None:
params["remove_partial_sequences"] = False
export = params.get("export", False)
if not export:
mtf_samples = sample_autoregressive(
inputs, other_features=other_features, params=params, variable_dtype=variable_dtype,
remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"],
sampling_use_entmax=params['sampling_use_entmax'], max_steps=params["predict_max_steps"])
else:
with mtf.utils.outside_all_rewrites():
with tf.variable_scope('gpt2'):
mtf_samples, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
variable_dtype=variable_dtype, context=None)
mtf_samples = mtf.anonymize(mtf_samples)
inputs = mtf.anonymize(inputs)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
inputs = lowering.export_to_tf_tensor(inputs)
outputs = lowering.export_to_tf_tensor(mtf_samples)
predictions = {
"inputs": inputs,
"outputs": outputs}
def scaffold_fn():
return tf.train.Scaffold(
local_init_op=tf.group(
tf.train.Scaffold.default_local_init_op(),
lowering.copy_masters_to_slices(),
name="mtf_local_init_op"),
ready_op=tf.concat(
[tf.report_uninitialized_variables(),
resources.report_uninitialized_resources()],
axis=0,
name="mtf_ready_op"))
return tpu_estimator.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
scaffold_fn=scaffold_fn,
prediction_hooks=[mtf.MtfRestoreHook(lowering)])
# We're not predicting, so we better be training or evaluating
assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL)
if mode == tf.estimator.ModeKeys.TRAIN:
# Gets number of microbatches per batch for serialized training
# if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=batch_dim,
sequence_length=sequence_length_dict,
mesh_shape=mesh_shape,
layout_rules=layout_rules,
tokens_per_microbatch_per_replica=
params["tokens_per_mb_per_replica"]))
else:
num_microbatches = 1
params["num_microbatches"] = num_microbatches # Add num microbatches to params
if num_microbatches > 1:
# For serialize_training_step we need to modify the model to output results in a dict
def serialized_fn(mtf_features):
if params["model"] == "GPT":
with tf.variable_scope('gpt2'):
logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
variable_dtype=variable_dtype)
return {"logits": logits, "loss": loss, "loss_batch": loss_batch}
else:
raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
# Serialize the training step - Gradients are accumulated locally and reduced once.
var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, batch_dim, num_microbatches)
loss = output_dict["loss"]
loss_batch = output_dict["loss_batch"]
logits = output_dict["logits"]
else:
# If we're not splitting into microbatches, return logits & loss as is
if params["model"] == "GPT":
with mtf.utils.outside_all_rewrites():
with tf.variable_scope('gpt2'):
logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
variable_dtype=variable_dtype, context=None)
else:
raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
# Auto layout generation
if params["auto_layout"]:
auto_layout(graph, mesh_shape, logits, loss)
if params["auto_layout_and_mesh_shape"]:
auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)
if mode == tf.estimator.ModeKeys.TRAIN:
# In TRAIN mode, get optimizer
if params["num_microbatches"] > 1:
# If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
# So we pass them in here
_, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype,
inp_var_grads=var_grads)
else:
# Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
_, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype)
# Log summaries to tensorboard
mtf.scalar_summary("loss", loss)
# Log gradients if in params
if params["log_grads"] not in [None, False]:
for g in var_grads:
grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
else:
# For now, we can only export fully-replicated tensors.
# This has to be done before lowering or they will not be included in the graph
mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
max_logits = mtf.argmax(logits, vocab_dim)
del logits
fully_replicated_mean_logits = mtf.anonymize(mean_logits)
fully_replicated_max_logits = mtf.anonymize(max_logits)
fully_replicated_loss_batch = mtf.anonymize(loss_batch)
# Gets & prints info about no. trainable vars in the model & dimension names
get_graph_info(graph)
# 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
tf_loss = lowering.export_to_tf_tensor(loss)
tf_loss = tf.cast(tf_loss, tf.float32)
if mode == tf.estimator.ModeKeys.TRAIN:
# Use our patched version until mtf updates theirs
host_call = create_host_call(params['model_path'])
mtf.utils.remove_summaries()
# Creates train_op
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
tf_update_ops.append(tf.assign_add(global_step, 1)) # Need to manually increment global_step
tf.logging.info(f"tf_update_ops: {tf_update_ops}")
train_op = tf.group(tf_update_ops)
else:
tf_mean_logits = lowering.export_to_tf_tensor(fully_replicated_mean_logits)
tf_max_logits = lowering.export_to_tf_tensor(fully_replicated_max_logits)
tf_loss_batch = tf.to_float(lowering.export_to_tf_tensor(fully_replicated_loss_batch))
with mtf.utils.outside_all_rewrites():
# Copy master variables to slices. Must be called first.
restore_hook = mtf.MtfRestoreHook(lowering)
if mode == tf.estimator.ModeKeys.TRAIN:
# Set up the checkpoint server and return the TPUEstimatorSpec
saver = tf.train.Saver(
tf.global_variables(),
sharded=True,
max_to_keep=10,
keep_checkpoint_every_n_hours=2,
defer_build=False,
save_relative_paths=True)
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
saver_listener = mtf.MtfCheckpointSaverListener(lowering)
saver_hook = tf.train.CheckpointSaverHook(
params["model_path"],
save_steps=params["steps_per_checkpoint"],
saver=saver,
listeners=[saver_listener])
return tpu_estimator.TPUEstimatorSpec(
tf.estimator.ModeKeys.TRAIN,
loss=tf_loss,
host_call=host_call,
train_op=train_op,
training_hooks=[restore_hook, saver_hook])
elif mode == tf.estimator.ModeKeys.EVAL:
# Evaluation metrics
def _perplexity(loss):
perplexity = tf.exp(loss)
return tf.metrics.mean(perplexity)
def _bits_per_byte(loss):
bpb = loss * (0.29335 / math.log(2))
return tf.metrics.mean(bpb)
def _metric_fn(tf_mean_logits, tf_loss_batch):
mean_logits = tf.metrics.mean(tf_mean_logits)
loss = tf.reduce_mean(tf_loss_batch)
perp = _perplexity(loss)
bpb = _bits_per_byte(loss)
return {"mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb}
def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
eos_token = params["eos_id"]
answer_positions = tf.where(tf.math.not_equal(labels, eos_token))
correct_answers = tf.gather_nd(tf.math.equal(tf_max_logits, labels), answer_positions)
accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32))
# I guess tf_loss_batch has z_loss and maybe other stuff added to it
# so maybe this should be calculated separately in the future
answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
log_perplexity = tf.metrics.mean(answer_loss)
return {"lambada_acc": accuracy, "lambada_log_ppl": log_perplexity}
eval_task = params["eval_task"]
if eval_task == "lambada":
eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch])
else:
eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])
return tpu_estimator.TPUEstimatorSpec(
tf.estimator.ModeKeys.EVAL,
evaluation_hooks=[restore_hook],
loss=tf_loss,
eval_metrics=eval_metrics)