|
import logging |
|
from pathlib import Path |
|
from typing import Sequence |
|
from typing import Union |
|
import warnings |
|
|
|
import torch |
|
from typeguard import check_argument_types |
|
from typing import Collection |
|
|
|
from espnet2.train.reporter import Reporter |
|
|
|
|
|
@torch.no_grad() |
|
def average_nbest_models( |
|
output_dir: Path, |
|
reporter: Reporter, |
|
best_model_criterion: Sequence[Sequence[str]], |
|
nbest: Union[Collection[int], int], |
|
) -> None: |
|
"""Generate averaged model from n-best models |
|
|
|
Args: |
|
output_dir: The directory contains the model file for each epoch |
|
reporter: Reporter instance |
|
best_model_criterion: Give criterions to decide the best model. |
|
e.g. [("valid", "loss", "min"), ("train", "acc", "max")] |
|
nbest: |
|
""" |
|
assert check_argument_types() |
|
if isinstance(nbest, int): |
|
nbests = [nbest] |
|
else: |
|
nbests = list(nbest) |
|
if len(nbests) == 0: |
|
warnings.warn("At least 1 nbest values are required") |
|
nbests = [1] |
|
|
|
nbest_epochs = [ |
|
(ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) |
|
for ph, k, m in best_model_criterion |
|
if reporter.has(ph, k) |
|
] |
|
|
|
_loaded = {} |
|
for ph, cr, epoch_and_values in nbest_epochs: |
|
_nbests = [i for i in nbests if i <= len(epoch_and_values)] |
|
if len(_nbests) == 0: |
|
_nbests = [1] |
|
|
|
for n in _nbests: |
|
if n == 0: |
|
continue |
|
elif n == 1: |
|
|
|
e, _ = epoch_and_values[0] |
|
op = output_dir / f"{e}epoch.pth" |
|
sym_op = output_dir / f"{ph}.{cr}.ave_1best.pth" |
|
if sym_op.is_symlink() or sym_op.exists(): |
|
sym_op.unlink() |
|
sym_op.symlink_to(op.name) |
|
else: |
|
op = output_dir / f"{ph}.{cr}.ave_{n}best.pth" |
|
logging.info( |
|
f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' |
|
) |
|
|
|
avg = None |
|
|
|
for e, _ in epoch_and_values[:n]: |
|
if e not in _loaded: |
|
_loaded[e] = torch.load( |
|
output_dir / f"{e}epoch.pth", |
|
map_location="cpu", |
|
) |
|
states = _loaded[e] |
|
|
|
if avg is None: |
|
avg = states |
|
else: |
|
|
|
for k in avg: |
|
avg[k] = avg[k] + states[k] |
|
for k in avg: |
|
if str(avg[k].dtype).startswith("torch.int"): |
|
|
|
|
|
|
|
|
|
|
|
pass |
|
else: |
|
avg[k] = avg[k] / n |
|
|
|
|
|
torch.save(avg, op) |
|
|
|
|
|
op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.pth" |
|
sym_op = output_dir / f"{ph}.{cr}.ave.pth" |
|
if sym_op.is_symlink() or sym_op.exists(): |
|
sym_op.unlink() |
|
sym_op.symlink_to(op.name) |
|
|