File size: 9,222 Bytes
fbb1d85
aaef8e0
fbb1d85
aaef8e0
 
bdf49c6
aaef8e0
 
 
bdf49c6
aaef8e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdf49c6
 
aaef8e0
 
bdf49c6
aaef8e0
 
bdf49c6
 
 
 
 
 
aaef8e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdf49c6
aaef8e0
 
 
 
 
bdf49c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbb1d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdf49c6
 
 
fbb1d85
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os
import json
import unittest
from pathlib import Path
from zipfile import ZipFile
from typing import List, Dict, Any, Union
from tempfile import TemporaryDirectory


def validate_zip(submission_track: str, submission_zip: Union[Path, str]):
    """
    Validates the submission format and contents
    Args:
        submission_track: the track of the submission
        submission_zip: path to the submission zip file
    Raises:
        ValueError: if the submission zip is invalid

    """
    with TemporaryDirectory() as temp_dir:
        with ZipFile(submission_zip, 'r') as submission_zip_file:
            submission_zip_file.extractall(temp_dir)
        submission_dir = Path(temp_dir)
        if submission_track in ['NOTSOFAR-SC', 'NOTSOFAR-MC']:
            validate_notsofar_submission(submission_dir=submission_dir)
        elif submission_track in ['DASR-Constrained-LM', 'DASR-Unconstrained-LM']:
            validate_dasr_submission(submission_dir=submission_dir)
        else:
            raise ValueError(f'Invalid submission track: {submission_track}')


def validate_notsofar_submission(submission_dir: Path):
    """
    Validates NOTSOFAR submission format and contents
    Args:
        submission_dir: path to the submission directory
    Raises:
        ValueError: if the submission zip is invalid
    """
    submission_file_names = ['tcp_wer_hyp.json']
    optional_file_names = ['tc_orc_wer_ref.json']
    fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']

    for file_name in submission_file_names + optional_file_names:
        file_path = submission_dir / file_name
        if not file_path.exists():
            if file_name in submission_file_names:
                raise ValueError(f'Missing {file_name}')
            else:
                continue

        validate_json_file_structure(file_path, fields)


def validate_dasr_submission(submission_dir: Path):
    """
    Validates DASR submission format and contents
    Args:
        submission_dir: path to the submission directory
    Raises:
        ValueError: if the submission zip is invalid

    """
    submission_file_names = ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']
    fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']

    if not (submission_dir / 'dev').exists():
        raise ValueError('Missing `dev` directory, expecting a directory named `dev` with the submission files in it.')

    for file_name in submission_file_names:
        file_path = submission_dir / 'dev' / file_name
        if not file_path.exists():
            raise ValueError(f'Missing {file_name}')

        validate_json_file_structure(file_path, fields)


def validate_json_file_structure(file_path: Path, fields: List[str]):
    """
    Validates the structure of a json file
    Args:
        file_path: path to the json file
        fields: list of fields that are required in each entry
    Raises:
        ValueError: if the json file is invalid

    """
    with open(file_path, 'r') as json_file:
        json_data: List[Dict[str, Any]] = json.load(json_file)
        if not isinstance(json_data, list):
            raise ValueError(f'Invalid `{file_path.name}` format, expecting a list of entries')
        for data in json_data:
            if not all(field in data for field in fields):
                raise ValueError(f'Invalid `{file_path.name}` format, fields: {fields} are required in each entry')


####################################################################################################
# Tests
####################################################################################################

class TestValidateZip(unittest.TestCase):
    DATA_SAMPLES = 10

    @classmethod
    def setUpClass(cls):
        cls.valid_data = [{'session_id': 'session_id', 'words': 'words', 'speaker': 'speaker',
                           'start_time': 0.0, 'end_time': 1.0} for _ in range(cls.DATA_SAMPLES)]
        cls.invalid_data = [{'session_id': 'session_id', 'words': 'words',
                             'start_time': 0.0} for _ in range(cls.DATA_SAMPLES)]

    def setUp(self):
        self.temp_dir = TemporaryDirectory()
        self.submission_zip = Path(self.temp_dir.name) / 'submission.zip'

    def create_test_data(self, submission_track: str, data: List[Dict[str, Any]], json_file_names: List[str],
                         parent_zip_dir: str = None):
        submission_dir = Path(self.temp_dir.name) / submission_track
        os.makedirs(submission_dir, exist_ok=True)
        with ZipFile(self.submission_zip, 'w') as submission_zip_file:
            for json_file_name in json_file_names:
                if parent_zip_dir:
                    json_file_name = str(Path(parent_zip_dir) / json_file_name)
                submission_zip_file.writestr(json_file_name, json.dumps(data))
        return submission_track, self.submission_zip

    def tearDown(self):
        self.temp_dir.cleanup()

    def test_NOTSOFAR_SC_valid_data_tcp(self):
        self.assertEqual(validate_zip(*self.create_test_data(
            'NOTSOFAR-SC', self.valid_data, ['tcp_wer_hyp.json'])), None)

    def test_NOTSOFAR_SC_valid_data_tcp_and_tcorc(self):
        self.assertEqual(validate_zip(*self.create_test_data(
            'NOTSOFAR-SC', self.valid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'])), None)

    def test_NOTSOFAR_SC_missing_tcp_file(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data(
                'NOTSOFAR-SC', self.valid_data, ['tc_orc_wer_ref.json']))

    def test_NOTSOFAR_SC_invalid_data(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data(
                'NOTSOFAR-SC', self.invalid_data, ['tcp_wer_hyp.json']))

    def test_NOTSOFAR_MC_valid_data_tcp(self):
        self.assertEqual(validate_zip(*self.create_test_data(
            'NOTSOFAR-MC', self.valid_data, ['tcp_wer_hyp.json'])), None)

    def test_NOTSOFAR_MC_valid_data_tcp_and_tcorc(self):
        self.assertEqual(validate_zip(*self.create_test_data(
            'NOTSOFAR-MC', self.valid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'])), None)

    def test_NOTSOFAR_MC_missing_tcp_file(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data(
                'NOTSOFAR-MC', self.valid_data, ['tc_orc_wer_ref.json']))

    def test_NOTSOFAR_MC_invalid_data(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data(
                'NOTSOFAR-MC', self.invalid_data, ['tcp_wer_hyp.json']))

    def test_DASR_Constrained_LM_valid_data(self):
        self.assertEqual(validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
                                                             ['chime6.json', 'dipco.json', 'mixer6.json',
                                                              'notsofar1.json'], 'dev')), None)

    def test_DASR_Constrained_LM_invalid_data(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Constrained-LM', self.invalid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], 'dev'))

    def test_DASR_Constrained_LM_missing_dev_dir(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']))

    def test_DASR_Constrained_LM_missing_json_file(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json'], 'dev'))

    def test_DASR_Unconstrained_LM_valid_data(self):
        self.assertEqual(validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
                                                             ['chime6.json', 'dipco.json', 'mixer6.json',
                                                              'notsofar1.json'], 'dev')), None)

    def test_DASR_Unconstrained_LM_invalid_data(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.invalid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], 'dev'))

    def test_DASR_Unconstrained_LM_missing_dev_dir(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']))

    def test_DASR_Unconstrained_LM_missing_json_file(self):
        with self.assertRaises(ValueError):
            validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
                                                ['chime6.json', 'dipco.json', 'mixer6.json'], 'dev'))


if __name__ == '__main__':
    unittest.main()