ASR_Model_Comparison / dataset.py
j-tobias
added backend
752ce9b
raw
history blame
3.6 kB
from datasets import load_dataset
from datasets import Audio
class Dataset:
def __init__(self, n:int = 100):
self.n = n
self.options = ['LibriSpeech Clean', 'LibriSpeech Other', 'Common Voice', 'VoxPopuli', 'TEDLIUM', 'GigaSpeech', 'SPGISpeech', 'AMI', 'OWN']
self.selected = None
self.dataset = None
self.text = None
def get_options(self):
return self.options
def _check_text(self):
sample = next(iter(self.dataset))
print(sample)
self._get_text(sample)
def _get_text(self, sample):
if "text" in sample:
self.text = "text"
return sample["text"]
elif "sentence" in sample:
self.text = "sentence"
return sample["sentence"]
elif "normalized_text" in sample:
self.text = "normalized_text"
return sample["normalized_text"]
elif "transcript" in sample:
self.text = "transcript"
return sample["transcript"]
else:
raise ValueError(f"Sample: {sample.keys()} has no transcript.")
def filter(self, input_column:str = None):
if input_column is None:
if self.text is not None:
input_column = self.text
else:
input_column = self._check_text()
def is_target_text_in_range(ref):
if ref.strip() == "ignore time segment in scoring":
return False
else:
return ref.strip() != ""
self.dataset = self.dataset.filter(is_target_text_in_range, input_columns=[input_column])
return self.dataset
def normalised(self, normalise):
self.dataset = self.dataset.map(normalise)
def _select(self, option:str):
if option not in self.options:
raise ValueError(f"This value is not an option, please see: {self.options}")
self.selected = option
def _preprocess(self):
self.dataset = self.dataset.take(self.n)
self.dataset = self.dataset.cast_column("audio", Audio(sampling_rate=16000))
def load(self, option:str = None):
self._select(option)
if option == "OWN":
pass
elif option == "LibriSpeech Clean":
self.dataset = load_dataset("librispeech_asr", "all", split="test.clean", streaming=True)
elif option == "LibriSpeech Other":
self.dataset = load_dataset("librispeech_asr", "all", split="test.other", streaming=True)
elif option == "Common Voice":
self.dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, token=True, trust_remote_code=True)
elif option == "VoxPopuli":
self.dataset = load_dataset("facebook/voxpopuli", "en", split="test", streaming=True, trust_remote_code=True)
elif option == "TEDLIUM":
self.dataset = load_dataset("LIUM/tedlium", "release3", split="test", streaming=True, trust_remote_code=True)
elif option == "GigaSpeech":
self.dataset = load_dataset("speechcolab/gigaspeech", "xs", split="test", streaming=True, token=True, trust_remote_code=True)
elif option == "SPGISpeech":
self.dataset = load_dataset("kensho/spgispeech", "S", split="test", streaming=True, token=True, trust_remote_code=True)
elif option == "AMI":
self.dataset = load_dataset("edinburghcstr/ami", "ihm", split="test", streaming=True, trust_remote_code=True)
self._preprocess()