File size: 3,603 Bytes
1dc0a7f
 
 
 
 
752ce9b
1dc0a7f
752ce9b
1dc0a7f
752ce9b
 
 
 
 
1dc0a7f
752ce9b
 
1dc0a7f
752ce9b
 
 
 
1dc0a7f
752ce9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dc0a7f
752ce9b
 
 
 
 
1dc0a7f
752ce9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dc0a7f
752ce9b
1dc0a7f
752ce9b
 
1dc0a7f
752ce9b
1dc0a7f
752ce9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dc0a7f
752ce9b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from datasets import load_dataset
from datasets import Audio



class Dataset:

    def __init__(self, n:int = 100):

        self.n = n
        self.options = ['LibriSpeech Clean', 'LibriSpeech Other', 'Common Voice', 'VoxPopuli', 'TEDLIUM', 'GigaSpeech', 'SPGISpeech', 'AMI', 'OWN']
        self.selected = None
        self.dataset = None
        self.text = None

    def get_options(self):
        return self.options

    def _check_text(self):
        sample = next(iter(self.dataset))
        print(sample)
        self._get_text(sample)

    def _get_text(self, sample):
        if "text" in sample:
            self.text = "text"
            return sample["text"]
        elif "sentence" in sample:
            self.text = "sentence"
            return sample["sentence"]
        elif "normalized_text" in sample:
            self.text = "normalized_text"
            return sample["normalized_text"]
        elif "transcript" in sample:
            self.text = "transcript"
            return sample["transcript"]
        else:
            raise ValueError(f"Sample: {sample.keys()} has no transcript.")
        
    def filter(self, input_column:str = None):

        if input_column is None:
            if self.text is not None:
                input_column = self.text
            else:
                input_column = self._check_text()

        def is_target_text_in_range(ref):
            if ref.strip() == "ignore time segment in scoring":
                return False
            else:
                return ref.strip() != ""
        
        self.dataset = self.dataset.filter(is_target_text_in_range, input_columns=[input_column])
        return self.dataset
        
    def normalised(self, normalise):
        self.dataset = self.dataset.map(normalise)

    def _select(self, option:str):
        if option not in self.options:
            raise ValueError(f"This value is not an option, please see: {self.options}")
        self.selected = option

    def _preprocess(self):

        self.dataset = self.dataset.take(self.n)
        self.dataset = self.dataset.cast_column("audio", Audio(sampling_rate=16000))

    def load(self, option:str = None):

        self._select(option)
        
        if option == "OWN":
            pass
        elif option == "LibriSpeech Clean":
            self.dataset = load_dataset("librispeech_asr", "all", split="test.clean", streaming=True)
        elif option == "LibriSpeech Other":
            self.dataset = load_dataset("librispeech_asr", "all", split="test.other", streaming=True)
        elif option == "Common Voice":
            self.dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, token=True, trust_remote_code=True)
        elif option == "VoxPopuli":
            self.dataset = load_dataset("facebook/voxpopuli", "en", split="test", streaming=True, trust_remote_code=True)
        elif option == "TEDLIUM":
            self.dataset = load_dataset("LIUM/tedlium", "release3", split="test", streaming=True, trust_remote_code=True)
        elif option == "GigaSpeech":
            self.dataset = load_dataset("speechcolab/gigaspeech", "xs", split="test", streaming=True, token=True, trust_remote_code=True)
        elif option == "SPGISpeech":
            self.dataset = load_dataset("kensho/spgispeech", "S", split="test", streaming=True, token=True, trust_remote_code=True)
        elif option == "AMI":
            self.dataset = load_dataset("edinburghcstr/ami", "ihm", split="test", streaming=True, trust_remote_code=True)

        self._preprocess()