vampnet / vampnet /util.py
pseeth's picture
Duplicate from hugggof/vampnet
c91e8cc
import tqdm
import torch
from einops import rearrange
def scalar_to_batch_tensor(x, batch_size):
return torch.tensor(x).repeat(batch_size)
def parallelize(
fn,
*iterables,
parallel: str = "thread_map",
**kwargs
):
if parallel == "thread_map":
from tqdm.contrib.concurrent import thread_map
return thread_map(
fn,
*iterables,
**kwargs
)
elif parallel == "process_map":
from tqdm.contrib.concurrent import process_map
return process_map(
fn,
*iterables,
**kwargs
)
elif parallel == "single":
return [fn(x) for x in tqdm.tqdm(*iterables)]
else:
raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
def codebook_flatten(tokens: torch.Tensor):
"""
flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
"""
return rearrange(tokens, "b c t -> b (t c)")
def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
"""
unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
"""
tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
return tokens