Upload seamless_communication/cli/streaming/evaluate.py with huggingface_hub
Browse files
seamless_communication/cli/streaming/evaluate.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# MIT_LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import logging
|
9 |
+
|
10 |
+
from fairseq2.assets import asset_store, download_manager
|
11 |
+
from seamless_communication.cli.streaming.scorers.seamless_quality_scorer import (
|
12 |
+
SeamlessQualityScorer,
|
13 |
+
)
|
14 |
+
from seamless_communication.streaming.agents.seamless_s2st import SeamlessS2STAgent
|
15 |
+
from seamless_communication.streaming.agents.seamless_streaming_s2st import (
|
16 |
+
SeamlessStreamingS2STAgent,
|
17 |
+
)
|
18 |
+
from seamless_communication.streaming.agents.seamless_streaming_s2t import (
|
19 |
+
SeamlessStreamingS2TAgent,
|
20 |
+
)
|
21 |
+
from simuleval.cli import evaluate
|
22 |
+
|
23 |
+
logging.basicConfig(
|
24 |
+
level=logging.INFO,
|
25 |
+
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
|
26 |
+
)
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
def main() -> None:
|
32 |
+
parser = argparse.ArgumentParser(
|
33 |
+
add_help=False,
|
34 |
+
description="Streaming evaluation of Seamless UnitY models",
|
35 |
+
conflict_handler="resolve",
|
36 |
+
)
|
37 |
+
|
38 |
+
parser.add_argument(
|
39 |
+
"--task",
|
40 |
+
choices=["s2st", "s2tt", "asr"],
|
41 |
+
required=True,
|
42 |
+
type=str,
|
43 |
+
help="Target language to translate/transcribe into.",
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--expressive",
|
47 |
+
action="store_true",
|
48 |
+
default=False,
|
49 |
+
help="Expressive streaming S2ST inference",
|
50 |
+
)
|
51 |
+
|
52 |
+
args, _ = parser.parse_known_args()
|
53 |
+
|
54 |
+
model_configs = dict(
|
55 |
+
source_segment_size=320,
|
56 |
+
device="cuda:0",
|
57 |
+
dtype="fp16",
|
58 |
+
min_starting_wait_w2vbert=192,
|
59 |
+
decision_threshold=0.5,
|
60 |
+
no_early_stop=True,
|
61 |
+
max_len_a=0,
|
62 |
+
max_len_b=100,
|
63 |
+
)
|
64 |
+
|
65 |
+
eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER")
|
66 |
+
if args.task == "s2st":
|
67 |
+
model_configs["min_unit_chunk_size"] = 50
|
68 |
+
eval_configs["latency_metrics"] = "StartOffset EndOffset"
|
69 |
+
|
70 |
+
if args.expressive:
|
71 |
+
agent_class = SeamlessS2STAgent
|
72 |
+
else:
|
73 |
+
agent_class = SeamlessStreamingS2STAgent
|
74 |
+
elif args.task in ["s2tt", "asr"]:
|
75 |
+
assert args.expressive is False, "S2TT inference cannot be expressive."
|
76 |
+
agent_class = SeamlessStreamingS2TAgent
|
77 |
+
parser.add_argument(
|
78 |
+
"--unity-model-name",
|
79 |
+
type=str,
|
80 |
+
help="Unity model name.",
|
81 |
+
default="seamless_streaming_unity",
|
82 |
+
)
|
83 |
+
args, _ = parser.parse_known_args()
|
84 |
+
asset_card = asset_store.retrieve_card(name=args.unity_model_name)
|
85 |
+
tokenizer_uri = asset_card.field("tokenizer").as_uri()
|
86 |
+
tokenizer_path = download_manager.download_tokenizer(
|
87 |
+
tokenizer_uri, asset_card.name, force=False, progress=True
|
88 |
+
)
|
89 |
+
eval_configs["latency_metrics"] = "AL LAAL"
|
90 |
+
eval_configs["eval_latency_unit"] = "spm"
|
91 |
+
eval_configs["eval_latency_spm_model"] = tokenizer_path
|
92 |
+
|
93 |
+
base_config = dict(
|
94 |
+
dataloader="fairseq2_s2tt",
|
95 |
+
dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
|
96 |
+
)
|
97 |
+
|
98 |
+
evaluate(agent_class, {**base_config, **model_configs, **eval_configs}, parser)
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
main()
|