gpt-neo / test_models.py
aliabd
full working demo
c6e7238
import pytest
import traceback
import logging
from collections import defaultdict
from contextlib import contextmanager
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
import mesh_tensorflow as mtf
from mesh_tensorflow import placement_mesh_impl
from inputs import mlm_sample_text
from models.gpt2 import gpt2
from models.utils import biasmask_attn_weights, entmax, sample_categorical
from sample import sample_autoregressive
# helper functions
@contextmanager
def not_raises(exception):
try:
yield
except exception:
logging.error(traceback.format_exc())
raise pytest.fail("DID RAISE {0}".format(exception))
# fixtures
params = defaultdict(lambda: None, {
"n_head": 1,
"n_ctx": 4,
"n_embd": 2,
"n_vocab": 256,
"embed_dropout": 0.,
"n_layer": 2,
"num_microbatches": 1,
"train_batch_size": 1,
"causal": True,
"attention_types": ['global', 'local'],
"res_dropout": 0.1,
"rotary_emb": True,
"activation_function": "gelu",
"moe_layers": (1,),
"num_mem_kv": 16,
"no_weight_tie": True,
"moe_params": {
'moe_dropout_rate': 0.0
},
"mesh_shape": [],
"layout": {},
"local_attention_radius": 128,
"share_parameters": True,
"rezero": True
})
# tests
def test_model():
graph = mtf.Graph()
mesh = mtf.Mesh(graph, "my_mesh")
seq_len = params["n_ctx"]
batch_dim = mtf.Dimension("batch", 1)
sequence_dim = mtf.Dimension("sequence", seq_len)
features = {
'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32),
'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
}
# create mask
num_mem_kv = params.get('num_mem_kv', 0)
length_dim = mtf.Dimension('sequence', seq_len)
memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
embd_dim = mtf.Dimension("embd", params["n_embd"])
vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
other_features = {}
variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32)
other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype)
other_features["embd_dim"] = embd_dim
other_features["vocab_dim"] = vocab_dim
other_features["embed_sequence_dim"] = embed_sequence_dim
other_features["memory_length_dim"] = memory_length_dim
with not_raises(Exception):
logits, _, _ = gpt2.model(features, other_features, params, mesh, variable_dtype=variable_dtype)
mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
logits = lowering.export_to_tf_tensor(logits)
def test_sampling():
graph = mtf.Graph()
mesh = mtf.Mesh(graph, "my_mesh")
batch_dim = mtf.Dimension("batch", 1)
sequence_dim = mtf.Dimension("sequence", 1)
inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
inputs = mtf.pad(inputs, [0, 3], sequence_dim.name)
# create mask
seq_len = params["n_ctx"]
num_mem_kv = params.get('num_mem_kv', 0)
length_dim = mtf.Dimension('sequence', seq_len)
memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
embd_dim = mtf.Dimension("embd", params["n_embd"])
vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
other_features = {}
other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32))
other_features["embd_dim"] = embd_dim
other_features["vocab_dim"] = vocab_dim
other_features["embed_sequence_dim"] = embed_sequence_dim
other_features["memory_length_dim"] = memory_length_dim
params["mode"] = "predict"
with not_raises(Exception):
samples = sample_autoregressive(
inputs, other_features=other_features, params=params, variable_dtype=mtf.VariableDType(),
remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=True)
mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
samples = lowering.export_to_tf_tensor(samples)
# mlm
mlm_params = defaultdict(lambda: None, {
"n_head": 1,
"n_ctx": 4,
"n_embd": 1,
"n_vocab": 256,
"embed_dropout": 0.,
"n_layer": 2,
"num_microbatches": 1,
"train_batch_size": 1,
"attention_types": ['global', 'local'],
"res_dropout": 0.1,
"mesh_shape": [],
"layout": {},
"share_parameters": True,
"mlm_training": True,
"mlm_mask_id": 3,
"mlm_cls_token_id": 4,
"mlm_random_token_prob": 0.1
})
def test_mlm_sample_text():
document = tf.random.normal((16,))
with not_raises(Exception):
features, labels = mlm_sample_text(mlm_params, document, random_documents = True)
assert features.shape == (mlm_params['n_ctx'],)
# entmax
def test_entmax():
graph = mtf.Graph()
mesh = mtf.Mesh(graph, "my_mesh")
length = mtf.Dimension("tensor_length", 8)
tensor = mtf.range(mesh, length, tf.float32)
output = entmax(tensor)
grad = mtf.gradients([output], [tensor])[0]
sample = sample_categorical(output, length)
mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
sample = lowering.export_to_tf_tensor(sample)
grad = lowering.export_to_tf_tensor(grad)