VALL-E-X / data /collation.py
Plachta's picture
fix
11433cb
from pathlib import Path
from typing import List, Tuple
import numpy as np
import torch
class TextTokenCollater:
"""Collate list of text tokens
Map sentences to integers. Sentences are padded to equal length.
Beginning and end-of-sequence symbols can be added.
Example:
>>> token_collater = TextTokenCollater(text_tokens)
>>> tokens_batch, tokens_lens = token_collater(text)
Returns:
tokens_batch: IntTensor of shape (B, L)
B: batch dimension, number of input sentences
L: length of the longest sentence
tokens_lens: IntTensor of shape (B,)
Length of each sentence after adding <eos> and <bos>
but before padding.
"""
def __init__(
self,
text_tokens: List[str],
add_eos: bool = True,
add_bos: bool = True,
pad_symbol: str = "<pad>",
bos_symbol: str = "<bos>",
eos_symbol: str = "<eos>",
):
self.pad_symbol = pad_symbol
self.add_eos = add_eos
self.add_bos = add_bos
self.bos_symbol = bos_symbol
self.eos_symbol = eos_symbol
unique_tokens = (
[pad_symbol]
+ ([bos_symbol] if add_bos else [])
+ ([eos_symbol] if add_eos else [])
+ sorted(text_tokens)
)
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
self.idx2token = [token for token in unique_tokens]
def index(
self, tokens_list: List[str]
) -> Tuple[torch.Tensor, torch.Tensor]:
seqs, seq_lens = [], []
for tokens in tokens_list:
assert (
all([True if s in self.token2idx else False for s in tokens])
is True
)
seq = (
([self.bos_symbol] if self.add_bos else [])
+ list(tokens)
+ ([self.eos_symbol] if self.add_eos else [])
)
seqs.append(seq)
seq_lens.append(len(seq))
max_len = max(seq_lens)
for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
seq.extend([self.pad_symbol] * (max_len - seq_len))
tokens = torch.from_numpy(
np.array(
[[self.token2idx[token] for token in seq] for seq in seqs],
dtype=np.int64,
)
)
tokens_lens = torch.IntTensor(seq_lens)
return tokens, tokens_lens
def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
tokens_seqs = [[p for p in text] for text in texts]
max_len = len(max(tokens_seqs, key=len))
seqs = [
([self.bos_symbol] if self.add_bos else [])
+ list(seq)
+ ([self.eos_symbol] if self.add_eos else [])
+ [self.pad_symbol] * (max_len - len(seq))
for seq in tokens_seqs
]
tokens_batch = torch.from_numpy(
np.array(
[seq for seq in seqs],
dtype=np.int64,
)
)
tokens_lens = torch.IntTensor(
[
len(seq) + int(self.add_eos) + int(self.add_bos)
for seq in tokens_seqs
]
)
return tokens_batch, tokens_lens
def get_text_token_collater() -> TextTokenCollater:
collater = TextTokenCollater(
['0'], add_bos=False, add_eos=False
)
return collater