# This module is from [WeNet](https://github.com/wenet-e2e/wenet). # ## Citations # ```bibtex # @inproceedings{yao2021wenet, # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, # booktitle={Proc. Interspeech}, # year={2021}, # address={Brno, Czech Republic }, # organization={IEEE} # } # @article{zhang2022wenet, # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, # journal={arXiv preprint arXiv:2203.15455}, # year={2022} # } # import logging import json import random import re import tarfile from subprocess import PIPE, Popen from urllib.parse import urlparse import torch import torchaudio import torchaudio.compliance.kaldi as kaldi from torch.nn.utils.rnn import pad_sequence AUDIO_FORMAT_SETS = set(["flac", "mp3", "m4a", "ogg", "opus", "wav", "wma"]) def url_opener(data): """Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert "src" in sample # TODO(Binbin Zhang): support HTTP url = sample["src"] try: pr = urlparse(url) # local file if pr.scheme == "" or pr.scheme == "file": stream = open(url, "rb") # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP else: cmd = f"wget -q -O - {url}" process = Popen(cmd, shell=True, stdout=PIPE) sample.update(process=process) stream = process.stdout sample.update(stream=stream) yield sample except Exception as ex: logging.warning("Failed to open {}".format(url)) def tar_file_and_group(data): """Expand a stream of open tar files into a stream of tar file contents. And groups the file with same prefix Args: data: Iterable[{src, stream}] Returns: Iterable[{key, wav, txt, sample_rate}] """ for sample in data: assert "stream" in sample stream = tarfile.open(fileobj=sample["stream"], mode="r|*") prev_prefix = None example = {} valid = True for tarinfo in stream: name = tarinfo.name pos = name.rfind(".") assert pos > 0 prefix, postfix = name[:pos], name[pos + 1 :] if prev_prefix is not None and prefix != prev_prefix: example["key"] = prev_prefix if valid: yield example example = {} valid = True with stream.extractfile(tarinfo) as file_obj: try: if postfix == "txt": example["txt"] = file_obj.read().decode("utf8").strip() elif postfix in AUDIO_FORMAT_SETS: waveform, sample_rate = torchaudio.load(file_obj) example["wav"] = waveform example["sample_rate"] = sample_rate else: example[postfix] = file_obj.read() except Exception as ex: valid = False logging.warning("error to parse {}".format(name)) prev_prefix = prefix if prev_prefix is not None: example["key"] = prev_prefix yield example stream.close() if "process" in sample: sample["process"].communicate() sample["stream"].close() def parse_raw(data): """Parse key/wav/txt from json line Args: data: Iterable[str], str is a json line has key/wav/txt Returns: Iterable[{key, wav, txt, sample_rate}] """ for sample in data: assert "src" in sample json_line = sample["src"] obj = json.loads(json_line) assert "key" in obj assert "wav" in obj assert "txt" in obj key = obj["key"] wav_file = obj["wav"] txt = obj["txt"] try: if "start" in obj: assert "end" in obj sample_rate = torchaudio.backend.sox_io_backend.info( wav_file ).sample_rate start_frame = int(obj["start"] * sample_rate) end_frame = int(obj["end"] * sample_rate) waveform, _ = torchaudio.backend.sox_io_backend.load( filepath=wav_file, num_frames=end_frame - start_frame, frame_offset=start_frame, ) else: waveform, sample_rate = torchaudio.load(wav_file) example = dict(key=key, txt=txt, wav=waveform, sample_rate=sample_rate) yield example except Exception as ex: logging.warning("Failed to read {}".format(wav_file)) def filter( data, max_length=10240, min_length=10, token_max_length=200, token_min_length=1, min_output_input_ratio=0.0005, max_output_input_ratio=1, ): """Filter sample according to feature and label length Inplace operation. Args:: data: Iterable[{key, wav, label, sample_rate}] max_length: drop utterance which is greater than max_length(10ms) min_length: drop utterance which is less than min_length(10ms) token_max_length: drop utterance which is greater than token_max_length, especially when use char unit for english modeling token_min_length: drop utterance which is less than token_max_length min_output_input_ratio: minimal ration of token_length / feats_length(10ms) max_output_input_ratio: maximum ration of token_length / feats_length(10ms) Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: assert "sample_rate" in sample assert "wav" in sample assert "label" in sample # sample['wav'] is torch.Tensor, we have 100 frames every second num_frames = sample["wav"].size(1) / sample["sample_rate"] * 100 if num_frames < min_length: continue if num_frames > max_length: continue if len(sample["label"]) < token_min_length: continue if len(sample["label"]) > token_max_length: continue if num_frames != 0: if len(sample["label"]) / num_frames < min_output_input_ratio: continue if len(sample["label"]) / num_frames > max_output_input_ratio: continue yield sample def resample(data, resample_rate=16000): """Resample data. Inplace operation. Args: data: Iterable[{key, wav, label, sample_rate}] resample_rate: target resample rate Returns: Iterable[{key, wav, label, sample_rate}] """ print("resample...") for sample in data: assert "sample_rate" in sample assert "wav" in sample sample_rate = sample["sample_rate"] print("sample_rate: ", sample_rate) print("resample_rate: ", resample_rate) waveform = sample["wav"] if sample_rate != resample_rate: sample["sample_rate"] = resample_rate sample["wav"] = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=resample_rate )(waveform) yield sample def speed_perturb(data, speeds=None): """Apply speed perturb to the data. Inplace operation. Args: data: Iterable[{key, wav, label, sample_rate}] speeds(List[float]): optional speed Returns: Iterable[{key, wav, label, sample_rate}] """ if speeds is None: speeds = [0.9, 1.0, 1.1] for sample in data: assert "sample_rate" in sample assert "wav" in sample sample_rate = sample["sample_rate"] waveform = sample["wav"] speed = random.choice(speeds) if speed != 1.0: wav, _ = torchaudio.sox_effects.apply_effects_tensor( waveform, sample_rate, [["speed", str(speed)], ["rate", str(sample_rate)]], ) sample["wav"] = wav yield sample def compute_fbank(data, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0): """Extract fbank Args: data: Iterable[{key, wav, label, sample_rate}] Returns: Iterable[{key, feat, label}] """ for sample in data: assert "sample_rate" in sample assert "wav" in sample assert "key" in sample assert "label" in sample sample_rate = sample["sample_rate"] waveform = sample["wav"] waveform = waveform * (1 << 15) # Only keep key, feat, label mat = kaldi.fbank( waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, energy_floor=0.0, sample_frequency=sample_rate, ) yield dict(key=sample["key"], label=sample["label"], feat=mat) def compute_mfcc( data, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0, num_ceps=40, high_freq=0.0, low_freq=20.0, ): """Extract mfcc Args: data: Iterable[{key, wav, label, sample_rate}] Returns: Iterable[{key, feat, label}] """ for sample in data: assert "sample_rate" in sample assert "wav" in sample assert "key" in sample assert "label" in sample sample_rate = sample["sample_rate"] waveform = sample["wav"] waveform = waveform * (1 << 15) # Only keep key, feat, label mat = kaldi.mfcc( waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, num_ceps=num_ceps, high_freq=high_freq, low_freq=low_freq, sample_frequency=sample_rate, ) yield dict(key=sample["key"], label=sample["label"], feat=mat) def __tokenize_by_bpe_model(sp, txt): tokens = [] # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) pattern = re.compile(r"([\u4e00-\u9fff])") # Example: # txt = "你好 ITS'S OKAY 的" # chars = ["你", "好", " ITS'S OKAY ", "的"] chars = pattern.split(txt.upper()) mix_chars = [w for w in chars if len(w.strip()) > 0] for ch_or_w in mix_chars: # ch_or_w is a single CJK charater(i.e., "你"), do nothing. if pattern.fullmatch(ch_or_w) is not None: tokens.append(ch_or_w) # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), # encode ch_or_w using bpe_model. else: for p in sp.encode_as_pieces(ch_or_w): tokens.append(p) return tokens def tokenize( data, symbol_table, bpe_model=None, non_lang_syms=None, split_with_space=False ): """Decode text to chars or BPE Inplace operation Args: data: Iterable[{key, wav, txt, sample_rate}] Returns: Iterable[{key, wav, txt, tokens, label, sample_rate}] """ if non_lang_syms is not None: non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") else: non_lang_syms = {} non_lang_syms_pattern = None if bpe_model is not None: import sentencepiece as spm sp = spm.SentencePieceProcessor() sp.load(bpe_model) else: sp = None for sample in data: assert "txt" in sample txt = sample["txt"].strip() if non_lang_syms_pattern is not None: parts = non_lang_syms_pattern.split(txt.upper()) parts = [w for w in parts if len(w.strip()) > 0] else: parts = [txt] label = [] tokens = [] for part in parts: if part in non_lang_syms: tokens.append(part) else: if bpe_model is not None: tokens.extend(__tokenize_by_bpe_model(sp, part)) else: if split_with_space: part = part.split(" ") for ch in part: if ch == " ": ch = "▁" tokens.append(ch) for ch in tokens: if ch in symbol_table: label.append(symbol_table[ch]) elif "" in symbol_table: label.append(symbol_table[""]) sample["tokens"] = tokens sample["label"] = label yield sample def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): """Do spec augmentation Inplace operation Args: data: Iterable[{key, feat, label}] num_t_mask: number of time mask to apply num_f_mask: number of freq mask to apply max_t: max width of time mask max_f: max width of freq mask max_w: max width of time warp Returns Iterable[{key, feat, label}] """ for sample in data: assert "feat" in sample x = sample["feat"] assert isinstance(x, torch.Tensor) y = x.clone().detach() max_frames = y.size(0) max_freq = y.size(1) # time mask for i in range(num_t_mask): start = random.randint(0, max_frames - 1) length = random.randint(1, max_t) end = min(max_frames, start + length) y[start:end, :] = 0 # freq mask for i in range(num_f_mask): start = random.randint(0, max_freq - 1) length = random.randint(1, max_f) end = min(max_freq, start + length) y[:, start:end] = 0 sample["feat"] = y yield sample def spec_sub(data, max_t=20, num_t_sub=3): """Do spec substitute Inplace operation ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642] Args: data: Iterable[{key, feat, label}] max_t: max width of time substitute num_t_sub: number of time substitute to apply Returns Iterable[{key, feat, label}] """ for sample in data: assert "feat" in sample x = sample["feat"] assert isinstance(x, torch.Tensor) y = x.clone().detach() max_frames = y.size(0) for i in range(num_t_sub): start = random.randint(0, max_frames - 1) length = random.randint(1, max_t) end = min(max_frames, start + length) # only substitute the earlier time chosen randomly for current time pos = random.randint(0, start) y[start:end, :] = x[start - pos : end - pos, :] sample["feat"] = y yield sample def spec_trim(data, max_t=20): """Trim tailing frames. Inplace operation. ref: TrimTail [https://arxiv.org/abs/2211.00522] Args: data: Iterable[{key, feat, label}] max_t: max width of length trimming Returns Iterable[{key, feat, label}] """ for sample in data: assert "feat" in sample x = sample["feat"] assert isinstance(x, torch.Tensor) max_frames = x.size(0) length = random.randint(1, max_t) if length < max_frames / 2: y = x.clone().detach()[: max_frames - length] sample["feat"] = y yield sample def shuffle(data, shuffle_size=10000): """Local shuffle the data Args: data: Iterable[{key, feat, label}] shuffle_size: buffer size for shuffle Returns: Iterable[{key, feat, label}] """ buf = [] for sample in data: buf.append(sample) if len(buf) >= shuffle_size: random.shuffle(buf) for x in buf: yield x buf = [] # The sample left over random.shuffle(buf) for x in buf: yield x def sort(data, sort_size=500): """Sort the data by feature length. Sort is used after shuffle and before batch, so we can group utts with similar lengths into a batch, and `sort_size` should be less than `shuffle_size` Args: data: Iterable[{key, feat, label}] sort_size: buffer size for sort Returns: Iterable[{key, feat, label}] """ buf = [] for sample in data: buf.append(sample) if len(buf) >= sort_size: buf.sort(key=lambda x: x["feat"].size(0)) for x in buf: yield x buf = [] # The sample left over buf.sort(key=lambda x: x["feat"].size(0)) for x in buf: yield x def static_batch(data, batch_size=16): """Static batch the data by `batch_size` Args: data: Iterable[{key, feat, label}] batch_size: batch size Returns: Iterable[List[{key, feat, label}]] """ buf = [] for sample in data: buf.append(sample) if len(buf) >= batch_size: yield buf buf = [] if len(buf) > 0: yield buf def dynamic_batch(data, max_frames_in_batch=12000): """Dynamic batch the data until the total frames in batch reach `max_frames_in_batch` Args: data: Iterable[{key, feat, label}] max_frames_in_batch: max_frames in one batch Returns: Iterable[List[{key, feat, label}]] """ buf = [] longest_frames = 0 for sample in data: assert "feat" in sample assert isinstance(sample["feat"], torch.Tensor) new_sample_frames = sample["feat"].size(0) longest_frames = max(longest_frames, new_sample_frames) frames_after_padding = longest_frames * (len(buf) + 1) if frames_after_padding > max_frames_in_batch: yield buf buf = [sample] longest_frames = new_sample_frames else: buf.append(sample) if len(buf) > 0: yield buf def batch(data, batch_type="static", batch_size=16, max_frames_in_batch=12000): """Wrapper for static/dynamic batch""" if batch_type == "static": return static_batch(data, batch_size) elif batch_type == "dynamic": return dynamic_batch(data, max_frames_in_batch) else: logging.fatal("Unsupported batch type {}".format(batch_type)) def padding(data): """Padding the data into training data Args: data: Iterable[List[{key, feat, label}]] Returns: Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] """ for sample in data: assert isinstance(sample, list) feats_length = torch.tensor( [x["feat"].size(0) for x in sample], dtype=torch.int32 ) order = torch.argsort(feats_length, descending=True) feats_lengths = torch.tensor( [sample[i]["feat"].size(0) for i in order], dtype=torch.int32 ) sorted_feats = [sample[i]["feat"] for i in order] sorted_keys = [sample[i]["key"] for i in order] sorted_labels = [ torch.tensor(sample[i]["label"], dtype=torch.int64) for i in order ] label_lengths = torch.tensor( [x.size(0) for x in sorted_labels], dtype=torch.int32 ) padded_feats = pad_sequence(sorted_feats, batch_first=True, padding_value=0) padding_labels = pad_sequence(sorted_labels, batch_first=True, padding_value=-1) yield (sorted_keys, padded_feats, padding_labels, feats_lengths, label_lengths)