|
import collections.abc |
|
from pathlib import Path |
|
from typing import Dict |
|
from typing import List |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import numpy as np |
|
import re |
|
from typeguard import check_argument_types |
|
|
|
|
|
def load_rttm_text(path: Union[Path, str]) -> Dict[str, List[Tuple[str, float, float]]]: |
|
"""Read a RTTM file |
|
|
|
Note: only support speaker information now |
|
""" |
|
|
|
assert check_argument_types() |
|
data = {} |
|
with Path(path).open("r", encoding="utf-8") as f: |
|
for linenum, line in enumerate(f, 1): |
|
sps = re.split(" +", line.rstrip()) |
|
|
|
|
|
assert len(sps) == 9, "{} does not have exactly 9 fields".format(path) |
|
label_type, utt_id, channel, start, end, _, _, spk_id, _ = sps |
|
|
|
|
|
assert label_type in ["SPEAKER", "END"] |
|
|
|
spk_list, spk_event, max_duration = data.get(utt_id, ([], [], 0)) |
|
if label_type == "END": |
|
data[utt_id] = (spk_list, spk_event, int(end)) |
|
continue |
|
if spk_id not in spk_list: |
|
spk_list.append(spk_id) |
|
|
|
data[utt_id] = ( |
|
spk_list, |
|
spk_event + [(spk_id, int(float(start)), int(float(end)))], |
|
max_duration, |
|
) |
|
|
|
return data |
|
|
|
|
|
class RttmReader(collections.abc.Mapping): |
|
"""Reader class for 'rttm.scp'. |
|
|
|
Examples: |
|
SPEAKER file1 1 0 1023 <NA> <NA> spk1 <NA> |
|
SPEAKER file1 2 4000 3023 <NA> <NA> spk2 <NA> |
|
SPEAKER file1 3 500 4023 <NA> <NA> spk1 <NA> |
|
END file1 <NA> 4023 <NA> <NA> <NA> <NA> |
|
|
|
This is an extend version of standard RTTM format for espnet. |
|
The difference including: |
|
1. Use sample number instead of absolute time |
|
2. has a END label to represent the duration of a recording |
|
3. replace duration (5th field) with end time |
|
(For standard RTTM, |
|
see https://catalog.ldc.upenn.edu/docs/LDC2004T12/RTTM-format-v13.pdf) |
|
... |
|
|
|
>>> reader = RttmReader('rttm') |
|
>>> spk_label = reader["file1"] |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
fname: str, |
|
): |
|
assert check_argument_types() |
|
super().__init__() |
|
|
|
self.fname = fname |
|
self.data = load_rttm_text(path=fname) |
|
|
|
def __getitem__(self, key): |
|
spk_list, spk_event, max_duration = self.data[key] |
|
spk_label = np.zeros((max_duration, len(spk_list))) |
|
for spk_id, start, end in spk_event: |
|
spk_label[start : end + 1, spk_list.index(spk_id)] = 1 |
|
return spk_label |
|
|
|
def __contains__(self, item): |
|
return item |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __iter__(self): |
|
return iter(self.data) |
|
|
|
def keys(self): |
|
return self.data.keys() |
|
|