| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from concurrent.futures import ProcessPoolExecutor | 
					
					
						
						| 
							 | 
						from functools import wraps | 
					
					
						
						| 
							 | 
						import hashlib | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						import typing as tp | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import flashy | 
					
					
						
						| 
							 | 
						import flashy.distrib | 
					
					
						
						| 
							 | 
						import omegaconf | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from torch.nn.utils.rnn import pad_sequence | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						logger = logging.getLogger(__name__) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def dict_from_config(cfg: omegaconf.DictConfig) -> dict: | 
					
					
						
						| 
							 | 
						    """Convenience function to map an omegaconf configuration to a dictionary. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        cfg (omegaconf.DictConfig): Original configuration to map to dict. | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        dict: Config as dictionary object. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) | 
					
					
						
						| 
							 | 
						    assert isinstance(dct, dict) | 
					
					
						
						| 
							 | 
						    return dct | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset: | 
					
					
						
						| 
							 | 
						    if max_samples >= len(dataset): | 
					
					
						
						| 
							 | 
						        return dataset | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    generator = torch.Generator().manual_seed(seed) | 
					
					
						
						| 
							 | 
						    perm = torch.randperm(len(dataset), generator=generator) | 
					
					
						
						| 
							 | 
						    return torch.utils.data.Subset(dataset, perm[:max_samples].tolist()) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int, | 
					
					
						
						| 
							 | 
						               num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader: | 
					
					
						
						| 
							 | 
						    """Convenience function to load dataset into a dataloader with optional subset sampling. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        dataset: Dataset to load. | 
					
					
						
						| 
							 | 
						        num_samples (Optional[int]): Number of samples to limit subset size. | 
					
					
						
						| 
							 | 
						        batch_size (int): Batch size. | 
					
					
						
						| 
							 | 
						        num_workers (int): Number of workers for data loading. | 
					
					
						
						| 
							 | 
						        seed (int): Random seed. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    if num_samples is not None: | 
					
					
						
						| 
							 | 
						        dataset = random_subset(dataset, num_samples, seed) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    dataloader = flashy.distrib.loader( | 
					
					
						
						| 
							 | 
						        dataset, | 
					
					
						
						| 
							 | 
						        batch_size=batch_size, | 
					
					
						
						| 
							 | 
						        num_workers=num_workers, | 
					
					
						
						| 
							 | 
						        **kwargs | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    return dataloader | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_dataset_from_loader(dataloader): | 
					
					
						
						| 
							 | 
						    dataset = dataloader.dataset | 
					
					
						
						| 
							 | 
						    if isinstance(dataset, torch.utils.data.Subset): | 
					
					
						
						| 
							 | 
						        return dataset.dataset | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        return dataset | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): | 
					
					
						
						| 
							 | 
						    """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        input (torch.Tensor): The input tensor containing probabilities. | 
					
					
						
						| 
							 | 
						        num_samples (int): Number of samples to draw. | 
					
					
						
						| 
							 | 
						        replacement (bool): Whether to draw with replacement or not. | 
					
					
						
						| 
							 | 
						    Keywords args: | 
					
					
						
						| 
							 | 
						        generator (torch.Generator): A pseudorandom number generator for sampling. | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        torch.Tensor: Last dimension contains num_samples indices | 
					
					
						
						| 
							 | 
						            sampled from the multinomial probability distribution | 
					
					
						
						| 
							 | 
						            located in the last dimension of tensor input. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    input_ = input.reshape(-1, input.shape[-1]) | 
					
					
						
						| 
							 | 
						    output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) | 
					
					
						
						| 
							 | 
						    output = output_.reshape(*list(input.shape[:-1]), -1) | 
					
					
						
						| 
							 | 
						    return output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						    """Sample next token from top K values along the last dimension of the input probs tensor. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        probs (torch.Tensor): Input probabilities with token candidates on the last dimension. | 
					
					
						
						| 
							 | 
						        k (int): The k in “top-k”. | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        torch.Tensor: Sampled tokens. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    top_k_value, _ = torch.topk(probs, k, dim=-1) | 
					
					
						
						| 
							 | 
						    min_value_top_k = top_k_value[..., [-1]] | 
					
					
						
						| 
							 | 
						    probs *= (probs >= min_value_top_k).float() | 
					
					
						
						| 
							 | 
						    probs.div_(probs.sum(dim=-1, keepdim=True)) | 
					
					
						
						| 
							 | 
						    next_token = multinomial(probs, num_samples=1) | 
					
					
						
						| 
							 | 
						    return next_token | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						    """Sample next token from top P probabilities along the last dimension of the input probs tensor. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        probs (torch.Tensor): Input probabilities with token candidates on the last dimension. | 
					
					
						
						| 
							 | 
						        p (int): The p in “top-p”. | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        torch.Tensor: Sampled tokens. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | 
					
					
						
						| 
							 | 
						    probs_sum = torch.cumsum(probs_sort, dim=-1) | 
					
					
						
						| 
							 | 
						    mask = probs_sum - probs_sort > p | 
					
					
						
						| 
							 | 
						    probs_sort *= (~mask).float() | 
					
					
						
						| 
							 | 
						    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | 
					
					
						
						| 
							 | 
						    next_token = multinomial(probs_sort, num_samples=1) | 
					
					
						
						| 
							 | 
						    next_token = torch.gather(probs_idx, -1, next_token) | 
					
					
						
						| 
							 | 
						    return next_token | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class DummyPoolExecutor: | 
					
					
						
						| 
							 | 
						    """Dummy pool executor to use when we actually have only 1 worker. | 
					
					
						
						| 
							 | 
						    (e.g. instead of ProcessPoolExecutor). | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    class DummyResult: | 
					
					
						
						| 
							 | 
						        def __init__(self, func, *args, **kwargs): | 
					
					
						
						| 
							 | 
						            self.func = func | 
					
					
						
						| 
							 | 
						            self.args = args | 
					
					
						
						| 
							 | 
						            self.kwargs = kwargs | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        def result(self): | 
					
					
						
						| 
							 | 
						            return self.func(*self.args, **self.kwargs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, workers, mp_context=None): | 
					
					
						
						| 
							 | 
						        pass | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def submit(self, func, *args, **kwargs): | 
					
					
						
						| 
							 | 
						        return DummyPoolExecutor.DummyResult(func, *args, **kwargs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __enter__(self): | 
					
					
						
						| 
							 | 
						        return self | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __exit__(self, exc_type, exc_value, exc_tb): | 
					
					
						
						| 
							 | 
						        return | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_pool_executor(num_workers: int, mp_context=None): | 
					
					
						
						| 
							 | 
						    return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						    """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). | 
					
					
						
						| 
							 | 
						    For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]] | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        lengths (torch.Tensor): tensor with lengths | 
					
					
						
						| 
							 | 
						        max_len (int): can set the max length manually. Defaults to None. | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        torch.Tensor: mask with 0s where there is pad tokens else 1s | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." | 
					
					
						
						| 
							 | 
						    final_length = lengths.max().item() if not max_len else max_len | 
					
					
						
						| 
							 | 
						    final_length = max(final_length, 1)   | 
					
					
						
						| 
							 | 
						    return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def hash_trick(word: str, vocab_size: int) -> int: | 
					
					
						
						| 
							 | 
						    """Hash trick to pair each word with an index | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        word (str): word we wish to convert to an index | 
					
					
						
						| 
							 | 
						        vocab_size (int): size of the vocabulary | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        int: index of the word in the embedding LUT | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16) | 
					
					
						
						| 
							 | 
						    return hash % vocab_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def with_rank_rng(base_seed: int = 1234): | 
					
					
						
						| 
							 | 
						    """Decorator for a function so that the function will use a Random Number Generator | 
					
					
						
						| 
							 | 
						    whose state depend on the GPU rank. The original RNG state is restored upon returning. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        base_seed (int): Random seed. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    def _decorator(fun: tp.Callable): | 
					
					
						
						| 
							 | 
						        @wraps(fun) | 
					
					
						
						| 
							 | 
						        def _decorated(*args, **kwargs): | 
					
					
						
						| 
							 | 
						            state = torch.get_rng_state() | 
					
					
						
						| 
							 | 
						            seed = base_seed ^ flashy.distrib.rank() | 
					
					
						
						| 
							 | 
						            torch.manual_seed(seed) | 
					
					
						
						| 
							 | 
						            logger.debug('Rank dependent seed set to %d', seed) | 
					
					
						
						| 
							 | 
						            try: | 
					
					
						
						| 
							 | 
						                return fun(*args, **kwargs) | 
					
					
						
						| 
							 | 
						            finally: | 
					
					
						
						| 
							 | 
						                torch.set_rng_state(state) | 
					
					
						
						| 
							 | 
						                logger.debug('RNG state restored.') | 
					
					
						
						| 
							 | 
						        return _decorated | 
					
					
						
						| 
							 | 
						    return _decorator | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]: | 
					
					
						
						| 
							 | 
						    """Get a list of tensors and collate them to a single tensor. according to the following logic: | 
					
					
						
						| 
							 | 
						    - `dim` specifies the time dimension which will be stacked and padded. | 
					
					
						
						| 
							 | 
						    - The output will contain 1 new dimension (dimension index 0) which will be the size of | 
					
					
						
						| 
							 | 
						    of the original list. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        tensors (tp.List[torch.Tensor]): List of tensors to collate. | 
					
					
						
						| 
							 | 
						        dim (int): Dimension which will be stacked and padded. | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        tp.Tuple[torch.Tensor, torch.Tensor]: | 
					
					
						
						| 
							 | 
						            torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension | 
					
					
						
						| 
							 | 
						                (dimension index 0) which will be the size of the original list. | 
					
					
						
						| 
							 | 
						            torch.Tensor: Tensor containing length of original tensor sizes (without padding). | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    tensors = [x.transpose(0, dim) for x in tensors] | 
					
					
						
						| 
							 | 
						    lens = torch.LongTensor([len(x) for x in tensors]) | 
					
					
						
						| 
							 | 
						    padded_tensors = pad_sequence(tensors) | 
					
					
						
						| 
							 | 
						    padded_tensors = padded_tensors.transpose(0, 1) | 
					
					
						
						| 
							 | 
						    padded_tensors = padded_tensors.transpose(1, dim + 1) | 
					
					
						
						| 
							 | 
						    return padded_tensors, lens | 
					
					
						
						| 
							 | 
						
 |