# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import librosa import logging import json import random import tarfile from subprocess import PIPE, Popen from urllib.parse import urlparse import torch import torchaudio import torchaudio.compliance.kaldi as kaldi import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from wenet.text.base_tokenizer import BaseTokenizer torchaudio.utils.sox_utils.set_buffer_size(16500) AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) def url_opener(data): """ Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert 'src' in sample # TODO(Binbin Zhang): support HTTP url = sample['src'] try: pr = urlparse(url) # local file if pr.scheme == '' or pr.scheme == 'file': stream = open(url, 'rb') # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP else: cmd = f'wget -q -O - {url}' process = Popen(cmd, shell=True, stdout=PIPE) sample.update(process=process) stream = process.stdout sample.update(stream=stream) yield sample except Exception as ex: logging.warning('Failed to open {}'.format(url)) def tar_file_and_group(data): """ Expand a stream of open tar files into a stream of tar file contents. And groups the file with same prefix Args: data: Iterable[{src, stream}] Returns: Iterable[{key, wav, txt, sample_rate}] """ for sample in data: assert 'stream' in sample stream = None try: stream = tarfile.open(fileobj=sample['stream'], mode="r:*") prev_prefix = None example = {} valid = True for tarinfo in stream: name = tarinfo.name pos = name.rfind('.') assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if prev_prefix is not None and prefix != prev_prefix: example['key'] = prev_prefix if valid: yield example example = {} valid = True with stream.extractfile(tarinfo) as file_obj: try: if postfix == 'txt': example['txt'] = file_obj.read().decode( 'utf8').strip() elif postfix in AUDIO_FORMAT_SETS: waveform, sample_rate = torchaudio.load(file_obj) example['wav'] = waveform example['sample_rate'] = sample_rate else: example[postfix] = file_obj.read() except Exception as ex: valid = False logging.warning('error to parse {}'.format(name)) prev_prefix = prefix if prev_prefix is not None: example['key'] = prev_prefix yield example except Exception as ex: logging.warning( 'In tar_file_and_group: {} when processing {}'.format( ex, sample['src'])) finally: if stream is not None: stream.close() if 'process' in sample: sample['process'].communicate() sample['stream'].close() def parse_raw(data): """ Parse key/wav/txt from json line Args: data: Iterable[str], str is a json line has key/wav/txt Returns: Iterable[{key, wav, txt, sample_rate}] """ for sample in data: assert 'src' in sample json_line = sample['src'] obj = json.loads(json_line) assert 'key' in obj assert 'wav' in obj assert 'txt' in obj key = obj['key'] wav_file = obj['wav'] txt = obj['txt'] try: if 'start' in obj: assert 'end' in obj sample_rate = torchaudio.info(wav_file).sample_rate start_frame = int(obj['start'] * sample_rate) end_frame = int(obj['end'] * sample_rate) waveform, _ = torchaudio.load(filepath=wav_file, num_frames=end_frame - start_frame, frame_offset=start_frame) else: waveform, sample_rate = torchaudio.load(wav_file) example = copy.deepcopy(obj) # copy and keep all the fields example['wav'] = waveform # overwrite wav example['sample_rate'] = sample_rate yield example except Exception as ex: logging.warning('Failed to read {}'.format(wav_file)) def parse_speaker(data, speaker_table_path): speaker_dict = {} with open(speaker_table_path, 'r', encoding='utf8') as fin: for line in fin: arr = line.strip().split() speaker_dict[arr[0]] = int(arr[1]) for sample in data: assert 'speaker' in sample speaker = sample['speaker'] sample['speaker'] = speaker_dict.get(speaker, 0) yield sample def filter(data, max_length=10240, min_length=10, token_max_length=200, token_min_length=1, min_output_input_ratio=0.0005, max_output_input_ratio=1): """ Filter sample according to feature and label length Inplace operation. Args:: data: Iterable[{key, wav, label, sample_rate}] max_length: drop utterance which is greater than max_length(10ms) min_length: drop utterance which is less than min_length(10ms) token_max_length: drop utterance which is greater than token_max_length, especially when use char unit for english modeling token_min_length: drop utterance which is less than token_max_length min_output_input_ratio: minimal ration of token_length / feats_length(10ms) max_output_input_ratio: maximum ration of token_length / feats_length(10ms) Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: try: assert 'sample_rate' in sample assert 'wav' in sample assert 'label' in sample except: continue # sample['wav'] is torch.Tensor, we have 100 frames every second num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 if num_frames < min_length: continue if num_frames > max_length: continue if len(sample['label']) < token_min_length: continue if len(sample['label']) > token_max_length: continue if num_frames != 0: if len(sample['label']) / num_frames < min_output_input_ratio: continue if len(sample['label']) / num_frames > max_output_input_ratio: continue yield sample def resample(data, resample_rate=16000): """ Resample data. Inplace operation. Args: data: Iterable[{key, wav, label, sample_rate}] resample_rate: target resample rate Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: assert 'sample_rate' in sample assert 'wav' in sample sample_rate = sample['sample_rate'] waveform = sample['wav'] if sample_rate != resample_rate: sample['sample_rate'] = resample_rate sample['wav'] = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=resample_rate)(waveform) yield sample def speed_perturb(data, speeds=None): """ Apply speed perturb to the data. Inplace operation. Args: data: Iterable[{key, wav, label, sample_rate}] speeds(List[float]): optional speed Returns: Iterable[{key, wav, label, sample_rate}] """ if speeds is None: speeds = [0.9, 1.0, 1.1] for sample in data: assert 'sample_rate' in sample assert 'wav' in sample sample_rate = sample['sample_rate'] waveform = sample['wav'] speed = random.choice(speeds) if speed != 1.0: wav, _ = torchaudio.sox_effects.apply_effects_tensor( waveform, sample_rate, [['speed', str(speed)], ['rate', str(sample_rate)]]) sample['wav'] = wav yield sample def compute_fbank(data, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0): """ Extract fbank Args: data: Iterable[{key, wav, label, sample_rate}] Returns: Iterable[{key, feat, label}] """ for sample in data: assert 'sample_rate' in sample assert 'wav' in sample assert 'key' in sample assert 'label' in sample sample_rate = sample['sample_rate'] waveform = sample['wav'] waveform = waveform * (1 << 15) # Only keep key, feat, label mat = kaldi.fbank(waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, energy_floor=0.0, sample_frequency=sample_rate) sample['feat'] = mat yield sample def compute_mfcc(data, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0, num_ceps=40, high_freq=0.0, low_freq=20.0): """ Extract mfcc Args: data: Iterable[{key, wav, label, sample_rate}] Returns: Iterable[{key, feat, label}] """ for sample in data: assert 'sample_rate' in sample assert 'wav' in sample assert 'key' in sample assert 'label' in sample sample_rate = sample['sample_rate'] waveform = sample['wav'] waveform = waveform * (1 << 15) # Only keep key, feat, label mat = kaldi.mfcc(waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, num_ceps=num_ceps, high_freq=high_freq, low_freq=low_freq, sample_frequency=sample_rate) sample['feat'] = mat yield sample def compute_log_mel_spectrogram(data, n_fft=400, hop_length=160, num_mel_bins=80, padding=0): """ Extract log mel spectrogram, modified from openai-whisper, see: - https://github.com/openai/whisper/blob/main/whisper/audio.py - https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040 Args: data: Iterable[{key, wav, label, sample_rate}] Returns: Iterable[{key, feat, label}] """ for sample in data: assert 'sample_rate' in sample assert 'wav' in sample assert 'key' in sample assert 'label' in sample sample_rate = sample['sample_rate'] waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,) if padding > 0: waveform = F.pad(waveform, (0, padding)) window = torch.hann_window(n_fft) stft = torch.stft(waveform, n_fft, hop_length, window=window, return_complex=True) magnitudes = stft[..., :-1].abs()**2 filters = torch.from_numpy( librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mel_bins)) mel_spec = filters @ magnitudes # NOTE(xcsong): https://github.com/openai/whisper/discussions/269 log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 sample['feat'] = log_spec.transpose(0, 1) yield sample def tokenize(data, tokenizer: BaseTokenizer, global_prompt_dict=None): """ Decode text to chars or BPE Inplace operation Args: data: Iterable[{key, wav, txt, sample_rate}] Returns: Iterable[{key, wav, txt, tokens, label, sample_rate}] """ for sample in data: assert 'txt' in sample if 'task' in sample: task_name = sample['task'] if "" in task_name: txt = sample['txt'].replace("", "") else: txt = sample['txt'] else: txt = sample['txt'] tokens, label = tokenizer.tokenize(txt) sample['tokens'] = tokens sample['label'] = label + [tokenizer.eod_id] if 'task' in sample: task_name = sample['task'] random_index = random.randint(0, len(global_prompt_dict[task_name])-1) prompt = global_prompt_dict[task_name][random_index] sample['prompt'] = tokenizer.tokenize(prompt) yield sample def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): """ Do spec augmentation Inplace operation Args: data: Iterable[{key, feat, label}] num_t_mask: number of time mask to apply num_f_mask: number of freq mask to apply max_t: max width of time mask max_f: max width of freq mask max_w: max width of time warp Returns Iterable[{key, feat, label}] """ for sample in data: assert 'feat' in sample x = sample['feat'] assert isinstance(x, torch.Tensor) y = x.clone().detach() max_frames = y.size(0) max_freq = y.size(1) # time mask for i in range(num_t_mask): start = random.randint(0, max_frames - 1) length = random.randint(1, max_t) end = min(max_frames, start + length) y[start:end, :] = 0 # freq mask for i in range(num_f_mask): start = random.randint(0, max_freq - 1) length = random.randint(1, max_f) end = min(max_freq, start + length) y[:, start:end] = 0 sample['feat'] = y yield sample def spec_sub(data, max_t=20, num_t_sub=3): """ Do spec substitute Inplace operation ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642] Args: data: Iterable[{key, feat, label}] max_t: max width of time substitute num_t_sub: number of time substitute to apply Returns Iterable[{key, feat, label}] """ for sample in data: assert 'feat' in sample x = sample['feat'] assert isinstance(x, torch.Tensor) y = x.clone().detach() max_frames = y.size(0) for i in range(num_t_sub): start = random.randint(0, max_frames - 1) length = random.randint(1, max_t) end = min(max_frames, start + length) # only substitute the earlier time chosen randomly for current time pos = random.randint(0, start) y[start:end, :] = x[start - pos:end - pos, :] sample['feat'] = y yield sample def spec_trim(data, max_t=20): """ Trim tailing frames. Inplace operation. ref: TrimTail [https://arxiv.org/abs/2211.00522] Args: data: Iterable[{key, feat, label}] max_t: max width of length trimming Returns Iterable[{key, feat, label}] """ for sample in data: assert 'feat' in sample x = sample['feat'] assert isinstance(x, torch.Tensor) max_frames = x.size(0) length = random.randint(1, max_t) if length < max_frames / 2: y = x.clone().detach()[:max_frames - length] sample['feat'] = y yield sample def shuffle(data, shuffle_size=10000): """ Local shuffle the data Args: data: Iterable[{key, feat, label}] shuffle_size: buffer size for shuffle Returns: Iterable[{key, feat, label}] """ buf = [] for sample in data: buf.append(sample) if len(buf) >= shuffle_size: random.shuffle(buf) for x in buf: yield x buf = [] # The sample left over random.shuffle(buf) for x in buf: yield x def sort(data, sort_size=500): """ Sort the data by feature length. Sort is used after shuffle and before batch, so we can group utts with similar lengths into a batch, and `sort_size` should be less than `shuffle_size` Args: data: Iterable[{key, feat, label}] sort_size: buffer size for sort Returns: Iterable[{key, feat, label}] """ buf = [] for sample in data: buf.append(sample) if len(buf) >= sort_size: buf.sort(key=lambda x: x['feat'].size(0)) for x in buf: yield x buf = [] # The sample left over buf.sort(key=lambda x: x['feat'].size(0)) for x in buf: yield x def static_batch(data, batch_size=16): """ Static batch the data by `batch_size` Args: data: Iterable[{key, feat, label}] batch_size: batch size Returns: Iterable[List[{key, feat, label}]] """ buf = [] for sample in data: buf.append(sample) if len(buf) >= batch_size: yield buf buf = [] if len(buf) > 0: yield buf def dynamic_batch(data, max_frames_in_batch=12000): """ Dynamic batch the data until the total frames in batch reach `max_frames_in_batch` Args: data: Iterable[{key, feat, label}] max_frames_in_batch: max_frames in one batch Returns: Iterable[List[{key, feat, label}]] """ buf = [] longest_frames = 0 for sample in data: assert 'feat' in sample assert isinstance(sample['feat'], torch.Tensor) new_sample_frames = sample['feat'].size(0) longest_frames = max(longest_frames, new_sample_frames) frames_after_padding = longest_frames * (len(buf) + 1) if frames_after_padding > max_frames_in_batch: yield buf buf = [sample] longest_frames = new_sample_frames else: buf.append(sample) if len(buf) > 0: yield buf def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000): """ Wrapper for static/dynamic batch """ if batch_type == 'static': return static_batch(data, batch_size) elif batch_type == 'dynamic': return dynamic_batch(data, max_frames_in_batch) else: logging.fatal('Unsupported batch type {}'.format(batch_type)) def padding(data): """ Padding the data into training data Args: data: Iterable[List[{key, feat, label}]] Returns: Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] """ for sample in data: assert isinstance(sample, list) feats_length = torch.tensor([x['feat'].size(0) for x in sample], dtype=torch.int32) order = torch.argsort(feats_length, descending=True) feats_lengths = torch.tensor( [sample[i]['feat'].size(0) for i in order], dtype=torch.int32) sorted_feats = [sample[i]['feat'] for i in order] sorted_keys = [sample[i]['key'] for i in order] sorted_labels = [ torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order ] sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order] label_lengths = torch.tensor([x.size(0) for x in sorted_labels], dtype=torch.int32) wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs], dtype=torch.int32) padded_feats = pad_sequence(sorted_feats, batch_first=True, padding_value=0) padding_labels = pad_sequence(sorted_labels, batch_first=True, padding_value=-1) padded_wavs = pad_sequence(sorted_wavs, batch_first=True, padding_value=0) batch = { "keys": sorted_keys, "feats": padded_feats, "target": padding_labels, "feats_lengths": feats_lengths, "target_lengths": label_lengths, "pcm": padded_wavs, "pcm_length": wav_lengths, } if 'speaker' in sample[0]: speaker = torch.tensor([sample[i]['speaker'] for i in order], dtype=torch.int32) batch['speaker'] = speaker if 'prompt' in sample[0]: sorted_prompts = [ torch.tensor(sample[i]['prompt'], dtype=torch.int64 ) for i in order ] prompt_lengths = torch.tensor([x.size(0) for x in sorted_prompts], dtype=torch.int32) padding_prompts = pad_sequence(sorted_prompts, batch_first=True, padding_value=-1) batch['prompt'] = padding_prompts batch['prompt_lengths'] = prompt_lengths yield batch