Source code for espnet2.fileio.rttm

import collections.abc
from pathlib import Path
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union

import numpy as np
import re
from typeguard import check_argument_types


[docs]def load_rttm_text(path: Union[Path, str]) -> Dict[str, List[Tuple[str, float, float]]]: """Read a RTTM file Note: only support speaker information now """ assert check_argument_types() data = {} with Path(path).open("r", encoding="utf-8") as f: for linenum, line in enumerate(f, 1): sps = re.split(" +", line.rstrip()) # RTTM format must have exactly 9 fields assert len(sps) == 9, "{} does not have exactly 9 fields".format(path) label_type, utt_id, channel, start, end, _, _, spk_id, _ = sps # Only support speaker label now assert label_type in ["SPEAKER", "END"] spk_list, spk_event, max_duration = data.get(utt_id, ([], [], 0)) if label_type == "END": data[utt_id] = (spk_list, spk_event, int(end)) continue if spk_id not in spk_list: spk_list.append(spk_id) data[utt_id] = ( spk_list, spk_event + [(spk_id, int(float(start)), int(float(end)))], max_duration, ) return data
[docs]class RttmReader(collections.abc.Mapping): """Reader class for 'rttm.scp'. Examples: SPEAKER file1 1 0 1023 <NA> <NA> spk1 <NA> SPEAKER file1 2 4000 3023 <NA> <NA> spk2 <NA> SPEAKER file1 3 500 4023 <NA> <NA> spk1 <NA> END file1 <NA> 4023 <NA> <NA> <NA> <NA> This is an extend version of standard RTTM format for espnet. The difference including: 1. Use sample number instead of absolute time 2. has a END label to represent the duration of a recording 3. replace duration (5th field) with end time (For standard RTTM, see https://catalog.ldc.upenn.edu/docs/LDC2004T12/RTTM-format-v13.pdf) ... >>> reader = RttmReader('rttm') >>> spk_label = reader["file1"] """ def __init__( self, fname: str, ): assert check_argument_types() super().__init__() self.fname = fname self.data = load_rttm_text(path=fname) def __getitem__(self, key): spk_list, spk_event, max_duration = self.data[key] spk_label = np.zeros((max_duration, len(spk_list))) for spk_id, start, end in spk_event: spk_label[start : end + 1, spk_list.index(spk_id)] = 1 return spk_label def __contains__(self, item): return item def __len__(self): return len(self.data) def __iter__(self): return iter(self.data)
[docs] def keys(self): return self.data.keys()