# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import random import json import tarfile import json import io import pyarrow.parquet as pq from io import BytesIO import torch import torchaudio from torch.nn.utils.rnn import pad_sequence import torch.nn.functional as F import tarfile import json import io import wave import numpy as np import torchaudio import os import sys import json import random import pickle import argparse import itertools import mmap import struct import collections import shutil import multiprocessing as mp from pathlib import Path from tqdm import tqdm from collections import defaultdict from copy import deepcopy from datetime import datetime import pickle from wids import wids import math torchaudio.set_audio_backend('soundfile') AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) try: MAIN_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/mean_embedding.pt") GPT_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/spk_mean_embeddings.pt") except: MAIN_SPK_EMBEDDING=torch.zeros(1,192) GPT_SPK_EMBEDDING=torch.zeros(1,192) def parquet_opener(data, mode='train', tts_data={}): """ Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert 'src' in sample url = sample['src'] try: df = pq.read_table(url).to_pandas() for i in range(len(df)): if mode == 'inference' and df.loc[i, 'utt'] not in tts_data: continue sample.update(dict(df.loc[i])) if mode == 'train': # NOTE do not return sample directly, must initialize a new dict yield {**sample} else: for index, text in enumerate(tts_data[df.loc[i, 'utt']]): yield {**sample, 'tts_index': index, 'tts_text': text} except Exception as ex: logging.warning('Failed to open {}, ex info {}'.format(url, ex)) def parse_tar_header(header_bytes): header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes) return TarHeader(*header) TarHeader = collections.namedtuple( "TarHeader", [ "name", "mode", "uid", "gid", "size", "mtime", "chksum", "typeflag", "linkname", "magic", "version", "uname", "gname", "devmajor", "devminor", "prefix", ], ) class MMTar: def __init__(self, file_path: Path | str): self.stream = open(file_path, "rb") self.mmap = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ) def __del__(self): try: self.mmap.close() self.stream.close() except: # noqa pass def get_at_offset(self, offset) -> tuple[str, bytes]: header = parse_tar_header(self.mmap[offset : offset + 500]) name = header.name.decode("utf-8").strip("\x00") start = offset + 512 end = start + int(header.size.decode("utf-8")[:-1], 8) return name, self.mmap[start:end] class Tar: def __init__(self, path: Path): self.tar = MMTar(path) indices_path = path.with_suffix(".index") self.index = pickle.loads(indices_path.read_bytes()) self.name_mapping = {} for name, offset, _ in self.index: self.name_mapping[name] = offset def read(self, name: str) -> bytes: return self.tar.get_at_offset(self.name_mapping[name])[1] def cosy_jsonl_opener(data, mode='train', tts_data={}): """ Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert 'src' in sample cosy_jsonl_path = sample['src'] tar_file_path=cosy_jsonl_path.replace(".vq0907.jsonl",".tar") try: tar_data=Tar(Path(tar_file_path)) with open(cosy_jsonl_path, 'r') as f: for line in f: item=json.loads(line) cosy_token = item['cosy_token'] sample['speech_token']=torch.tensor(cosy_token) sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) # print(item['filename']) yield {**sample} except Exception as ex: logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) def cosy_jsonl_opener_vq0918_nopool(data, mode='train', tts_data={}): """ Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert 'src' in sample cosy_jsonl_path = sample['src'] tar_file_path=cosy_jsonl_path.replace(".vq0918-nopool.jsonl",".tar") try: tar_data=Tar(Path(tar_file_path)) with open(cosy_jsonl_path, 'r') as f: # cosy_data = [json.loads(line) for line in f] for line in f: item=json.loads(line) cosy_token = item['cosy_token'] sample['speech_token']=torch.tensor(cosy_token) sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) # print(item['filename']) yield {**sample} except Exception as ex: logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) def cosy_jsonl_opener_vq0918_pool2(data, mode='train', tts_data={}): """ Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert 'src' in sample cosy_jsonl_path = sample['src'] tar_file_path=cosy_jsonl_path.replace(".vq0918-pool2.jsonl",".tar") try: tar_data=Tar(Path(tar_file_path)) with open(cosy_jsonl_path, 'r') as f: for line in f: item=json.loads(line) cosy_token = item['cosy_token'] sample['speech_token']=torch.tensor(cosy_token) sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) yield {**sample} except Exception as ex: logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) def cosy_jsonl_opener_vq0918_pool4(data, mode='train', tts_data={}): """ Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert 'src' in sample cosy_jsonl_path = sample['src'] tar_file_path=cosy_jsonl_path.replace(".vq0918-pool4.jsonl",".tar") try: tar_data=Tar(Path(tar_file_path)) with open(cosy_jsonl_path, 'r') as f: # cosy_data = [json.loads(line) for line in f] for line in f: item=json.loads(line) cosy_token = item['cosy_token'] sample['speech_token']=torch.tensor(cosy_token) sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) # print(item['filename']) yield {**sample} except Exception as ex: logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) def cosy_jsonl_opener_vq0918_pool8(data, mode='train', tts_data={}): """ Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert 'src' in sample cosy_jsonl_path = sample['src'] tar_file_path=cosy_jsonl_path.replace(".vq0918-pool8.jsonl",".tar") try: tar_data=Tar(Path(tar_file_path)) with open(cosy_jsonl_path, 'r') as f: # cosy_data = [json.loads(line) for line in f] for line in f: item=json.loads(line) cosy_token = item['cosy_token'] sample['speech_token']=torch.tensor(cosy_token) sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) # print(item['filename']) yield {**sample} except Exception as ex: logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) def process_sft_vq0918_pool4(data, mode='train', tts_data={}): for sample in data: assert 'src' in sample token_npy_path = sample['src'] wav_path=token_npy_path.replace(".vq0918-pool4.npy","") # wav_path,token_npy_path=sample['src'].split(' ') try: sample['speech_token']=torch.tensor(np.load(token_npy_path)) sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) if sample['speech'].shape[0] > 1: sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) yield {**sample} except Exception as ex: logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) logging.warning('Failed to open {}'.format(wav_path)) def process_sft_vq0918_pool4_split(data, mode='train',split_token=25, tts_data={}): for sample in data: assert 'src' in sample token_npy_path = sample['src'] wav_path=token_npy_path.replace(".vq0918-pool4.npy","") # wav_path,token_npy_path=sample['src'].split(' ') try: # sample['speech_token']=torch.tensor(np.load(token_npy_path)) # sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) # if sample['speech'].shape[0] > 1: # sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) # sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) speech_token=torch.tensor(np.load(token_npy_path)) speech,sample_rate= torchaudio.load(wav_path) # split_speech=int(split_token / 12.5 * sample_rate) if speech.shape[0] > 1: speech = speech.mean(dim=0, keepdim=True) sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) sample['sample_rate']=sample_rate num_splits = (speech_token.size(0) + split_token - 1) // split_token for split_id in range(num_splits): end_token_idx = min((split_id + 1) * split_token, speech_token.size(0)) end_speech_idx=int(np.ceil(end_token_idx / 12.5 * sample_rate)) sample['speech_token']=speech_token[:end_token_idx] sample['speech']=speech[:,:end_speech_idx] print(sample['speech_token'].size(),sample['speech'].size()) yield {**sample} except Exception as ex: logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) logging.warning('Failed to open {}'.format(wav_path)) def process_sft_vq0918_pool2(data, mode='train', tts_data={}): for sample in data: assert 'src' in sample token_npy_path = sample['src'].replace(".vq0918-pool4.npy",".vq0918-pool2.npy") wav_path=token_npy_path.replace(".vq0918-pool2.npy","") # wav_path,token_npy_path=sample['src'].split(' ') try: sample['speech_token']=torch.tensor(np.load(token_npy_path)) sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) if sample['speech'].shape[0] > 1: sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) yield {**sample} except Exception as ex: logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) logging.warning('Failed to open {}'.format(wav_path)) def process_sft_vq0918_pool2_split(data, mode='train',split_token=50, tts_data={}): for sample in data: assert 'src' in sample token_npy_path = sample['src'] wav_path=token_npy_path.replace(".vq0918-pool2.npy","") # wav_path,token_npy_path=sample['src'].split(' ') try: # sample['speech_token']=torch.tensor(np.load(token_npy_path)) # sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) # if sample['speech'].shape[0] > 1: # sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) # sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) speech_token=torch.tensor(np.load(token_npy_path)) speech,sample_rate= torchaudio.load(wav_path) # split_speech=int(split_token / 12.5 * sample_rate) if speech.shape[0] > 1: speech = speech.mean(dim=0, keepdim=True) sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) sample['sample_rate']=sample_rate num_splits = (speech_token.size(0) + split_token - 1) // split_token for split_id in range(num_splits): end_token_idx = min((split_id + 1) * split_token, speech_token.size(0)) end_speech_idx=int(np.ceil(end_token_idx / 25 * sample_rate)) sample['speech_token']=speech_token[:end_token_idx] sample['speech']=speech[:,:end_speech_idx] print(sample['speech_token'].size(),sample['speech'].size()) yield {**sample} except Exception as ex: logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) logging.warning('Failed to open {}'.format(wav_path)) def process_sft_vq0918_pool4_gpt(data, mode='train', tts_data={}): for sample in data: assert 'src' in sample try: entry=json.loads(sample['src']) sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) for conv in entry["conversations"]: if "response_wav" in conv: wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") sample['speech_token']=torch.tensor(np.load(token_npy_path)) sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) if sample['speech'].shape[0] > 1: sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) sample['spk_embedding']=spk_embedding yield {**sample} except Exception as ex: # logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) logging.warning('Failed to open {}'.format(wav_path)) def process_sft_vq0918_pool4_gpt_1010(data, mode='train', tts_data={}): for sample in data: assert 'src' in sample try: entry=json.loads(sample['src']) sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) for conv in entry["conversations"]: if "response_wav" in conv: wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") sample['speech_token']=torch.tensor(np.load(token_npy_path)) sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) if sample['speech'].shape[0] > 1: sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) sample['spk_embedding']=spk_embedding yield {**sample} if "prompt_wav" in conv: wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") sample['speech_token']=torch.tensor(np.load(token_npy_path)) sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) if sample['speech'].shape[0] > 1: sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) sample['spk_embedding']=spk_embedding yield {**sample} except Exception as ex: # logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) logging.warning('Failed to open {}'.format(wav_path)) def filter(data, max_length=10240, min_length=10, token_max_length=200, token_min_length=1, min_output_input_ratio=0.0005, max_output_input_ratio=1, mode='train'): """ Filter sample according to feature and label length Inplace operation. Args:: data: Iterable[{key, wav, label, sample_rate}] max_length: drop utterance which is greater than max_length(10ms) min_length: drop utterance which is less than min_length(10ms) token_max_length: drop utterance which is greater than token_max_length, especially when use char unit for english modeling token_min_length: drop utterance which is less than token_max_length min_output_input_ratio: minimal ration of token_length / feats_length(10ms) max_output_input_ratio: maximum ration of token_length / feats_length(10ms) Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: # sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) # del sample['audio_data'] # sample['wav'] is torch.Tensor, we have 100 frames every second num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 if num_frames < min_length: continue if num_frames > max_length: continue if len(sample['text_token']) < token_min_length: continue if len(sample['text_token']) > token_max_length: continue if len(sample['speech_token']) == 0: continue if num_frames != 0: if len(sample['text_token']) / num_frames < min_output_input_ratio: continue if len(sample['text_token']) / num_frames > max_output_input_ratio: continue yield sample def filter_speech_token(data, max_length=10240, min_length=10, token_max_length=5000, token_min_length=1, min_output_input_ratio=0.0005, max_output_input_ratio=30, mode='train'): """ Filter sample according to feature and label length Inplace operation. Args:: data: Iterable[{key, wav, label, sample_rate}] max_length: drop utterance which is greater than max_length(10ms) min_length: drop utterance which is less than min_length(10ms) token_max_length: drop utterance which is greater than token_max_length, especially when use char unit for english modeling token_min_length: drop utterance which is less than token_max_length min_output_input_ratio: minimal ration of token_length / feats_length(10ms) max_output_input_ratio: maximum ration of token_length / feats_length(10ms) Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: # sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) # del sample['audio_data'] # sample['wav'] is torch.Tensor, we have 100 frames every second num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 if num_frames < min_length: continue if num_frames > max_length: continue if len(sample['speech_token']) < token_min_length: continue if len(sample['speech_token']) > token_max_length: continue if len(sample['speech_token']) == 0: continue if num_frames != 0: if len(sample['speech_token']) / num_frames < min_output_input_ratio: continue if len(sample['speech_token']) / num_frames > max_output_input_ratio: continue yield sample def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): """ Resample data. Inplace operation. Args: data: Iterable[{key, wav, label, sample_rate}] resample_rate: target resample rate Returns: Iterable[{key, wav, label, sample_rate}] """ for sample in data: assert 'sample_rate' in sample assert 'speech' in sample sample_rate = sample['sample_rate'] waveform = sample['speech'] if sample_rate != resample_rate: if sample_rate < min_sample_rate: continue sample['sample_rate'] = resample_rate sample['speech'] = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=resample_rate)(waveform) max_val = sample['speech'].abs().max() if max_val > 1: sample['speech'] /= max_val yield sample def compute_fbank(data, feat_extractor, mode='train'): """ Extract fbank Args: data: Iterable[{key, wav, label, sample_rate}] Returns: Iterable[{key, feat, label}] """ for sample in data: assert 'sample_rate' in sample assert 'speech' in sample # assert 'utt' in sample # assert 'text_token' in sample waveform = sample['speech'] mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) sample['speech_feat'] = mat del sample['speech'] yield sample def parse_embedding(data, normalize, mode='train'): """ Parse utt_embedding/spk_embedding Args: data: Iterable[{key, wav, label, sample_rate}] Returns: Iterable[{key, feat, label}] """ for sample in data: sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32) sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32) if normalize: sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0) sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0) yield sample def tokenize(data, get_tokenizer, allowed_special, mode='train'): """ Decode text to chars or BPE Inplace operation Args: data: Iterable[{key, wav, txt, sample_rate}] Returns: Iterable[{key, wav, txt, tokens, label, sample_rate}] """ tokenizer = get_tokenizer() for sample in data: assert 'text' in sample sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) if mode == 'inference': sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special) yield sample def shuffle(data, shuffle_size=10000, mode='train'): """ Local shuffle the data Args: data: Iterable[{key, feat, label}] shuffle_size: buffer size for shuffle Returns: Iterable[{key, feat, label}] """ buf = [] for sample in data: buf.append(sample) if len(buf) >= shuffle_size: random.shuffle(buf) for x in buf: yield x buf = [] # The sample left over random.shuffle(buf) for x in buf: yield x def sort(data, sort_size=500, mode='train'): """ Sort the data by feature length. Sort is used after shuffle and before batch, so we can group utts with similar lengths into a batch, and `sort_size` should be less than `shuffle_size` Args: data: Iterable[{key, feat, label}] sort_size: buffer size for sort Returns: Iterable[{key, feat, label}] """ buf = [] for sample in data: buf.append(sample) if len(buf) >= sort_size: buf.sort(key=lambda x: x['speech_feat'].size(0)) for x in buf: yield x buf = [] # The sample left over buf.sort(key=lambda x: x['speech_feat'].size(0)) for x in buf: yield x def static_batch(data, batch_size=16): """ Static batch the data by `batch_size` Args: data: Iterable[{key, feat, label}] batch_size: batch size Returns: Iterable[List[{key, feat, label}]] """ buf = [] for sample in data: buf.append(sample) if len(buf) >= batch_size: yield buf buf = [] if len(buf) > 0: yield buf def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): """ Dynamic batch the data until the total frames in batch reach `max_frames_in_batch` Args: data: Iterable[{key, feat, label}] max_frames_in_batch: max_frames in one batch Returns: Iterable[List[{key, feat, label}]] """ buf = [] longest_frames = 0 for sample in data: assert 'speech_feat' in sample assert isinstance(sample['speech_feat'], torch.Tensor) new_sample_frames = sample['speech_feat'].size(0) longest_frames = max(longest_frames, new_sample_frames) frames_after_padding = longest_frames * (len(buf) + 1) if frames_after_padding > max_frames_in_batch: yield buf buf = [sample] longest_frames = new_sample_frames else: buf.append(sample) if len(buf) > 0: yield buf def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'): """ Wrapper for static/dynamic batch """ if mode == 'inference': return static_batch(data, 1) else: if batch_type == 'static': return static_batch(data, batch_size) elif batch_type == 'dynamic': return dynamic_batch(data, max_frames_in_batch) else: logging.fatal('Unsupported batch type {}'.format(batch_type)) def padding(data, use_spk_embedding, mode='train'): """ Padding the data into training data Args: data: Iterable[List[{key, feat, label}]] Returns: Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] """ for sample in data: assert isinstance(sample, list) speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], dtype=torch.int32) order = torch.argsort(speech_feat_len, descending=True) utts = [sample[i]['utt'] for i in order] speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) speech_token = pad_sequence(speech_token, batch_first=True, padding_value=0) speech_feat = [sample[i]['speech_feat'] for i in order] speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) speech_feat = pad_sequence(speech_feat, batch_first=True, padding_value=0) text = [sample[i]['text'] for i in order] text_token = [torch.tensor(sample[i]['text_token']) for i in order] text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) text_token = pad_sequence(text_token, batch_first=True, padding_value=0) utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0) spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) batch = { "utts": utts, "speech_token": speech_token, "speech_token_len": speech_token_len, "speech_feat": speech_feat, "speech_feat_len": speech_feat_len, "text": text, "text_token": text_token, "text_token_len": text_token_len, "utt_embedding": utt_embedding, "spk_embedding": spk_embedding, } if mode == 'inference': tts_text = [sample[i]['tts_text'] for i in order] tts_index = [sample[i]['tts_index'] for i in order] tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) batch.update({'tts_text': tts_text, 'tts_index': tts_index, 'tts_text_token': tts_text_token, 'tts_text_token_len': tts_text_token_len}) if use_spk_embedding is True: batch["embedding"] = batch["spk_embedding"] else: batch["embedding"] = batch["utt_embedding"] yield batch def padding_speech_token(data, use_spk_embedding, mode='train'): """ Padding the data into training data Args: data: Iterable[List[{key, feat, label}]] Returns: Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] """ for sample in data: assert isinstance(sample, list) speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], dtype=torch.int32) order = torch.argsort(speech_feat_len, descending=True) # utts = [sample[i]['utt'] for i in order] # speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] try: speech_token = [sample[i]['speech_token'].clone().detach() for i in order] speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) speech_token = pad_sequence(speech_token, batch_first=True, padding_value=0) speech_feat = [sample[i]['speech_feat'] for i in order] speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) speech_feat = pad_sequence(speech_feat, batch_first=True, padding_value=0) batch = { "speech_token": speech_token, "speech_token_len": speech_token_len, "speech_feat": speech_feat, "speech_feat_len": speech_feat_len, } if mode == 'inference': tts_text = [sample[i]['tts_text'] for i in order] tts_index = [sample[i]['tts_index'] for i in order] tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) batch.update({'tts_text': tts_text, 'tts_index': tts_index, 'tts_text_token': tts_text_token, 'tts_text_token_len': tts_text_token_len}) # if use_spk_embedding is True: # batch["embedding"] = batch["spk_embedding"] # else: # batch["embedding"] = batch["utt_embedding"] batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device) yield batch except Exception as ex: logging.warning(' ex info {}'.format(ex)) # assert False def padding_speech_token_spk(data, use_spk_embedding, mode='train'): """ Padding the data into training data Args: data: Iterable[List[{key, feat, label}]] Returns: Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] """ for sample in data: assert isinstance(sample, list) speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], dtype=torch.int32) order = torch.argsort(speech_feat_len, descending=True) # utts = [sample[i]['utt'] for i in order] # speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] try: speech_token = [sample[i]['speech_token'].clone().detach() for i in order] speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) speech_token = pad_sequence(speech_token, batch_first=True, padding_value=0) speech_feat = [sample[i]['speech_feat'] for i in order] speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) speech_feat = pad_sequence(speech_feat, batch_first=True, padding_value=0) spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) batch = { "speech_token": speech_token, "speech_token_len": speech_token_len, "speech_feat": speech_feat, "speech_feat_len": speech_feat_len, "spk_embedding": spk_embedding, } if mode == 'inference': tts_text = [sample[i]['tts_text'] for i in order] tts_index = [sample[i]['tts_index'] for i in order] tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) batch.update({'tts_text': tts_text, 'tts_index': tts_index, 'tts_text_token': tts_text_token, 'tts_text_token_len': tts_text_token_len}) # if use_spk_embedding is True: # batch["embedding"] = batch["spk_embedding"] # else: # batch["embedding"] = batch["utt_embedding"] # batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device) batch["embedding"] = batch["spk_embedding"] yield batch except Exception as ex: logging.warning(' ex info {}'.format(ex)) # assert False