|
import re |
|
from urllib.parse import urlparse |
|
from shutil import rmtree |
|
import logging |
|
import os |
|
from pathlib import Path |
|
import sys |
|
import tensorflow.compat.v1 as tf |
|
import tensorflow.compat.v2 as tf2 |
|
import mesh_tensorflow as mtf |
|
from data.encoders import fetch_encoder |
|
import re |
|
|
|
def setup_logging(args): |
|
Path("logs").mkdir(exist_ok=True) |
|
tf.logging.set_verbosity(logging.INFO) |
|
tf.get_logger().propagate = False |
|
name = os.path.splitext(os.path.basename(args.model))[0] |
|
handlers = [ |
|
logging.FileHandler(f"logs/{name}.log"), |
|
logging.StreamHandler(sys.stdout) |
|
] |
|
logger = logging.getLogger("tensorflow") |
|
logger.handlers = handlers |
|
return logger |
|
|
|
|
|
def get_batch_size(params): |
|
return params[f"{params['mode']}_batch_size"] |
|
|
|
|
|
def add_mode_to_params(params, mode): |
|
if mode == tf.estimator.ModeKeys.PREDICT: |
|
params["mode"] = "predict" |
|
elif mode == tf.estimator.ModeKeys.EVAL: |
|
params["mode"] = "eval" |
|
elif mode == tf.estimator.ModeKeys.TRAIN: |
|
params["mode"] = "train" |
|
else: |
|
raise ValueError(f"Invalid mode {mode}") |
|
return params |
|
|
|
|
|
def simd_mesh_setup(params, mesh_shape, layout_rules): |
|
"""Constructs SimdMesh function - instructions on how to evenly split tensors across all TPU cores""" |
|
|
|
num_hosts = params["context"].num_hosts |
|
host_placement_fn = params["context"].tpu_host_placement_function |
|
device_list = [host_placement_fn(host_id=i) for i in range(num_hosts)] |
|
tf.logging.info(f"device_list = {device_list}") |
|
|
|
|
|
replica_cache_size = 300 * 1000000 |
|
|
|
|
|
worker0_mem = replica_cache_size * params["context"].num_replicas |
|
devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1) |
|
var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memory_usage) |
|
mesh_devices = [""] * mesh_shape.size |
|
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( |
|
mesh_shape, layout_rules, mesh_devices, params["context"].device_assignment) |
|
|
|
return var_placer, mesh_impl |
|
|
|
|
|
def remove_batch_from_layout(layout): |
|
""" |
|
The tf-mesh layout splits across batch size, remove it. |
|
Useful for prediction steps, when you no longer want large batches. |
|
|
|
:param layout: string describing tf-mesh layout |
|
:return: layout minus batch dimension |
|
""" |
|
layout = layout.split(',') |
|
ret_layout = "" |
|
for i in layout: |
|
if "batch" in i: |
|
pass |
|
else: |
|
ret_layout += f"{i}," |
|
return ret_layout[:-1] |
|
|
|
|
|
def yes_or_no(question): |
|
while True: |
|
reply = str(input(question+' (y/n): ')).lower().strip() |
|
if reply[:1] == 'y': |
|
return True |
|
if reply[:1] == 'n': |
|
return False |
|
|
|
|
|
def remove_gs_or_filepath(path): |
|
parsed_url = urlparse(path) |
|
if parsed_url.scheme == "gs": |
|
os.system(f"gsutil rm -rf {path}") |
|
return |
|
rmtree(path) |
|
|
|
|
|
def save_config(params_dict, logdir): |
|
print(f"Saving config to {logdir}") |
|
text = "{\n\n" |
|
total_params = len(params_dict) |
|
for count, key in enumerate(params_dict): |
|
config_value = str(params_dict[key]) |
|
if re.search('[a-zA-Z]', config_value): |
|
if config_value.lower() != 'true': |
|
if config_value.lower() != 'false': |
|
if config_value[0] != '[': |
|
|
|
|
|
if key != "epsilon": |
|
config_value = f'"{config_value}"' |
|
if count == total_params - 1: |
|
text += f'"{str(key)}"' + ' : ' + config_value + '\n\n' |
|
else: |
|
text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n' |
|
text += '\n\n}' |
|
sess = tf.InteractiveSession() |
|
summary_op = tf.summary.text("run_config", tf.convert_to_tensor(text)) |
|
summary_writer = tf.summary.FileWriter(f"{logdir}/config", sess.graph) |
|
text = sess.run(summary_op) |
|
summary_writer.add_summary(text, 0) |
|
summary_writer.flush() |
|
summary_writer.close() |
|
tf.reset_default_graph() |
|
print('Done!') |
|
|
|
|
|
def expand_attention_types_params(params_list): |
|
newlist = [] |
|
for item in params_list: |
|
for _ in range(item[1]): |
|
newlist.extend(item[0]) |
|
return newlist |
|
|
|
|
|
def get_n_trainable_vars(graph): |
|
""" |
|
Gets number of trainable vars in a MTF model. |
|
|
|
:param graph: Mesh-Tensorflow graph |
|
:return: None |
|
""" |
|
total_parameters = 0 |
|
for variable in graph.trainable_variables: |
|
shape = variable.shape.dims |
|
variable_parameters = 1 |
|
for dim in shape: |
|
variable_parameters *= dim.size |
|
total_parameters += variable_parameters |
|
print(f"\n\nN TRAINABLE VARS:\n{total_parameters:,}\n\n") |
|
|
|
|
|
def print_dim_names(graph): |
|
""" |
|
Print names of all Dimensions |
|
:param graph: Mesh-Tensorflow graph |
|
:return: None |
|
""" |
|
all_dim_names = [] |
|
for variable in graph.all_variables: |
|
names = variable.shape.dimension_names |
|
all_dim_names.append(names) |
|
|
|
|
|
all_dim_names = [item for sublist in all_dim_names for item in sublist] |
|
unique_dims = list(set(all_dim_names)) |
|
print("ALL DIM NAMES:") |
|
for dim_name in unique_dims: |
|
print(dim_name) |
|
print('\n') |
|
|
|
|
|
def get_graph_info(graph): |
|
""" |
|
Wrapper fn that calculates number of trainable vars in an MTF graph & prints all dim_names to file |
|
TODO: how to get un-trainable dim-names too, batch etc. |
|
|
|
:param graph: Mesh-Tensorflow graph |
|
:return: None |
|
""" |
|
get_n_trainable_vars(graph) |
|
print_dim_names(graph) |
|
|
|
|
|
def loss_denominator(targets, num_microbatches): |
|
"""Denominator applied to losses. |
|
|
|
This is usually the size of the targets tensor (omitting ensemble |
|
dimensions). Alternatively, it is an override value passed to the |
|
class constructor. |
|
|
|
Args: |
|
targets: a mtf.Tensor |
|
num_microbatches: an integer - greater than one if the step has been |
|
serialized into multiple microbatches to save memory. |
|
Returns: |
|
a float |
|
""" |
|
ret = float(targets.shape.size) * num_microbatches |
|
return float(ret) |
|
|
|
def check_dataset(input_fn, params, global_step=None): |
|
tf.enable_eager_execution() |
|
if global_step is not None: |
|
dataset = input_fn(params, global_step=global_step) |
|
else: |
|
dataset = input_fn(params) |
|
dataset_iter = dataset.make_one_shot_iterator() |
|
tensor, _ = next(dataset_iter) |
|
enc = fetch_encoder(params) |
|
|
|
for p in tensor[:1]: |
|
txt = enc.decode(p) |
|
|
|
print('-' * 50) |
|
print(txt[:500], '\n\n...\n\n', txt[-500:]) |
|
print('-' * 50) |
|
exit() |
|
|
|
def auto_layout(graph, mesh_shape, logits, loss): |
|
layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss]) |
|
print(f"Auto-selected layout:\n{layout_rules}\nRe-initialize graph with selected layout") |
|
quit() |
|
|
|
def auto_layout_and_mesh_shape(graph, num_cores, logits, loss): |
|
layout_rules, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(graph, num_cores, |
|
[logits, loss], max_mesh_shape_dimensions=4) |
|
print(f"Num cores:\n{num_cores}\nAuto-selected layout:\n{layout_rules}\nAuto-selected mesh shape:\n{mesh_shape}" \ |
|
f"\nRe-initialize graph with selected layout & mesh shape") |
|
quit() |
|
|
|
def create_host_call(model_dir): |
|
"""Construct a host_call writing scalar summaries. |
|
|
|
Borrowed from t2t. |
|
|
|
Args: |
|
model_dir: String containing path to train |
|
Returns: |
|
(fn, args) Pair to be called by TPUEstimator as the host_call. |
|
""" |
|
|
|
graph = tf.get_default_graph() |
|
|
|
summaries = graph.get_collection(mtf.utils.SCALAR_SUMMARIES_COLLECTION_KEY) |
|
|
|
def maybe_cast(tensor): |
|
assert tensor.shape.is_compatible_with([]), tensor.name |
|
if tensor.dtype == tf.int64: |
|
return tf.to_int32(tensor) |
|
if tensor.dtype == tf.bfloat16: |
|
return tf.cast(tensor, tf.float32) |
|
return tensor |
|
|
|
reshaped_tensors = [tf.reshape(maybe_cast(t), [1]) for _, t in summaries] |
|
|
|
|
|
|
|
|
|
if not reshaped_tensors: |
|
return None |
|
|
|
def host_call_fn(global_step, *args): |
|
"""Training host call. Creates scalar summaries for training metrics.""" |
|
|
|
|
|
|
|
global_step = tf.cast(global_step[0], tf.int64) |
|
with tf2.summary.create_file_writer(model_dir).as_default(): |
|
|
|
|
|
|
|
|
|
|
|
assert len(args) == len(summaries) |
|
for i, tensor in enumerate(args): |
|
name = summaries[i][0] |
|
tf2.summary.scalar(name, tf.reduce_mean(tensor), step=global_step) |
|
return tf.summary.all_v2_summary_ops() |
|
|
|
global_step_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1]) |
|
return host_call_fn, [global_step_t] + reshaped_tensors |
|
|
|
|
|
def natural_sort(l): |
|
convert = lambda text: int(text) if text.isdigit() else text.lower() |
|
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] |
|
return sorted(l, key = alphanum_key) |
|
|