| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import json |
| import random |
| import re |
| import tarfile |
| from subprocess import PIPE, Popen |
| from urllib.parse import urlparse |
|
|
| import torch |
| import torchaudio |
| import torchaudio.compliance.kaldi as kaldi |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| AUDIO_FORMAT_SETS = set(["flac", "mp3", "m4a", "ogg", "opus", "wav", "wma"]) |
|
|
|
|
| def url_opener(data): |
| """Give url or local file, return file descriptor |
| Inplace operation. |
| |
| Args: |
| data(Iterable[str]): url or local file list |
| |
| Returns: |
| Iterable[{src, stream}] |
| """ |
| for sample in data: |
| assert "src" in sample |
| |
| url = sample["src"] |
| try: |
| pr = urlparse(url) |
| |
| if pr.scheme == "" or pr.scheme == "file": |
| stream = open(url, "rb") |
| |
| else: |
| cmd = f"wget -q -O - {url}" |
| process = Popen(cmd, shell=True, stdout=PIPE) |
| sample.update(process=process) |
| stream = process.stdout |
| sample.update(stream=stream) |
| yield sample |
| except Exception as ex: |
| logging.warning("Failed to open {}".format(url)) |
|
|
|
|
| def tar_file_and_group(data): |
| """Expand a stream of open tar files into a stream of tar file contents. |
| And groups the file with same prefix |
| |
| Args: |
| data: Iterable[{src, stream}] |
| |
| Returns: |
| Iterable[{key, wav, txt, sample_rate}] |
| """ |
| for sample in data: |
| assert "stream" in sample |
| stream = tarfile.open(fileobj=sample["stream"], mode="r|*") |
| prev_prefix = None |
| example = {} |
| valid = True |
| for tarinfo in stream: |
| name = tarinfo.name |
| pos = name.rfind(".") |
| assert pos > 0 |
| prefix, postfix = name[:pos], name[pos + 1 :] |
| if prev_prefix is not None and prefix != prev_prefix: |
| example["key"] = prev_prefix |
| if valid: |
| yield example |
| example = {} |
| valid = True |
| with stream.extractfile(tarinfo) as file_obj: |
| try: |
| if postfix == "txt": |
| example["txt"] = file_obj.read().decode("utf8").strip() |
| elif postfix in AUDIO_FORMAT_SETS: |
| waveform, sample_rate = torchaudio.load(file_obj) |
| example["wav"] = waveform |
| example["sample_rate"] = sample_rate |
| else: |
| example[postfix] = file_obj.read() |
| except Exception as ex: |
| valid = False |
| logging.warning("error to parse {}".format(name)) |
| prev_prefix = prefix |
| if prev_prefix is not None: |
| example["key"] = prev_prefix |
| yield example |
| stream.close() |
| if "process" in sample: |
| sample["process"].communicate() |
| sample["stream"].close() |
|
|
|
|
| def parse_raw(data): |
| """Parse key/wav/txt from json line |
| |
| Args: |
| data: Iterable[str], str is a json line has key/wav/txt |
| |
| Returns: |
| Iterable[{key, wav, txt, sample_rate}] |
| """ |
| for sample in data: |
| assert "src" in sample |
| json_line = sample["src"] |
| obj = json.loads(json_line) |
| assert "key" in obj |
| assert "wav" in obj |
| assert "txt" in obj |
| key = obj["key"] |
| wav_file = obj["wav"] |
| txt = obj["txt"] |
| try: |
| if "start" in obj: |
| assert "end" in obj |
| sample_rate = torchaudio.backend.sox_io_backend.info( |
| wav_file |
| ).sample_rate |
| start_frame = int(obj["start"] * sample_rate) |
| end_frame = int(obj["end"] * sample_rate) |
| waveform, _ = torchaudio.backend.sox_io_backend.load( |
| filepath=wav_file, |
| num_frames=end_frame - start_frame, |
| frame_offset=start_frame, |
| ) |
| else: |
| waveform, sample_rate = torchaudio.load(wav_file) |
| example = dict(key=key, txt=txt, wav=waveform, sample_rate=sample_rate) |
| yield example |
| except Exception as ex: |
| logging.warning("Failed to read {}".format(wav_file)) |
|
|
|
|
| def filter( |
| data, |
| max_length=10240, |
| min_length=10, |
| token_max_length=200, |
| token_min_length=1, |
| min_output_input_ratio=0.0005, |
| max_output_input_ratio=1, |
| ): |
| """Filter sample according to feature and label length |
| Inplace operation. |
| |
| Args:: |
| data: Iterable[{key, wav, label, sample_rate}] |
| max_length: drop utterance which is greater than max_length(10ms) |
| min_length: drop utterance which is less than min_length(10ms) |
| token_max_length: drop utterance which is greater than |
| token_max_length, especially when use char unit for |
| english modeling |
| token_min_length: drop utterance which is |
| less than token_max_length |
| min_output_input_ratio: minimal ration of |
| token_length / feats_length(10ms) |
| max_output_input_ratio: maximum ration of |
| token_length / feats_length(10ms) |
| |
| Returns: |
| Iterable[{key, wav, label, sample_rate}] |
| """ |
| for sample in data: |
| assert "sample_rate" in sample |
| assert "wav" in sample |
| assert "label" in sample |
| |
| num_frames = sample["wav"].size(1) / sample["sample_rate"] * 100 |
| if num_frames < min_length: |
| continue |
| if num_frames > max_length: |
| continue |
| if len(sample["label"]) < token_min_length: |
| continue |
| if len(sample["label"]) > token_max_length: |
| continue |
| if num_frames != 0: |
| if len(sample["label"]) / num_frames < min_output_input_ratio: |
| continue |
| if len(sample["label"]) / num_frames > max_output_input_ratio: |
| continue |
| yield sample |
|
|
|
|
| def resample(data, resample_rate=16000): |
| """Resample data. |
| Inplace operation. |
| |
| Args: |
| data: Iterable[{key, wav, label, sample_rate}] |
| resample_rate: target resample rate |
| |
| Returns: |
| Iterable[{key, wav, label, sample_rate}] |
| """ |
| print("resample...") |
| for sample in data: |
| assert "sample_rate" in sample |
| assert "wav" in sample |
| sample_rate = sample["sample_rate"] |
| print("sample_rate: ", sample_rate) |
| print("resample_rate: ", resample_rate) |
| waveform = sample["wav"] |
| if sample_rate != resample_rate: |
| sample["sample_rate"] = resample_rate |
| sample["wav"] = torchaudio.transforms.Resample( |
| orig_freq=sample_rate, new_freq=resample_rate |
| )(waveform) |
| yield sample |
|
|
|
|
| def speed_perturb(data, speeds=None): |
| """Apply speed perturb to the data. |
| Inplace operation. |
| |
| Args: |
| data: Iterable[{key, wav, label, sample_rate}] |
| speeds(List[float]): optional speed |
| |
| Returns: |
| Iterable[{key, wav, label, sample_rate}] |
| """ |
| if speeds is None: |
| speeds = [0.9, 1.0, 1.1] |
| for sample in data: |
| assert "sample_rate" in sample |
| assert "wav" in sample |
| sample_rate = sample["sample_rate"] |
| waveform = sample["wav"] |
| speed = random.choice(speeds) |
| if speed != 1.0: |
| wav, _ = torchaudio.sox_effects.apply_effects_tensor( |
| waveform, |
| sample_rate, |
| [["speed", str(speed)], ["rate", str(sample_rate)]], |
| ) |
| sample["wav"] = wav |
|
|
| yield sample |
|
|
|
|
| def compute_fbank(data, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0): |
| """Extract fbank |
| |
| Args: |
| data: Iterable[{key, wav, label, sample_rate}] |
| |
| Returns: |
| Iterable[{key, feat, label}] |
| """ |
| for sample in data: |
| assert "sample_rate" in sample |
| assert "wav" in sample |
| assert "key" in sample |
| assert "label" in sample |
| sample_rate = sample["sample_rate"] |
| waveform = sample["wav"] |
| waveform = waveform * (1 << 15) |
| |
| mat = kaldi.fbank( |
| waveform, |
| num_mel_bins=num_mel_bins, |
| frame_length=frame_length, |
| frame_shift=frame_shift, |
| dither=dither, |
| energy_floor=0.0, |
| sample_frequency=sample_rate, |
| ) |
| yield dict(key=sample["key"], label=sample["label"], feat=mat) |
|
|
|
|
| def compute_mfcc( |
| data, |
| num_mel_bins=23, |
| frame_length=25, |
| frame_shift=10, |
| dither=0.0, |
| num_ceps=40, |
| high_freq=0.0, |
| low_freq=20.0, |
| ): |
| """Extract mfcc |
| |
| Args: |
| data: Iterable[{key, wav, label, sample_rate}] |
| |
| Returns: |
| Iterable[{key, feat, label}] |
| """ |
| for sample in data: |
| assert "sample_rate" in sample |
| assert "wav" in sample |
| assert "key" in sample |
| assert "label" in sample |
| sample_rate = sample["sample_rate"] |
| waveform = sample["wav"] |
| waveform = waveform * (1 << 15) |
| |
| mat = kaldi.mfcc( |
| waveform, |
| num_mel_bins=num_mel_bins, |
| frame_length=frame_length, |
| frame_shift=frame_shift, |
| dither=dither, |
| num_ceps=num_ceps, |
| high_freq=high_freq, |
| low_freq=low_freq, |
| sample_frequency=sample_rate, |
| ) |
| yield dict(key=sample["key"], label=sample["label"], feat=mat) |
|
|
|
|
| def __tokenize_by_bpe_model(sp, txt): |
| tokens = [] |
| |
| |
| pattern = re.compile(r"([\u4e00-\u9fff])") |
| |
| |
| |
| chars = pattern.split(txt.upper()) |
| mix_chars = [w for w in chars if len(w.strip()) > 0] |
| for ch_or_w in mix_chars: |
| |
| if pattern.fullmatch(ch_or_w) is not None: |
| tokens.append(ch_or_w) |
| |
| |
| else: |
| for p in sp.encode_as_pieces(ch_or_w): |
| tokens.append(p) |
|
|
| return tokens |
|
|
|
|
| def tokenize( |
| data, symbol_table, bpe_model=None, non_lang_syms=None, split_with_space=False |
| ): |
| """Decode text to chars or BPE |
| Inplace operation |
| |
| Args: |
| data: Iterable[{key, wav, txt, sample_rate}] |
| |
| Returns: |
| Iterable[{key, wav, txt, tokens, label, sample_rate}] |
| """ |
| if non_lang_syms is not None: |
| non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") |
| else: |
| non_lang_syms = {} |
| non_lang_syms_pattern = None |
|
|
| if bpe_model is not None: |
| import sentencepiece as spm |
|
|
| sp = spm.SentencePieceProcessor() |
| sp.load(bpe_model) |
| else: |
| sp = None |
|
|
| for sample in data: |
| assert "txt" in sample |
| txt = sample["txt"].strip() |
| if non_lang_syms_pattern is not None: |
| parts = non_lang_syms_pattern.split(txt.upper()) |
| parts = [w for w in parts if len(w.strip()) > 0] |
| else: |
| parts = [txt] |
|
|
| label = [] |
| tokens = [] |
| for part in parts: |
| if part in non_lang_syms: |
| tokens.append(part) |
| else: |
| if bpe_model is not None: |
| tokens.extend(__tokenize_by_bpe_model(sp, part)) |
| else: |
| if split_with_space: |
| part = part.split(" ") |
| for ch in part: |
| if ch == " ": |
| ch = "▁" |
| tokens.append(ch) |
|
|
| for ch in tokens: |
| if ch in symbol_table: |
| label.append(symbol_table[ch]) |
| elif "<unk>" in symbol_table: |
| label.append(symbol_table["<unk>"]) |
|
|
| sample["tokens"] = tokens |
| sample["label"] = label |
| yield sample |
|
|
|
|
| def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): |
| """Do spec augmentation |
| Inplace operation |
| |
| Args: |
| data: Iterable[{key, feat, label}] |
| num_t_mask: number of time mask to apply |
| num_f_mask: number of freq mask to apply |
| max_t: max width of time mask |
| max_f: max width of freq mask |
| max_w: max width of time warp |
| |
| Returns |
| Iterable[{key, feat, label}] |
| """ |
| for sample in data: |
| assert "feat" in sample |
| x = sample["feat"] |
| assert isinstance(x, torch.Tensor) |
| y = x.clone().detach() |
| max_frames = y.size(0) |
| max_freq = y.size(1) |
| |
| for i in range(num_t_mask): |
| start = random.randint(0, max_frames - 1) |
| length = random.randint(1, max_t) |
| end = min(max_frames, start + length) |
| y[start:end, :] = 0 |
| |
| for i in range(num_f_mask): |
| start = random.randint(0, max_freq - 1) |
| length = random.randint(1, max_f) |
| end = min(max_freq, start + length) |
| y[:, start:end] = 0 |
| sample["feat"] = y |
| yield sample |
|
|
|
|
| def spec_sub(data, max_t=20, num_t_sub=3): |
| """Do spec substitute |
| Inplace operation |
| ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642] |
| |
| Args: |
| data: Iterable[{key, feat, label}] |
| max_t: max width of time substitute |
| num_t_sub: number of time substitute to apply |
| |
| Returns |
| Iterable[{key, feat, label}] |
| """ |
| for sample in data: |
| assert "feat" in sample |
| x = sample["feat"] |
| assert isinstance(x, torch.Tensor) |
| y = x.clone().detach() |
| max_frames = y.size(0) |
| for i in range(num_t_sub): |
| start = random.randint(0, max_frames - 1) |
| length = random.randint(1, max_t) |
| end = min(max_frames, start + length) |
| |
| pos = random.randint(0, start) |
| y[start:end, :] = x[start - pos : end - pos, :] |
| sample["feat"] = y |
| yield sample |
|
|
|
|
| def spec_trim(data, max_t=20): |
| """Trim tailing frames. Inplace operation. |
| ref: TrimTail [https://arxiv.org/abs/2211.00522] |
| |
| Args: |
| data: Iterable[{key, feat, label}] |
| max_t: max width of length trimming |
| |
| Returns |
| Iterable[{key, feat, label}] |
| """ |
| for sample in data: |
| assert "feat" in sample |
| x = sample["feat"] |
| assert isinstance(x, torch.Tensor) |
| max_frames = x.size(0) |
| length = random.randint(1, max_t) |
| if length < max_frames / 2: |
| y = x.clone().detach()[: max_frames - length] |
| sample["feat"] = y |
| yield sample |
|
|
|
|
| def shuffle(data, shuffle_size=10000): |
| """Local shuffle the data |
| |
| Args: |
| data: Iterable[{key, feat, label}] |
| shuffle_size: buffer size for shuffle |
| |
| Returns: |
| Iterable[{key, feat, label}] |
| """ |
| buf = [] |
| for sample in data: |
| buf.append(sample) |
| if len(buf) >= shuffle_size: |
| random.shuffle(buf) |
| for x in buf: |
| yield x |
| buf = [] |
| |
| random.shuffle(buf) |
| for x in buf: |
| yield x |
|
|
|
|
| def sort(data, sort_size=500): |
| """Sort the data by feature length. |
| Sort is used after shuffle and before batch, so we can group |
| utts with similar lengths into a batch, and `sort_size` should |
| be less than `shuffle_size` |
| |
| Args: |
| data: Iterable[{key, feat, label}] |
| sort_size: buffer size for sort |
| |
| Returns: |
| Iterable[{key, feat, label}] |
| """ |
|
|
| buf = [] |
| for sample in data: |
| buf.append(sample) |
| if len(buf) >= sort_size: |
| buf.sort(key=lambda x: x["feat"].size(0)) |
| for x in buf: |
| yield x |
| buf = [] |
| |
| buf.sort(key=lambda x: x["feat"].size(0)) |
| for x in buf: |
| yield x |
|
|
|
|
| def static_batch(data, batch_size=16): |
| """Static batch the data by `batch_size` |
| |
| Args: |
| data: Iterable[{key, feat, label}] |
| batch_size: batch size |
| |
| Returns: |
| Iterable[List[{key, feat, label}]] |
| """ |
| buf = [] |
| for sample in data: |
| buf.append(sample) |
| if len(buf) >= batch_size: |
| yield buf |
| buf = [] |
| if len(buf) > 0: |
| yield buf |
|
|
|
|
| def dynamic_batch(data, max_frames_in_batch=12000): |
| """Dynamic batch the data until the total frames in batch |
| reach `max_frames_in_batch` |
| |
| Args: |
| data: Iterable[{key, feat, label}] |
| max_frames_in_batch: max_frames in one batch |
| |
| Returns: |
| Iterable[List[{key, feat, label}]] |
| """ |
| buf = [] |
| longest_frames = 0 |
| for sample in data: |
| assert "feat" in sample |
| assert isinstance(sample["feat"], torch.Tensor) |
| new_sample_frames = sample["feat"].size(0) |
| longest_frames = max(longest_frames, new_sample_frames) |
| frames_after_padding = longest_frames * (len(buf) + 1) |
| if frames_after_padding > max_frames_in_batch: |
| yield buf |
| buf = [sample] |
| longest_frames = new_sample_frames |
| else: |
| buf.append(sample) |
| if len(buf) > 0: |
| yield buf |
|
|
|
|
| def batch(data, batch_type="static", batch_size=16, max_frames_in_batch=12000): |
| """Wrapper for static/dynamic batch""" |
| if batch_type == "static": |
| return static_batch(data, batch_size) |
| elif batch_type == "dynamic": |
| return dynamic_batch(data, max_frames_in_batch) |
| else: |
| logging.fatal("Unsupported batch type {}".format(batch_type)) |
|
|
|
|
| def padding(data): |
| """Padding the data into training data |
| |
| Args: |
| data: Iterable[List[{key, feat, label}]] |
| |
| Returns: |
| Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] |
| """ |
| for sample in data: |
| assert isinstance(sample, list) |
| feats_length = torch.tensor( |
| [x["feat"].size(0) for x in sample], dtype=torch.int32 |
| ) |
| order = torch.argsort(feats_length, descending=True) |
| feats_lengths = torch.tensor( |
| [sample[i]["feat"].size(0) for i in order], dtype=torch.int32 |
| ) |
| sorted_feats = [sample[i]["feat"] for i in order] |
| sorted_keys = [sample[i]["key"] for i in order] |
| sorted_labels = [ |
| torch.tensor(sample[i]["label"], dtype=torch.int64) for i in order |
| ] |
| label_lengths = torch.tensor( |
| [x.size(0) for x in sorted_labels], dtype=torch.int32 |
| ) |
|
|
| padded_feats = pad_sequence(sorted_feats, batch_first=True, padding_value=0) |
| padding_labels = pad_sequence(sorted_labels, batch_first=True, padding_value=-1) |
|
|
| yield (sorted_keys, padded_feats, padding_labels, feats_lengths, label_lengths) |
|
|