File size: 5,704 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script is based on speech_to_text_eval.py and allows you to score the hypotheses
with sclite. A local installation from https://github.com/usnistgov/SCTK is required.
Hypotheses and references are first saved in trn format and are scored after applying a glm
file (if provided).

# Usage

python speech_to_text_sclite.py \
    --asr_model="<Path to ASR Model>" \
    --dataset="<Path to manifest file>" \
    --out_dir="<Path to output dir, should be unique per model evaluated>" \
    --sctk_dir="<Path to root directory where SCTK is installed>" \
    --glm="<OPTIONAL: Path to glm file>" \
    --batch_size=4

"""

import errno
import json
import os
import subprocess
from argparse import ArgumentParser

import torch

from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest
from nemo.utils import logging

try:
    from torch.cuda.amp import autocast
except ImportError:
    from contextlib import contextmanager

    @contextmanager
    def autocast(enabled=None):
        yield


def score_with_sctk(sctk_dir, ref_fname, hyp_fname, out_dir, glm=""):
    sclite_path = os.path.join(sctk_dir, "bin", "sclite")
    if not os.path.exists(sclite_path):
        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), sclite_path)
    # apply glm
    if os.path.exists(glm):
        rfilter_path = os.path.join(sctk_dir, "bin", "rfilter1")
        if not os.path.exists(rfilter_path):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), rfilter_path)
        hypglm = os.path.join(out_dir, os.path.basename(hyp_fname)) + ".glm"
        rfilt_cmd = [rfilter_path] + [glm]
        with open(hypglm, "w", encoding='utf-8') as hypf, open(hyp_fname, "r", encoding='utf-8') as hyp_in:
            subprocess.run(rfilt_cmd, stdin=hyp_in, stdout=hypf)
        refglm = os.path.join(out_dir, os.path.basename(ref_fname)) + ".glm"
        with open(refglm, "w", encoding='utf-8') as reff, open(ref_fname, "r", encoding='utf-8') as ref_in:
            subprocess.run(rfilt_cmd, stdin=ref_in, stdout=reff)
    else:
        refglm = ref_fname
        hypglm = hyp_fname

    _ = subprocess.check_output(f"{sclite_path} -h {hypglm}  -r {refglm} -i wsj -o all", shell=True)


can_gpu = torch.cuda.is_available()


def get_utt_info(manifest_path):
    info_list = []
    with open(manifest_path, "r", encoding='utf-8') as utt_f:
        for line in utt_f:
            utt = json.loads(line)
            info_list.append(utt)

    return info_list


def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model", type=str, default="QuartzNet15x5Base-En", required=False, help="Pass: 'QuartzNet15x5Base-En'",
    )
    parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--out_dir", type=str, required=True, help="Destination dir for output files")
    parser.add_argument("--sctk_dir", type=str, required=False, default="", help="Path to sctk root dir")
    parser.add_argument("--glm", type=str, required=False, default="", help="Path to glm file")
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir, exist_ok=True)

    use_sctk = os.path.exists(args.sctk_dir)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = ASRModel.restore_from(restore_path=args.asr_model, map_location='cpu')
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = ASRModel.from_pretrained(model_name=args.asr_model, map_location='cpu')

    if can_gpu:
        asr_model = asr_model.cuda()

    asr_model.eval()

    manifest_data = read_manifest(args.dataset)

    references = [data['text'] for data in manifest_data]
    audio_filepaths = [data['audio_filepath'] for data in manifest_data]

    with autocast():
        hypotheses = asr_model.transcribe(audio_filepaths, batch_size=args.batch_size)

        # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
        if type(hypotheses) == tuple and len(hypotheses) == 2:
            hypotheses = hypotheses[0]

    info_list = get_utt_info(args.dataset)
    hypfile = os.path.join(args.out_dir, "hyp.trn")
    reffile = os.path.join(args.out_dir, "ref.trn")
    with open(hypfile, "w") as hyp_f, open(reffile, "w") as ref_f:
        for i in range(len(hypotheses)):
            utt_id = os.path.splitext(os.path.basename(info_list[i]['audio_filepath']))[0]
            # rfilter in sctk likes each transcript to have a space at the beginning
            hyp_f.write(" " + hypotheses[i] + " (" + utt_id + ")" + "\n")
            ref_f.write(" " + references[i] + " (" + utt_id + ")" + "\n")

    if use_sctk:
        score_with_sctk(args.sctk_dir, reffile, hypfile, args.out_dir, glm=args.glm)


if __name__ == '__main__':
    main()  # noqa pylint: disable=no-value-for-parameter