|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import ast |
|
import math |
|
from functools import partial |
|
from itertools import islice |
|
from pathlib import Path |
|
from typing import Callable, Iterable |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from lhotse.cut import Cut |
|
from omegaconf import OmegaConf |
|
|
|
import nemo.collections.speechlm2.data.salm_dataset |
|
from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper |
|
from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config |
|
from nemo.collections.common.data.lhotse.dataloader import LhotseDataLoadingConfig, tokenize, tokenize_with_prompt |
|
from nemo.collections.common.data.lhotse.sampling import ( |
|
MultimodalFixedBucketBatchSizeConstraint2D, |
|
MultimodalSamplingConstraint, |
|
TokenCountFilter, |
|
TokenPerTokenFilter, |
|
) |
|
from nemo.collections.common.prompts.formatter import PromptFormatter |
|
from nemo.collections.common.tokenizers import AggregateTokenizer, AutoTokenizer, SentencePieceTokenizer |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description="Estimate token bins for Lhotse dynamic bucketing using a sample of the input dataset. " |
|
"The dataset is read either from one or more manifest files and supports data weighting. " |
|
"Unlike estimate_duration_bins.py, this script is intended for text data only. " |
|
"It supports 2D bucketing. ", |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
) |
|
parser.add_argument( |
|
"input", |
|
help='Path to a data input configuration YAML file. ' |
|
'This is the only type of input specification supported for text data.', |
|
) |
|
parser.add_argument( |
|
"-t", |
|
"--tokenizer", |
|
nargs="+", |
|
required=True, |
|
help="Path to one or more SPE tokenizers. More than one means we'll use AggregateTokenizer and --langs argument must also be used. When provided, we'll estimate a 2D distribution for input and output sequence lengths.", |
|
) |
|
parser.add_argument( |
|
"-a", "--langs", nargs="+", help="Language names for each of AggregateTokenizer sub-tokenizers." |
|
) |
|
parser.add_argument( |
|
"-b", |
|
"--buckets", |
|
type=int, |
|
default=30, |
|
help="The desired number of buckets (dim0 => covers input sequence length / audio duration).", |
|
) |
|
parser.add_argument( |
|
"-s", |
|
"--sub-buckets", |
|
type=int, |
|
default=None, |
|
help="The desired number of sub-buckets (dim1 => covers output sequence length / num_tokens). " |
|
"If not provided, we'll only perform 1D bucketing. ", |
|
) |
|
parser.add_argument( |
|
"-n", |
|
"--num_examples", |
|
type=int, |
|
default=-1, |
|
help="The number of examples (utterances) to estimate the bins. -1 means use all data " |
|
"(be careful: it could be iterated over infinitely).", |
|
) |
|
parser.add_argument( |
|
"-l", |
|
"--min_tokens", |
|
type=float, |
|
default=-float("inf"), |
|
help="If specified, we'll filter out examples with less tokens than this number.", |
|
) |
|
parser.add_argument( |
|
"-u", |
|
"--max_tokens", |
|
type=float, |
|
default=float("inf"), |
|
help="If specified, we'll filter out examples with more tokens than this number.", |
|
) |
|
parser.add_argument( |
|
"--max_tpt", |
|
type=float, |
|
default=float("inf"), |
|
help="If specified, we'll filter out examples with more output tokens per input token than this. ", |
|
) |
|
parser.add_argument( |
|
"-q", "--quiet", type=bool, default=False, help="When specified, only print the estimated duration bins." |
|
) |
|
parser.add_argument( |
|
"-f", |
|
"--prompt-format", |
|
type=str, |
|
help="When specified, we'll use a prompt formatter in addition to the tokenizer for the purpose of estimating token count bins. " |
|
"This is useful for accurate 2D bucket estimation with models such as EncDecMultiTaskModel (Canary-1B), " |
|
"or any model where the label sequence consists of a user prompt and a model's response.", |
|
) |
|
parser.add_argument( |
|
"-p", |
|
"--prompt", |
|
type=str, |
|
help="Prompt slots provided as a Python list of dicts. It is used together with --prompt-format option." |
|
"For example, with Canary-1B you may use: [{'role':'user','slots':{'source_lang':'en','target_lang':'en','task':'asr','pnc':'yes'}]", |
|
) |
|
parser.add_argument( |
|
"-m", |
|
"--measure-total-length", |
|
type=bool, |
|
default=False, |
|
help="When specified, we'll measure the total length (context+answer, i.e. input_ids) instead of context-only length. Total length is more suitable for decoder-only models while context-only length is more suitable for encoder-decoder models.", |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def estimate_token_buckets( |
|
cuts: Iterable[Cut], |
|
num_buckets: int, |
|
num_subbuckets: int | None, |
|
quiet: bool, |
|
) -> list[tuple[float, float]]: |
|
""" |
|
This function is based on lhotse.dataset.sampling.dynamic_bucketing.estimate_duration_buckets. |
|
It extends it to a 2D bucketing case. |
|
""" |
|
assert num_buckets > 1 |
|
is_2d = num_subbuckets is not None |
|
|
|
if is_2d: |
|
constraint = MultimodalFixedBucketBatchSizeConstraint2D([(0.0, 0.0)], [0], measure_total_length=False) |
|
else: |
|
constraint = MultimodalSamplingConstraint(measure_total_length=True) |
|
|
|
|
|
num_input_tokens = [] |
|
num_output_tokens = [] |
|
for c in cuts: |
|
ans = constraint.measure_length(c) |
|
if is_2d: |
|
itoks, otoks = ans |
|
num_input_tokens.append(itoks) |
|
num_output_tokens.append(otoks) |
|
else: |
|
num_input_tokens.append(ans) |
|
num_input_tokens = np.array(num_input_tokens, dtype=np.int32) |
|
if is_2d: |
|
num_output_tokens = np.array(num_output_tokens, dtype=np.int32) |
|
joint = np.rec.fromarrays([num_input_tokens, num_output_tokens]) |
|
joint.sort() |
|
num_input_tokens = joint.f0 |
|
num_output_tokens = joint.f1 |
|
else: |
|
num_input_tokens.sort() |
|
|
|
|
|
|
|
size_per_bucket = num_input_tokens.sum() / num_buckets |
|
|
|
if not quiet: |
|
print("Duration distribution:") |
|
print(pd.Series(num_input_tokens).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) |
|
max_input_tokens = num_input_tokens[-1] |
|
|
|
if is_2d: |
|
tpt = num_output_tokens / num_input_tokens |
|
if not quiet: |
|
print("Output tokens per input token distribution:") |
|
print(pd.Series(tpt).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) |
|
max_tpt = tpt.max() |
|
del tpt |
|
|
|
bins = [] |
|
bin_indexes = [0] |
|
tot = 0.0 |
|
|
|
def _estimate_output_token_buckets(max_bucket_duration): |
|
|
|
|
|
|
|
|
|
|
|
nonlocal bins |
|
num_tokens_bucket = num_output_tokens[bin_indexes[-1] : binidx] |
|
num_tokens_bucket.sort() |
|
tokens_per_subbucket = num_tokens_bucket.sum() / num_subbuckets |
|
tot_toks = 0 |
|
|
|
for num_toks in num_tokens_bucket: |
|
|
|
if tot_toks > tokens_per_subbucket: |
|
bins.append((max_bucket_duration, num_toks)) |
|
tot_toks = 0 |
|
tot_toks += num_toks |
|
bins.append((size, math.ceil(size * max_tpt))) |
|
|
|
|
|
for binidx, size in enumerate(num_input_tokens): |
|
if tot > size_per_bucket: |
|
|
|
if is_2d: |
|
_estimate_output_token_buckets(max_bucket_duration=size) |
|
else: |
|
bins.append(size) |
|
tot = 0.0 |
|
tot += size |
|
|
|
|
|
if num_subbuckets is not None: |
|
if is_2d: |
|
_estimate_output_token_buckets(max_bucket_duration=max_input_tokens) |
|
else: |
|
bins.append(max_input_tokens) |
|
|
|
return bins |
|
|
|
|
|
def load_tokenizer(paths: list[str], langs: list[str] = None) -> TokenizerWrapper: |
|
if len(paths) == 1: |
|
(p,) = paths |
|
if Path(p).exists(): |
|
tok = SentencePieceTokenizer(p) |
|
else: |
|
|
|
tok = AutoTokenizer(p, use_fast=True) |
|
else: |
|
assert langs is not None and len(paths) == len( |
|
langs |
|
), f"Cannot create AggregateTokenizer; each tokenizer must have assigned a language via --langs option (we got --tokenizers={paths} and --langs={langs})" |
|
tok = AggregateTokenizer({lang: SentencePieceTokenizer(p) for lang, p in zip(langs, paths)}) |
|
return TokenizerWrapper(tok) |
|
|
|
|
|
def apply_tokenizer(cut, tokenizer=None, prompt: PromptFormatter = None): |
|
if prompt is not None: |
|
cut = tokenize_with_prompt(cut, tokenizer, prompt) |
|
elif tokenizer is not None: |
|
cut = tokenize(cut, tokenizer) |
|
return cut |
|
|
|
|
|
class RejectionsCounter: |
|
def __init__(self, predicate: Callable, message: str): |
|
self.predicate = predicate |
|
self.message = message |
|
self.total = 0 |
|
self.rejected = 0 |
|
|
|
def __call__(self, example) -> bool: |
|
ans = self.predicate(example) |
|
self.total += 1 |
|
if not ans: |
|
self.rejected += 1 |
|
return ans |
|
|
|
def print_report(self) -> None: |
|
if self.rejected: |
|
print(f"{self.message} | Rejected {self.rejected}/{self.total} examples.") |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
if not args.quiet: |
|
pd.set_option('display.float_format', lambda x: '%.2f' % x) |
|
|
|
tokenizer = None |
|
prompt = None |
|
if args.tokenizer is not None: |
|
tokenizer = load_tokenizer(args.tokenizer, args.langs) |
|
if args.prompt_format is not None: |
|
prompt_defaults = None |
|
if args.prompt is not None: |
|
prompt_defaults = ast.literal_eval(args.prompt) |
|
prompt = PromptFormatter.resolve(args.prompt_format)(tokenizer._tokenizer, defaults=prompt_defaults) |
|
|
|
assert args.input.endswith(".yaml") |
|
config = OmegaConf.merge( |
|
OmegaConf.structured(LhotseDataLoadingConfig), |
|
OmegaConf.from_dotlist([f"input_cfg={args.input}", "force_finite=True", "metadata_only=True"]), |
|
) |
|
cuts, _ = read_cutset_from_config(config) |
|
cuts = cuts.map(partial(apply_tokenizer, tokenizer=tokenizer, prompt=prompt), apply_fn=None) |
|
if hasattr(cuts, "prefetch"): |
|
cuts = cuts.prefetch() |
|
token_filter = RejectionsCounter( |
|
TokenCountFilter(args.min_tokens, args.max_tokens, args.measure_total_length), "Token count filtering" |
|
) |
|
cuts = cuts.filter(token_filter) |
|
tpt_filter = RejectionsCounter(TokenPerTokenFilter(-1, args.max_tpt), "Output tokens per input token filtering") |
|
cuts = cuts.filter(tpt_filter) |
|
if (N := args.num_examples) > 0: |
|
cuts = islice(cuts, N) |
|
|
|
token_bins = estimate_token_buckets( |
|
cuts, |
|
num_buckets=args.buckets, |
|
num_subbuckets=args.sub_buckets, |
|
quiet=args.quiet, |
|
) |
|
if args.sub_buckets is not None: |
|
token_bins = "[" + ','.join(f"[{b:d},{sb:d}]" for b, sb in token_bins) + "]" |
|
else: |
|
token_bins = "[" + ','.join(f"{b:d}" for b in token_bins) + "]" |
|
if args.quiet: |
|
print(token_bins) |
|
return |
|
token_filter.print_report() |
|
tpt_filter.print_report() |
|
print("Use the following options in your config:") |
|
print(f"\tnum_buckets={args.buckets}") |
|
print(f"\tbucket_duration_bins={token_bins}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|