import collections.abc from pathlib import Path from typing import Dict from typing import List from typing import Tuple from typing import Union import numpy as np import re from typeguard import check_argument_types def load_rttm_text(path: Union[Path, str]) -> Dict[str, List[Tuple[str, float, float]]]: """Read a RTTM file Note: only support speaker information now """ assert check_argument_types() data = {} with Path(path).open("r", encoding="utf-8") as f: for linenum, line in enumerate(f, 1): sps = re.split(" +", line.rstrip()) # RTTM format must have exactly 9 fields assert len(sps) == 9, "{} does not have exactly 9 fields".format(path) label_type, utt_id, channel, start, end, _, _, spk_id, _ = sps # Only support speaker label now assert label_type in ["SPEAKER", "END"] spk_list, spk_event, max_duration = data.get(utt_id, ([], [], 0)) if label_type == "END": data[utt_id] = (spk_list, spk_event, int(end)) continue if spk_id not in spk_list: spk_list.append(spk_id) data[utt_id] = ( spk_list, spk_event + [(spk_id, int(float(start)), int(float(end)))], max_duration, ) return data class RttmReader(collections.abc.Mapping): """Reader class for 'rttm.scp'. Examples: SPEAKER file1 1 0 1023 spk1 SPEAKER file1 2 4000 3023 spk2 SPEAKER file1 3 500 4023 spk1 END file1 4023 This is an extend version of standard RTTM format for espnet. The difference including: 1. Use sample number instead of absolute time 2. has a END label to represent the duration of a recording 3. replace duration (5th field) with end time (For standard RTTM, see https://catalog.ldc.upenn.edu/docs/LDC2004T12/RTTM-format-v13.pdf) ... >>> reader = RttmReader('rttm') >>> spk_label = reader["file1"] """ def __init__( self, fname: str, ): assert check_argument_types() super().__init__() self.fname = fname self.data = load_rttm_text(path=fname) def __getitem__(self, key): spk_list, spk_event, max_duration = self.data[key] spk_label = np.zeros((max_duration, len(spk_list))) for spk_id, start, end in spk_event: spk_label[start : end + 1, spk_list.index(spk_id)] = 1 return spk_label def __contains__(self, item): return item def __len__(self): return len(self.data) def __iter__(self): return iter(self.data) def keys(self): return self.data.keys()