File size: 3,234 Bytes
19d7906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.

import argparse
import logging

from fairseq2.assets import asset_store, download_manager
from seamless_communication.cli.streaming.scorers.seamless_quality_scorer import (
    SeamlessQualityScorer,
)
from seamless_communication.streaming.agents.seamless_s2st import SeamlessS2STAgent
from seamless_communication.streaming.agents.seamless_streaming_s2st import (
    SeamlessStreamingS2STAgent,
)
from seamless_communication.streaming.agents.seamless_streaming_s2t import (
    SeamlessStreamingS2TAgent,
)
from simuleval.cli import evaluate

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
)

logger = logging.getLogger(__name__)


def main() -> None:
    parser = argparse.ArgumentParser(
        add_help=False,
        description="Streaming evaluation of Seamless UnitY models",
        conflict_handler="resolve",
    )

    parser.add_argument(
        "--task",
        choices=["s2st", "s2tt", "asr"],
        required=True,
        type=str,
        help="Target language to translate/transcribe into.",
    )
    parser.add_argument(
        "--expressive",
        action="store_true",
        default=False,
        help="Expressive streaming S2ST inference",
    )

    args, _ = parser.parse_known_args()

    model_configs = dict(
        source_segment_size=320,
        device="cuda:0",
        dtype="fp16",
        min_starting_wait_w2vbert=192,
        decision_threshold=0.5,
        no_early_stop=True,
        max_len_a=0,
        max_len_b=100,
    )

    eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER")
    if args.task == "s2st":
        model_configs["min_unit_chunk_size"] = 50
        eval_configs["latency_metrics"] = "StartOffset EndOffset"

        if args.expressive:
            agent_class = SeamlessS2STAgent
        else:
            agent_class = SeamlessStreamingS2STAgent
    elif args.task in ["s2tt", "asr"]:
        assert args.expressive is False, "S2TT inference cannot be expressive."
        agent_class = SeamlessStreamingS2TAgent
        parser.add_argument(
            "--unity-model-name",
            type=str,
            help="Unity model name.",
            default="seamless_streaming_unity",
        )
        args, _ = parser.parse_known_args()
        asset_card = asset_store.retrieve_card(name=args.unity_model_name)
        tokenizer_uri = asset_card.field("tokenizer").as_uri()
        tokenizer_path = download_manager.download_tokenizer(
            tokenizer_uri, asset_card.name, force=False, progress=True
        )
        eval_configs["latency_metrics"] = "AL LAAL"
        eval_configs["eval_latency_unit"] = "spm"
        eval_configs["eval_latency_spm_model"] = tokenizer_path

    base_config = dict(
        dataloader="fairseq2_s2tt",
        dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
    )

    evaluate(agent_class, {**base_config, **model_configs, **eval_configs}, parser)


if __name__ == "__main__":
    main()