victan commited on
Commit
19d7906
1 Parent(s): d698c13

Upload seamless_communication/cli/streaming/evaluate.py with huggingface_hub

Browse files
seamless_communication/cli/streaming/evaluate.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+
10
+ from fairseq2.assets import asset_store, download_manager
11
+ from seamless_communication.cli.streaming.scorers.seamless_quality_scorer import (
12
+ SeamlessQualityScorer,
13
+ )
14
+ from seamless_communication.streaming.agents.seamless_s2st import SeamlessS2STAgent
15
+ from seamless_communication.streaming.agents.seamless_streaming_s2st import (
16
+ SeamlessStreamingS2STAgent,
17
+ )
18
+ from seamless_communication.streaming.agents.seamless_streaming_s2t import (
19
+ SeamlessStreamingS2TAgent,
20
+ )
21
+ from simuleval.cli import evaluate
22
+
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
26
+ )
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def main() -> None:
32
+ parser = argparse.ArgumentParser(
33
+ add_help=False,
34
+ description="Streaming evaluation of Seamless UnitY models",
35
+ conflict_handler="resolve",
36
+ )
37
+
38
+ parser.add_argument(
39
+ "--task",
40
+ choices=["s2st", "s2tt", "asr"],
41
+ required=True,
42
+ type=str,
43
+ help="Target language to translate/transcribe into.",
44
+ )
45
+ parser.add_argument(
46
+ "--expressive",
47
+ action="store_true",
48
+ default=False,
49
+ help="Expressive streaming S2ST inference",
50
+ )
51
+
52
+ args, _ = parser.parse_known_args()
53
+
54
+ model_configs = dict(
55
+ source_segment_size=320,
56
+ device="cuda:0",
57
+ dtype="fp16",
58
+ min_starting_wait_w2vbert=192,
59
+ decision_threshold=0.5,
60
+ no_early_stop=True,
61
+ max_len_a=0,
62
+ max_len_b=100,
63
+ )
64
+
65
+ eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER")
66
+ if args.task == "s2st":
67
+ model_configs["min_unit_chunk_size"] = 50
68
+ eval_configs["latency_metrics"] = "StartOffset EndOffset"
69
+
70
+ if args.expressive:
71
+ agent_class = SeamlessS2STAgent
72
+ else:
73
+ agent_class = SeamlessStreamingS2STAgent
74
+ elif args.task in ["s2tt", "asr"]:
75
+ assert args.expressive is False, "S2TT inference cannot be expressive."
76
+ agent_class = SeamlessStreamingS2TAgent
77
+ parser.add_argument(
78
+ "--unity-model-name",
79
+ type=str,
80
+ help="Unity model name.",
81
+ default="seamless_streaming_unity",
82
+ )
83
+ args, _ = parser.parse_known_args()
84
+ asset_card = asset_store.retrieve_card(name=args.unity_model_name)
85
+ tokenizer_uri = asset_card.field("tokenizer").as_uri()
86
+ tokenizer_path = download_manager.download_tokenizer(
87
+ tokenizer_uri, asset_card.name, force=False, progress=True
88
+ )
89
+ eval_configs["latency_metrics"] = "AL LAAL"
90
+ eval_configs["eval_latency_unit"] = "spm"
91
+ eval_configs["eval_latency_spm_model"] = tokenizer_path
92
+
93
+ base_config = dict(
94
+ dataloader="fairseq2_s2tt",
95
+ dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
96
+ )
97
+
98
+ evaluate(agent_class, {**base_config, **model_configs, **eval_configs}, parser)
99
+
100
+
101
+ if __name__ == "__main__":
102
+ main()