|
import torch |
|
|
|
|
|
def column_split(x, num_heads, head_size): |
|
"""Split a tensor with `num_heads` alongside the head dimension, instead of |
|
across heads. Fixed to three projections |
|
""" |
|
|
|
x_reshaped = x.reshape( |
|
x.shape[0], |
|
num_heads, |
|
3 * head_size, |
|
) |
|
|
|
x2, x1, v = ( |
|
x_reshaped[:, :, :head_size], |
|
x_reshaped[ |
|
:, |
|
:, |
|
head_size : 2 * head_size, |
|
], |
|
x_reshaped[:, :, 2 * head_size :], |
|
) |
|
x2, x1, v = ( |
|
x2.reshape(x2.shape[0], -1), |
|
x1.reshape(x1.shape[0], -1), |
|
v.reshape(v.shape[0], -1), |
|
) |
|
return x2, x1, v |
|
|
|
|
|
def get_init_from_string(init_str): |
|
if type(init_str) == str: |
|
if init_str == "torch.nn.init.zeros_": |
|
return torch.nn.init.zeros_ |
|
elif init_str == "torch.nn.init.xavier_uniform_": |
|
return torch.nn.init.xavier_uniform_ |
|
elif init_str == "torch.nn.init.xavier_normal_": |
|
return torch.nn.init.xavier_normal_ |
|
else: |
|
raise ValueError(f"Unrecognized init {init_str}") |
|
|
|
|
|
def print_rank_0(message, debug=False, end="\n"): |
|
"""Print from rank 0 only.""" |
|
if torch.distributed.is_initialized(): |
|
if torch.distributed.get_rank() == 0: |
|
print(message, flush=True, end=end) |
|
else: |
|
print(message, flush=True, end=end) |
|
|
|
|
|
class dotdict(dict): |
|
"""dot.notation access to dictionary attributes""" |
|
|
|
__getattr__ = dict.get |
|
__setattr__ = dict.__setitem__ |
|
__delattr__ = dict.__delitem__ |
|
|
|
|
|
def ensure_divisibility(numerator, denominator): |
|
"""Ensure that numerator is divisible by the denominator.""" |
|
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) |
|
|
|
|
|
def divide(numerator, denominator): |
|
"""Ensure that numerator is divisible by the denominator and return |
|
the division value.""" |
|
ensure_divisibility(numerator, denominator) |
|
return numerator // denominator |
|
|
|
|
|
class VocabUtility: |
|
"""Split the vocabulary into `world_size` chunks amd return the |
|
first and last index of the vocabulary belonging to the `rank` |
|
partition: Note that indices in [first, last]""" |
|
|
|
@staticmethod |
|
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): |
|
index_f = rank * per_partition_vocab_size |
|
index_l = index_f + per_partition_vocab_size |
|
return index_f, index_l |
|
|
|
@staticmethod |
|
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): |
|
per_partition_vocab_size = divide(global_vocab_size, world_size) |
|
return VocabUtility.vocab_range_from_per_partition_vocab_size( |
|
per_partition_vocab_size, rank, world_size |
|
) |
|
|