Spaces:
Sleeping
Sleeping
import tqdm | |
import torch | |
from einops import rearrange | |
def scalar_to_batch_tensor(x, batch_size): | |
return torch.tensor(x).repeat(batch_size) | |
def parallelize( | |
fn, | |
*iterables, | |
parallel: str = "thread_map", | |
**kwargs | |
): | |
if parallel == "thread_map": | |
from tqdm.contrib.concurrent import thread_map | |
return thread_map( | |
fn, | |
*iterables, | |
**kwargs | |
) | |
elif parallel == "process_map": | |
from tqdm.contrib.concurrent import process_map | |
return process_map( | |
fn, | |
*iterables, | |
**kwargs | |
) | |
elif parallel == "single": | |
return [fn(x) for x in tqdm.tqdm(*iterables)] | |
else: | |
raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}") | |
def codebook_flatten(tokens: torch.Tensor): | |
""" | |
flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time) | |
""" | |
return rearrange(tokens, "b c t -> b (t c)") | |
def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None): | |
""" | |
unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time) | |
""" | |
tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c) | |
return tokens | |