Source code for espnet2.main_funcs.collect_stats

from collections import defaultdict
import logging
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple

import numpy as np
import torch
from torch.nn.parallel import data_parallel
from torch.utils.data import DataLoader
from typeguard import check_argument_types

from espnet2.fileio.datadir_writer import DatadirWriter
from espnet2.fileio.npy_scp import NpyScpWriter
from espnet2.torch_utils.device_funcs import to_device
from espnet2.torch_utils.forward_adaptor import ForwardAdaptor
from espnet2.train.abs_espnet_model import AbsESPnetModel


[docs]@torch.no_grad() def collect_stats( model: AbsESPnetModel, train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], output_dir: Path, ngpu: Optional[int], log_interval: Optional[int], write_collected_feats: bool, ) -> None: """Perform on collect_stats mode. Running for deriving the shape information from data and gathering statistics. This method is used before executing train(). """ assert check_argument_types() npy_scp_writers = {} for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]): if log_interval is None: try: log_interval = max(len(itr) // 20, 10) except TypeError: log_interval = 100 sum_dict = defaultdict(lambda: 0) sq_dict = defaultdict(lambda: 0) count_dict = defaultdict(lambda: 0) with DatadirWriter(output_dir / mode) as datadir_writer: for iiter, (keys, batch) in enumerate(itr, 1): batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") # 1. Write shape file for name in batch: if name.endswith("_lengths"): continue for i, (key, data) in enumerate(zip(keys, batch[name])): if f"{name}_lengths" in batch: lg = int(batch[f"{name}_lengths"][i]) data = data[:lg] datadir_writer[f"{name}_shape"][key] = ",".join( map(str, data.shape) ) # 2. Extract feats if ngpu <= 1: data = model.collect_feats(**batch) else: # Note that data_parallel can parallelize only "forward()" data = data_parallel( ForwardAdaptor(model, "collect_feats"), (), range(ngpu), module_kwargs=batch, ) # 3. Calculate sum and square sum for key, v in data.items(): for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())): # Truncate zero-padding region if f"{key}_lengths" in data: length = data[f"{key}_lengths"][i] # seq: (Length, Dim, ...) seq = seq[:length] else: # seq: (Dim, ...) -> (1, Dim, ...) seq = seq[None] # Accumulate value, its square, and count sum_dict[key] += seq.sum(0) sq_dict[key] += (seq**2).sum(0) count_dict[key] += len(seq) # 4. [Option] Write derived features as npy format file. if write_collected_feats: # Instantiate NpyScpWriter for the first iteration if (key, mode) not in npy_scp_writers: p = output_dir / mode / "collect_feats" npy_scp_writers[(key, mode)] = NpyScpWriter( p / f"data_{key}", p / f"{key}.scp" ) # Save array as npy file npy_scp_writers[(key, mode)][uttid] = seq if iiter % log_interval == 0: logging.info(f"Niter: {iiter}") for key in sum_dict: np.savez( output_dir / mode / f"{key}_stats.npz", count=count_dict[key], sum=sum_dict[key], sum_square=sq_dict[key], ) # batch_keys and stats_keys are used by aggregate_stats_dirs.py with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f: f.write( "\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n" ) with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f: f.write("\n".join(sum_dict) + "\n")