from itertools import count, islice from typing import Any, Iterable, Literal, Optional, TypeVar, Union, overload, Dict, List, Tuple from collections import defaultdict import json import spaces import torch from datasets import Dataset, Audio from dataspeech import rate_apply, pitch_apply, snr_apply, squim_apply from metadata_to_text import bins_to_text, speaker_level_relative_to_gender Row = Dict[str, Any] T = TypeVar("T") BATCH_SIZE = 20 @overload def batched(it: Iterable[T], n: int) -> Iterable[List[T]]: ... @overload def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[List[T]]: ... @overload def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[Tuple[List[int], List[T]]]: ... def batched( it: Iterable[T], n: int, with_indices: bool = False ) -> Union[Iterable[List[T]], Iterable[Tuple[List[int], List[T]]]]: it, indices = iter(it), count() while batch := list(islice(it, n)): yield (list(islice(indices, len(batch))), batch) if with_indices else batch @spaces.GPU(duration=60) def analyze( batch: List[Dict[str, Any]], audio_column_name: str, text_column_name: str, cache: Optional[Dict[str, List[Any]]] = None, ) -> List[List[Any]]: cache = {} if cache is None else cache # TODO: add speaker and gender to app speaker_id_column_name = "speaker_id" gender_column_name = "gender" tmp_dict = defaultdict(list) for sample in batch: for key in sample: if key in [audio_column_name, text_column_name, speaker_id_column_name, gender_column_name]: tmp_dict[key].append(sample[key]) if key != audio_column_name else tmp_dict[key].append(sample[key][0]["src"]) tmp_dataset = Dataset.from_dict(tmp_dict).cast_column(audio_column_name, Audio()) ## 1. Extract continous tags squim_dataset = tmp_dataset.map( squim_apply, batched=True, batch_size=BATCH_SIZE, with_rank=True if torch.cuda.device_count()>0 else False, num_proc=torch.cuda.device_count(), remove_columns=[audio_column_name], # tricks to avoid rewritting audio fn_kwargs={"audio_column_name": audio_column_name,}, ) pitch_dataset = tmp_dataset.map( pitch_apply, batched=True, batch_size=BATCH_SIZE, with_rank=True if torch.cuda.device_count()>0 else False, num_proc=torch.cuda.device_count(), remove_columns=[audio_column_name], # tricks to avoid rewritting audio fn_kwargs={"audio_column_name": audio_column_name, "penn_batch_size": 4096}, ) snr_dataset = tmp_dataset.map( snr_apply, batched=True, batch_size=BATCH_SIZE, with_rank=True if torch.cuda.device_count()>0 else False, num_proc=torch.cuda.device_count(), remove_columns=[audio_column_name], # tricks to avoid rewritting audio fn_kwargs={"audio_column_name": audio_column_name}, ) rate_dataset = tmp_dataset.map( rate_apply, with_rank=False, num_proc=1, remove_columns=[audio_column_name], # tricks to avoid rewritting audio fn_kwargs={"audio_column_name": audio_column_name, "text_column_name": text_column_name}, ) enriched_dataset = pitch_dataset.add_column("snr", snr_dataset["snr"]).add_column("c50", snr_dataset["c50"]) enriched_dataset = enriched_dataset.add_column("speaking_rate", rate_dataset["speaking_rate"]).add_column("phonemes", rate_dataset["phonemes"]) enriched_dataset = enriched_dataset.add_column("stoi", squim_dataset["stoi"]).add_column("si-sdr", squim_dataset["sdr"]).add_column("pesq", squim_dataset["pesq"]) ## 2. Map continuous tags to text tags text_bins_dict = {} with open("./v01_text_bins.json") as json_file: text_bins_dict = json.load(json_file) bin_edges_dict = {} with open("./v01_bin_edges.json") as json_file: bin_edges_dict = json.load(json_file) speaker_level_pitch_bins = text_bins_dict.get("speaker_level_pitch_bins") speaker_rate_bins = text_bins_dict.get("speaker_rate_bins") snr_bins = text_bins_dict.get("snr_bins") reverberation_bins = text_bins_dict.get("reverberation_bins") utterance_level_std = text_bins_dict.get("utterance_level_std") enriched_dataset = [enriched_dataset] if "gender" in batch[0] and "speaker_id" in batch[0]: bin_edges = None if "pitch_bins_male" in bin_edges_dict and "pitch_bins_female" in bin_edges_dict: bin_edges = {"male": bin_edges_dict["pitch_bins_male"], "female": bin_edges_dict["pitch_bins_female"]} enriched_dataset, _ = speaker_level_relative_to_gender(enriched_dataset, speaker_level_pitch_bins, "speaker_id", "gender", "utterance_pitch_mean", "pitch", batch_size=20, num_workers=1, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges) enriched_dataset, _ = bins_to_text(enriched_dataset, speaker_rate_bins, "speaking_rate", "speaking_rate", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speaking_rate",None)) enriched_dataset, _ = bins_to_text(enriched_dataset, snr_bins, "snr", "noise", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("noise",None), lower_range=None) enriched_dataset, _ = bins_to_text(enriched_dataset, reverberation_bins, "c50", "reverberation", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("reverberation",None)) enriched_dataset, _ = bins_to_text(enriched_dataset, utterance_level_std, "utterance_pitch_std", "speech_monotony", batch_size=20, num_workers=1, leading_split_for_bins=None, std_tolerance=None, save_dir=None, only_save_plot=False, bin_edges=bin_edges_dict.get("speech_monotony",None)) enriched_dataset = enriched_dataset[0] for i,sample in enumerate(batch): new_sample = {} new_sample[audio_column_name] = f"" for col in ["speaking_rate", "reverberation", "noise", "speech_monotony", "c50", "snr", "stoi", "pesq", "si-sdr"]: # phonemes, speaking_rate, utterance_pitch_std, utterance_pitch_mean new_sample[col] = enriched_dataset[col][i] if "gender" in batch[0] and "speaker_id" in batch[0]: new_sample["pitch"] = enriched_dataset["pitch"][i] new_sample[gender_column_name] = sample[col] new_sample[speaker_id_column_name] = sample[col] new_sample[text_column_name] = sample[text_column_name] batch[i] = new_sample return batch def run_dataspeech( rows: Iterable[Row], audio_column_name: str, text_column_name: str ) -> Iterable[Any]: cache: Dict[str, List[Any]] = {} for batch in batched(rows, BATCH_SIZE): yield analyze( batch=batch, audio_column_name=audio_column_name, text_column_name=text_column_name, cache=cache, )