| import re | |
| import xml.etree.ElementTree as ET | |
| from pathlib import Path | |
| from torch.utils.data.dataset import Dataset | |
| from torchaudio.sox_effects import apply_effects_file | |
| class SWS2013Testset(Dataset): | |
| """SWS 2013 testset.""" | |
| def __init__(self, split, **kwargs): | |
| assert split in ["dev", "eval"] | |
| scoring_root = Path(kwargs["sws2013_scoring_root"]) | |
| audio_names = parse_ecf(scoring_root / f"sws2013_{split}" / "sws2013.ecf.xml") | |
| query_names = parse_tlist( | |
| scoring_root / f"sws2013_{split}" / f"sws2013_{split}.tlist.xml" | |
| ) | |
| self.dataset_root = Path(kwargs["sws2013_root"]) | |
| self.split = split | |
| self.n_queries = len(query_names) | |
| self.n_docs = len(audio_names) | |
| self.data = query_names + audio_names | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| audio_name = self.data[idx] | |
| audio_path = ( | |
| (self.dataset_root / f"{self.split}_queries" / audio_name) | |
| if idx < self.n_queries | |
| else (self.dataset_root / "Audio" / audio_name) | |
| ) | |
| audio_path = audio_path.with_suffix(".wav") | |
| wav, _ = apply_effects_file( | |
| str(audio_path), | |
| [ | |
| ["channels", "1"], | |
| ["rate", "16000"], | |
| ["norm"], | |
| ["vad", "-T", "0.25", "-p", "0.1"], | |
| ["reverse"], | |
| ["vad", "-T", "0.25", "-p", "0.1"], | |
| ["reverse"], | |
| ["pad", "0", "3"], | |
| ], | |
| ) | |
| segments = wav.squeeze(0).unfold(0, 48000, 12000).unbind(0) | |
| return segments, len(segments), audio_name | |
| def collate_fn(self, samples): | |
| """Collate a mini-batch of data.""" | |
| segments, lengths, audio_names = zip(*samples) | |
| segments = [seg for segs in segments for seg in segs] | |
| return segments, (lengths, audio_names) | |
| def parse_ecf(ecf_path): | |
| """Find audio paths from sws2013.ecf.xml.""" | |
| root = ET.parse(str(ecf_path)).getroot() | |
| audio_names = [] | |
| for excerpt in root.findall("excerpt"): | |
| audio_name = ( | |
| excerpt.attrib["audio_filename"].replace("Audio/", "").replace(".wav", "") | |
| ) | |
| audio_names.append(audio_name) | |
| return audio_names | |
| def parse_tlist(tlist_path): | |
| """Find audio paths from sws2013_eval.tlist.xml.""" | |
| root = ET.parse(str(tlist_path)).getroot() | |
| audio_names = [] | |
| for term in root.findall("term"): | |
| audio_names.append(term.attrib["termid"]) | |
| return audio_names | |