import torch def grab_first_if_tuple(x): if x.__class__.__name__ == "tuple": return x[0] else: return x def column_split(x, num_heads, head_size): """Split a tensor with `num_heads` alongside the head dimension, instead of across heads. Fixed to three projections """ x_reshaped = x.reshape( x.shape[0], num_heads, 3 * head_size, ) x2, x1, v = ( x_reshaped[:, :, :head_size], x_reshaped[ :, :, head_size : 2 * head_size, ], x_reshaped[:, :, 2 * head_size :], ) x2, x1, v = ( x2.reshape(x2.shape[0], -1), x1.reshape(x1.shape[0], -1), v.reshape(v.shape[0], -1), ) return x2, x1, v def get_init_from_string(init_str): if type(init_str) == str: if init_str == "torch.nn.init.zeros_": return torch.nn.init.zeros_ elif init_str == "torch.nn.init.xavier_uniform_": return torch.nn.init.xavier_uniform_ elif init_str == "torch.nn.init.xavier_normal_": return torch.nn.init.xavier_normal_ else: raise ValueError(f"Unrecognized init {init_str}") def print_rank_0(message, debug=False, end="\n"): """Print from rank 0 only.""" if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: print(message, flush=True, end=end) else: print(message, flush=True, end=end) class dotdict(dict): """dot.notation access to dictionary attributes""" __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) def divide(numerator, denominator): """Ensure that numerator is divisible by the denominator and return the division value.""" ensure_divisibility(numerator, denominator) return numerator // denominator class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the first and last index of the vocabulary belonging to the `rank` partition: Note that indices in [first, last]""" @staticmethod def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size return index_f, index_l @staticmethod def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): per_partition_vocab_size = divide(global_vocab_size, world_size) return VocabUtility.vocab_range_from_per_partition_vocab_size( per_partition_vocab_size, rank, world_size )