victan commited on
Commit
3cbc121
1 Parent(s): 34a59b2

Upload seamless_communication/cli/streaming/scorers/seamless_quality_scorer.py with huggingface_hub

Browse files
seamless_communication/cli/streaming/scorers/seamless_quality_scorer.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from argparse import ArgumentParser, Namespace
11
+ from pathlib import Path
12
+ from typing import Dict, Optional
13
+
14
+ import pandas
15
+ from fairseq2.typing import Device
16
+ from seamless_communication.cli.eval_utils import compute_quality_metrics
17
+ from simuleval.evaluator.instance import LogInstance
18
+ from simuleval.evaluator.scorers.quality_scorer import (
19
+ QualityScorer,
20
+ register_quality_scorer,
21
+ )
22
+
23
+
24
+ @register_quality_scorer("SEAMLESS_QUALITY_SCORER")
25
+ class SeamlessQualityScorer(QualityScorer): # type: ignore
26
+ def __init__(
27
+ self,
28
+ tgt_lang: str,
29
+ task: str,
30
+ output_dir: str,
31
+ device: Device = "cuda:0",
32
+ whisper_model_name: str = "large",
33
+ whisper_normalize_text_output: Optional[bool] = None,
34
+ ref_text_col_name: str = "ref_tgt_text",
35
+ pred_text_col_name: str = "pred_tgt_text",
36
+ pred_audio_col_name: str = "pred_tgt_audio",
37
+ ) -> None:
38
+ super().__init__()
39
+ self.tgt_lang = tgt_lang
40
+ self.task = task.upper()
41
+ self.device = device
42
+ self.output_dir = Path(output_dir)
43
+ self.whisper_model_name = whisper_model_name
44
+ self.whisper_normalize_text_output = whisper_normalize_text_output
45
+ if self.whisper_normalize_text_output is None:
46
+ self.whisper_normalize_text_output = (
47
+ False if self.task in ["S2TT", "S2ST", "T2TT"] else True
48
+ )
49
+ self.ref_text_col_name = ref_text_col_name
50
+ self.pred_text_col_name = pred_text_col_name
51
+ self.pred_audio_col_name = pred_audio_col_name
52
+
53
+ def __call__(self, instances: Dict[int, LogInstance]) -> float:
54
+ references = [ins.reference for ins in instances.values()]
55
+ df = pandas.DataFrame({self.ref_text_col_name: references})
56
+ if self.task in ["ASR", "S2TT", "T2TT"]:
57
+ predictions = [ins.prediction for ins in instances.values()]
58
+ df[self.pred_text_col_name] = predictions
59
+ else:
60
+ predictions = [ins.prediction for ins in instances.values()]
61
+ df[self.pred_audio_col_name] = predictions
62
+
63
+ df.to_csv(
64
+ self.output_dir / "results.tsv",
65
+ sep="\t",
66
+ quoting=3,
67
+ encoding="utf-8",
68
+ )
69
+ filename = compute_quality_metrics(
70
+ self.output_dir / "results.tsv",
71
+ self.output_dir,
72
+ self.tgt_lang,
73
+ self.task,
74
+ self.device,
75
+ self.whisper_model_name,
76
+ self.whisper_normalize_text_output,
77
+ self.ref_text_col_name,
78
+ self.pred_text_col_name if self.task in ["ASR", "S2TT", "T2TT"] else None,
79
+ self.pred_audio_col_name,
80
+ )
81
+
82
+ with open(self.output_dir / filename, "r") as f:
83
+ corpus_metric_score = json.load(f)["score"]
84
+
85
+ return corpus_metric_score # type: ignore[no-any-return]
86
+
87
+ @staticmethod
88
+ def add_args(parser: ArgumentParser) -> None:
89
+ parser.add_argument("--task", type=str, help="Task to evaluate", required=True)
90
+ parser.add_argument(
91
+ "--tgt-lang",
92
+ type=str,
93
+ help="Target language to translate/transcribe into.",
94
+ required=True,
95
+ )
96
+ parser.add_argument(
97
+ "--whisper-model-name", type=str, help="Whisper model name", default="large"
98
+ )
99
+ parser.add_argument(
100
+ "--whisper-normalize-text-output",
101
+ action="store_true",
102
+ help="Normalize text output",
103
+ default=None,
104
+ )
105
+ parser.add_argument(
106
+ "--ref-text-col-name",
107
+ type=str,
108
+ help="Reference text column name",
109
+ default="ref_tgt_text",
110
+ )
111
+ parser.add_argument(
112
+ "--pred-text-col-name",
113
+ type=str,
114
+ help="Prediction text column name",
115
+ default="pred_tgt_text",
116
+ )
117
+ parser.add_argument(
118
+ "--pred-audio-col-name",
119
+ type=str,
120
+ help="Prediction audio column name",
121
+ default="pred_tgt_audio",
122
+ )
123
+
124
+ @classmethod
125
+ def from_args(cls, args: Namespace) -> SeamlessQualityScorer:
126
+ return cls(
127
+ tgt_lang=args.tgt_lang,
128
+ task=args.task,
129
+ output_dir=args.output,
130
+ device=getattr(args, "device", "cpu"),
131
+ whisper_model_name=args.whisper_model_name,
132
+ whisper_normalize_text_output=args.whisper_normalize_text_output,
133
+ ref_text_col_name=args.ref_text_col_name,
134
+ pred_text_col_name=args.pred_text_col_name,
135
+ pred_audio_col_name=args.pred_audio_col_name,
136
+ )