Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| import random | |
| import pyarrow.parquet as pq | |
| from io import BytesIO | |
| import torch | |
| import torchaudio | |
| from torch.nn.utils.rnn import pad_sequence | |
| import torch.nn.functional as F | |
| import pyworld as pw | |
| AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} | |
| def parquet_opener(data, mode='train', tts_data={}): | |
| """ Give url or local file, return file descriptor | |
| Inplace operation. | |
| Args: | |
| data(Iterable[str]): url or local file list | |
| Returns: | |
| Iterable[{src, stream}] | |
| """ | |
| for sample in data: | |
| assert 'src' in sample | |
| url = sample['src'] | |
| try: | |
| for df in pq.ParquetFile(url).iter_batches(batch_size=64): | |
| df = df.to_pandas() | |
| for i in range(len(df)): | |
| sample.update(dict(df.loc[i])) | |
| if mode == 'train': | |
| # NOTE do not return sample directly, must initialize a new dict | |
| yield {**sample} | |
| else: | |
| for index, text in enumerate(tts_data[df.loc[i, 'utt']]): | |
| yield {**sample, 'tts_index': index, 'tts_text': text} | |
| except Exception as ex: | |
| logging.warning('Failed to open {}, ex info {}'.format(url, ex)) | |
| def filter(data, | |
| max_length=10240, | |
| min_length=10, | |
| token_max_length=200, | |
| token_min_length=1, | |
| min_output_input_ratio=0.0005, | |
| max_output_input_ratio=1, | |
| mode='train'): | |
| """ Filter sample according to feature and label length | |
| Inplace operation. | |
| Args:: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| max_length: drop utterance which is greater than max_length(10ms) | |
| min_length: drop utterance which is less than min_length(10ms) | |
| token_max_length: drop utterance which is greater than | |
| token_max_length, especially when use char unit for | |
| english modeling | |
| token_min_length: drop utterance which is | |
| less than token_max_length | |
| min_output_input_ratio: minimal ration of | |
| token_length / feats_length(10ms) | |
| max_output_input_ratio: maximum ration of | |
| token_length / feats_length(10ms) | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| for sample in data: | |
| sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) | |
| sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) | |
| del sample['audio_data'] | |
| # sample['wav'] is torch.Tensor, we have 100 frames every second | |
| num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 | |
| if num_frames < min_length: | |
| continue | |
| if num_frames > max_length: | |
| continue | |
| if len(sample['text_token']) < token_min_length: | |
| continue | |
| if len(sample['text_token']) > token_max_length: | |
| continue | |
| if len(sample['speech_token']) == 0: | |
| continue | |
| if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0: | |
| continue | |
| if num_frames != 0: | |
| if len(sample['text_token']) / num_frames < min_output_input_ratio: | |
| continue | |
| if len(sample['text_token']) / num_frames > max_output_input_ratio: | |
| continue | |
| yield sample | |
| def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): | |
| """ Resample data. | |
| Inplace operation. | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| resample_rate: target resample rate | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| for sample in data: | |
| assert 'sample_rate' in sample | |
| assert 'speech' in sample | |
| sample_rate = sample['sample_rate'] | |
| waveform = sample['speech'] | |
| if sample_rate != resample_rate: | |
| if sample_rate < min_sample_rate: | |
| continue | |
| sample['sample_rate'] = resample_rate | |
| sample['speech'] = torchaudio.transforms.Resample( | |
| orig_freq=sample_rate, new_freq=resample_rate)(waveform) | |
| max_val = sample['speech'].abs().max() | |
| if max_val > 1: | |
| sample['speech'] /= max_val | |
| yield sample | |
| def truncate(data, truncate_length=24576, mode='train'): | |
| """ Truncate data. | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| truncate_length: truncate length | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| for sample in data: | |
| waveform = sample['speech'] | |
| if waveform.shape[1] > truncate_length: | |
| start = random.randint(0, waveform.shape[1] - truncate_length) | |
| waveform = waveform[:, start: start + truncate_length] | |
| else: | |
| waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1) | |
| sample['speech'] = waveform | |
| yield sample | |
| def compute_fbank(data, | |
| feat_extractor, | |
| token_mel_ratio=0, | |
| mode='train'): | |
| """ Extract fbank | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| for sample in data: | |
| assert 'sample_rate' in sample | |
| assert 'speech' in sample | |
| assert 'utt' in sample | |
| assert 'text_token' in sample | |
| waveform = sample['speech'] | |
| feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) | |
| if token_mel_ratio != 0: | |
| # trim to align speech_token and speech_feat | |
| token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0])) | |
| feat = feat[:token_mel_ratio * token_len] | |
| sample["speech_token"] = sample["speech_token"][:token_len] | |
| sample['speech_feat'] = feat | |
| yield sample | |
| def compute_f0(data, sample_rate, hop_size, mode='train'): | |
| """ Extract f0 | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| frame_period = hop_size * 1000 / sample_rate | |
| for sample in data: | |
| assert 'sample_rate' in sample | |
| assert 'speech' in sample | |
| assert 'utt' in sample | |
| assert 'text_token' in sample | |
| waveform = sample['speech'] | |
| _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) | |
| if sum(_f0 != 0) < 5: # this happens when the algorithm fails | |
| _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio | |
| f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate) | |
| f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1) | |
| sample['pitch_feat'] = f0 | |
| yield sample | |
| def parse_embedding(data, normalize, mode='train'): | |
| """ Parse utt_embedding/spk_embedding | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| for sample in data: | |
| sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32) | |
| sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32) | |
| if normalize: | |
| sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0) | |
| sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0) | |
| yield sample | |
| def tokenize(data, get_tokenizer, allowed_special, mode='train'): | |
| """ Decode text to chars or BPE | |
| Inplace operation | |
| Args: | |
| data: Iterable[{key, wav, txt, sample_rate}] | |
| Returns: | |
| Iterable[{key, wav, txt, tokens, label, sample_rate}] | |
| """ | |
| tokenizer = get_tokenizer() | |
| for sample in data: | |
| assert 'text' in sample | |
| sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) | |
| yield sample | |
| def shuffle(data, shuffle_size=10000, mode='train'): | |
| """ Local shuffle the data | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| shuffle_size: buffer size for shuffle | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| buf = [] | |
| for sample in data: | |
| buf.append(sample) | |
| if len(buf) >= shuffle_size: | |
| random.shuffle(buf) | |
| for x in buf: | |
| yield x | |
| buf = [] | |
| # The sample left over | |
| random.shuffle(buf) | |
| for x in buf: | |
| yield x | |
| def sort(data, sort_size=500, mode='train'): | |
| """ Sort the data by feature length. | |
| Sort is used after shuffle and before batch, so we can group | |
| utts with similar lengths into a batch, and `sort_size` should | |
| be less than `shuffle_size` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| sort_size: buffer size for sort | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| buf = [] | |
| for sample in data: | |
| buf.append(sample) | |
| if len(buf) >= sort_size: | |
| buf.sort(key=lambda x: x['speech_feat'].size(0)) | |
| for x in buf: | |
| yield x | |
| buf = [] | |
| # The sample left over | |
| buf.sort(key=lambda x: x['speech_feat'].size(0)) | |
| for x in buf: | |
| yield x | |
| def static_batch(data, batch_size=16): | |
| """ Static batch the data by `batch_size` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| batch_size: batch size | |
| Returns: | |
| Iterable[List[{key, feat, label}]] | |
| """ | |
| buf = [] | |
| for sample in data: | |
| buf.append(sample) | |
| if len(buf) >= batch_size: | |
| yield buf | |
| buf = [] | |
| if len(buf) > 0: | |
| yield buf | |
| def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): | |
| """ Dynamic batch the data until the total frames in batch | |
| reach `max_frames_in_batch` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| max_frames_in_batch: max_frames in one batch | |
| Returns: | |
| Iterable[List[{key, feat, label}]] | |
| """ | |
| buf = [] | |
| longest_frames = 0 | |
| for sample in data: | |
| assert 'speech_feat' in sample | |
| assert isinstance(sample['speech_feat'], torch.Tensor) | |
| new_sample_frames = sample['speech_feat'].size(0) | |
| longest_frames = max(longest_frames, new_sample_frames) | |
| frames_after_padding = longest_frames * (len(buf) + 1) | |
| if frames_after_padding > max_frames_in_batch: | |
| yield buf | |
| buf = [sample] | |
| longest_frames = new_sample_frames | |
| else: | |
| buf.append(sample) | |
| if len(buf) > 0: | |
| yield buf | |
| def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'): | |
| """ Wrapper for static/dynamic batch | |
| """ | |
| if batch_type == 'static': | |
| return static_batch(data, batch_size) | |
| elif batch_type == 'dynamic': | |
| return dynamic_batch(data, max_frames_in_batch) | |
| else: | |
| logging.fatal('Unsupported batch type {}'.format(batch_type)) | |
| def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False): | |
| """ Padding the data into training data | |
| Args: | |
| data: Iterable[List[{key, feat, label}]] | |
| Returns: | |
| Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] | |
| """ | |
| for sample in data: | |
| assert isinstance(sample, list) | |
| speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], | |
| dtype=torch.int32) | |
| order = torch.argsort(speech_feat_len, descending=True) | |
| utts = [sample[i]['utt'] for i in order] | |
| speech = [sample[i]['speech'].squeeze(dim=0) for i in order] | |
| speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32) | |
| speech = pad_sequence(speech, batch_first=True, padding_value=0) | |
| speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] | |
| speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) | |
| speech_token = pad_sequence(speech_token, | |
| batch_first=True, | |
| padding_value=0) | |
| speech_feat = [sample[i]['speech_feat'] for i in order] | |
| speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) | |
| speech_feat = pad_sequence(speech_feat, | |
| batch_first=True, | |
| padding_value=0) | |
| text = [sample[i]['text'] for i in order] | |
| text_token = [torch.tensor(sample[i]['text_token']) for i in order] | |
| text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) | |
| text_token = pad_sequence(text_token, batch_first=True, padding_value=0) | |
| utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0) | |
| spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) | |
| batch = { | |
| "utts": utts, | |
| "speech": speech, | |
| "speech_len": speech_len, | |
| "speech_token": speech_token, | |
| "speech_token_len": speech_token_len, | |
| "speech_feat": speech_feat, | |
| "speech_feat_len": speech_feat_len, | |
| "text": text, | |
| "text_token": text_token, | |
| "text_token_len": text_token_len, | |
| "utt_embedding": utt_embedding, | |
| "spk_embedding": spk_embedding, | |
| } | |
| if gan is True: | |
| # in gan train, we need pitch_feat | |
| pitch_feat = [sample[i]['pitch_feat'] for i in order] | |
| pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32) | |
| pitch_feat = pad_sequence(pitch_feat, | |
| batch_first=True, | |
| padding_value=0) | |
| batch["pitch_feat"] = pitch_feat | |
| batch["pitch_feat_len"] = pitch_feat_len | |
| else: | |
| # only gan train needs speech, delete it to save memory | |
| del batch["speech"] | |
| del batch["speech_len"] | |
| if dpo is True: | |
| reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order] | |
| reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32) | |
| reject_speech_token = pad_sequence(reject_speech_token, | |
| batch_first=True, | |
| padding_value=0) | |
| batch['reject_speech_token'] = reject_speech_token | |
| batch['reject_speech_token_len'] = reject_speech_token_len | |
| if use_spk_embedding is True: | |
| batch["embedding"] = batch["spk_embedding"] | |
| else: | |
| batch["embedding"] = batch["utt_embedding"] | |
| yield batch | |
