Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from concurrent.futures import ProcessPoolExecutor | |
from contextlib import contextmanager | |
from functools import wraps, lru_cache | |
import hashlib | |
import json | |
import logging | |
from pathlib import Path | |
import typing as tp | |
import flashy | |
import flashy.distrib | |
import omegaconf | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
logger = logging.getLogger(__name__) | |
def model_hash(model: torch.nn.Module) -> str: | |
"""Return a model hash. This should allow us to track regressions in model init | |
from the logs of past experiments. | |
""" | |
hasher = hashlib.sha1() | |
for p in model.parameters(): | |
hasher.update(p.data.cpu().numpy().tobytes()) | |
return hasher.hexdigest() | |
def dict_from_config(cfg: omegaconf.DictConfig) -> dict: | |
"""Convenience function to map an omegaconf configuration to a dictionary. | |
Args: | |
cfg (omegaconf.DictConfig): Original configuration to map to dict. | |
Returns: | |
dict: Config as dictionary object. | |
""" | |
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) | |
assert isinstance(dct, dict) | |
return dct | |
def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset: | |
if max_samples >= len(dataset): | |
return dataset | |
generator = torch.Generator().manual_seed(seed) | |
perm = torch.randperm(len(dataset), generator=generator) | |
return torch.utils.data.Subset(dataset, perm[:max_samples].tolist()) | |
def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int, | |
num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader: | |
"""Convenience function to load dataset into a dataloader with optional subset sampling. | |
Args: | |
dataset: Dataset to load. | |
num_samples (Optional[int]): Number of samples to limit subset size. | |
batch_size (int): Batch size. | |
num_workers (int): Number of workers for data loading. | |
seed (int): Random seed. | |
""" | |
if num_samples is not None: | |
dataset = random_subset(dataset, num_samples, seed) | |
dataloader = flashy.distrib.loader( | |
dataset, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
**kwargs | |
) | |
return dataloader | |
def get_dataset_from_loader(dataloader): | |
dataset = dataloader.dataset | |
if isinstance(dataset, torch.utils.data.Subset): | |
return dataset.dataset | |
else: | |
return dataset | |
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): | |
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. | |
Args: | |
input (torch.Tensor): The input tensor containing probabilities. | |
num_samples (int): Number of samples to draw. | |
replacement (bool): Whether to draw with replacement or not. | |
Keywords args: | |
generator (torch.Generator): A pseudorandom number generator for sampling. | |
Returns: | |
torch.Tensor: Last dimension contains num_samples indices | |
sampled from the multinomial probability distribution | |
located in the last dimension of tensor input. | |
""" | |
input_ = input.reshape(-1, input.shape[-1]) | |
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) | |
output = output_.reshape(*list(input.shape[:-1]), -1) | |
return output | |
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: | |
"""Sample next token from top K values along the last dimension of the input probs tensor. | |
Args: | |
probs (torch.Tensor): Input probabilities with token candidates on the last dimension. | |
k (int): The k in “top-k”. | |
Returns: | |
torch.Tensor: Sampled tokens. | |
""" | |
top_k_value, _ = torch.topk(probs, k, dim=-1) | |
min_value_top_k = top_k_value[..., [-1]] | |
probs *= (probs >= min_value_top_k).float() | |
probs.div_(probs.sum(dim=-1, keepdim=True)) | |
next_token = multinomial(probs, num_samples=1) | |
return next_token | |
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: | |
"""Sample next token from top P probabilities along the last dimension of the input probs tensor. | |
Args: | |
probs (torch.Tensor): Input probabilities with token candidates on the last dimension. | |
p (int): The p in “top-p”. | |
Returns: | |
torch.Tensor: Sampled tokens. | |
""" | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > p | |
probs_sort *= (~mask).float() | |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
next_token = multinomial(probs_sort, num_samples=1) | |
next_token = torch.gather(probs_idx, -1, next_token) | |
return next_token | |
class DummyPoolExecutor: | |
"""Dummy pool executor to use when we actually have only 1 worker. | |
(e.g. instead of ProcessPoolExecutor). | |
""" | |
class DummyResult: | |
def __init__(self, func, *args, **kwargs): | |
self.func = func | |
self.args = args | |
self.kwargs = kwargs | |
def result(self): | |
return self.func(*self.args, **self.kwargs) | |
def __init__(self, workers, mp_context=None): | |
pass | |
def submit(self, func, *args, **kwargs): | |
return DummyPoolExecutor.DummyResult(func, *args, **kwargs) | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_value, exc_tb): | |
return | |
def get_pool_executor(num_workers: int, mp_context=None): | |
return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1) | |
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor: | |
"""Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). | |
For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]] | |
Args: | |
lengths (torch.Tensor): tensor with lengths | |
max_len (int): can set the max length manually. Defaults to None. | |
Returns: | |
torch.Tensor: mask with 0s where there is pad tokens else 1s | |
""" | |
assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." | |
final_length = lengths.max().item() if not max_len else max_len | |
final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor | |
return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None] | |
def hash_trick(word: str, vocab_size: int) -> int: | |
"""Hash trick to pair each word with an index | |
Args: | |
word (str): word we wish to convert to an index | |
vocab_size (int): size of the vocabulary | |
Returns: | |
int: index of the word in the embedding LUT | |
""" | |
hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16) | |
return hash % vocab_size | |
def with_rank_rng(base_seed: int = 1234): | |
"""Decorator for a function so that the function will use a Random Number Generator | |
whose state depend on the GPU rank. The original RNG state is restored upon returning. | |
Args: | |
base_seed (int): Random seed. | |
""" | |
def _decorator(fun: tp.Callable): | |
def _decorated(*args, **kwargs): | |
state = torch.get_rng_state() | |
seed = base_seed ^ flashy.distrib.rank() | |
torch.manual_seed(seed) | |
logger.debug('Rank dependent seed set to %d', seed) | |
try: | |
return fun(*args, **kwargs) | |
finally: | |
torch.set_rng_state(state) | |
logger.debug('RNG state restored.') | |
return _decorated | |
return _decorator | |
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
"""Get a list of tensors and collate them to a single tensor. according to the following logic: | |
- `dim` specifies the time dimension which will be stacked and padded. | |
- The output will contain 1 new dimension (dimension index 0) which will be the size of | |
of the original list. | |
Args: | |
tensors (tp.List[torch.Tensor]): List of tensors to collate. | |
dim (int): Dimension which will be stacked and padded. | |
Returns: | |
tp.Tuple[torch.Tensor, torch.Tensor]: | |
torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension | |
(dimension index 0) which will be the size of the original list. | |
torch.Tensor: Tensor containing length of original tensor sizes (without padding). | |
""" | |
tensors = [x.transpose(0, dim) for x in tensors] | |
lens = torch.LongTensor([len(x) for x in tensors]) | |
padded_tensors = pad_sequence(tensors) | |
padded_tensors = padded_tensors.transpose(0, 1) | |
padded_tensors = padded_tensors.transpose(1, dim + 1) | |
return padded_tensors, lens | |
# TODO: Move to flashy? | |
def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu', | |
dtype: tp.Optional[torch.dtype] = None) -> tp.Any: | |
if isinstance(state, torch.Tensor): | |
if dtype is None or not state.is_floating_point(): | |
dtype = state.dtype | |
return state.detach().to(device=device, dtype=dtype, copy=True) | |
elif isinstance(state, dict): | |
return {k: copy_state(v, device, dtype) for k, v in state.items()} | |
elif isinstance(state, list): | |
return [copy_state(v, device, dtype) for v in state] | |
# TODO: Move to flashy? | |
def swap_state(model, state, **kwargs): | |
old_state = copy_state(model.state_dict()) | |
model.load_state_dict(state, **kwargs) | |
try: | |
yield | |
finally: | |
model.load_state_dict(old_state) | |
def warn_once(logger, msg): | |
"""Warn about a given message only once.""" | |
logger.warning(msg) | |
def is_jsonable(x: tp.Any): | |
"""Check if an object can be serialized into a json:""" | |
try: | |
json.dumps(x) | |
return True | |
except (TypeError, OverflowError): | |
return False | |
def load_clap_state_dict(clap_model, path: tp.Union[str, Path]): | |
"""Wrapper around state dict loading of CLAP model | |
addressing compatibility issues between CLAP and AudioCraft | |
HuggingFace transformer version. | |
See: https://github.com/LAION-AI/CLAP/issues/118 | |
""" | |
from clap_module.factory import load_state_dict # type: ignore | |
pkg = load_state_dict(path) | |
pkg.pop('text_branch.embeddings.position_ids', None) | |
clap_model.model.load_state_dict(pkg) | |