victan commited on
Commit
20b6753
1 Parent(s): a0b6a86

Upload seamless_communication/cli/m4t/predict/predict.py with huggingface_hub

Browse files
seamless_communication/cli/m4t/predict/predict.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # MIT_LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import logging
8
+ from argparse import Namespace
9
+ from pathlib import Path
10
+ from typing import Tuple
11
+
12
+ import torch
13
+ import torchaudio
14
+ from fairseq2.generation import NGramRepeatBlockProcessor
15
+
16
+ from seamless_communication.inference import SequenceGeneratorOptions, Translator
17
+
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
21
+ )
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
27
+ parser.add_argument("--task", type=str, help="Task type")
28
+ parser.add_argument(
29
+ "--tgt_lang", type=str, help="Target language to translate/transcribe into."
30
+ )
31
+ parser.add_argument(
32
+ "--src_lang",
33
+ type=str,
34
+ help="Source language, only required if input is text.",
35
+ default=None,
36
+ )
37
+ parser.add_argument(
38
+ "--output_path",
39
+ type=Path,
40
+ help="Path to save the generated audio.",
41
+ default=None,
42
+ )
43
+ parser.add_argument(
44
+ "--model_name",
45
+ type=str,
46
+ help=(
47
+ "Base model name (`seamlessM4T_medium`, "
48
+ "`seamlessM4T_large`, `seamlessM4T_v2_large`)"
49
+ ),
50
+ default="seamlessM4T_v2_large",
51
+ )
52
+ parser.add_argument(
53
+ "--vocoder_name",
54
+ type=str,
55
+ help="Vocoder model name",
56
+ default="vocoder_v2",
57
+ )
58
+ # Text generation args.
59
+ parser.add_argument(
60
+ "--text_generation_beam_size",
61
+ type=int,
62
+ help="Beam size for incremental text decoding.",
63
+ default=5,
64
+ )
65
+ parser.add_argument(
66
+ "--text_generation_max_len_a",
67
+ type=int,
68
+ help="`a` in `ax + b` for incremental text decoding.",
69
+ default=1,
70
+ )
71
+ parser.add_argument(
72
+ "--text_generation_max_len_b",
73
+ type=int,
74
+ help="`b` in `ax + b` for incremental text decoding.",
75
+ default=200,
76
+ )
77
+ parser.add_argument(
78
+ "--text_generation_ngram_blocking",
79
+ type=bool,
80
+ help=(
81
+ "Enable ngram_repeat_block for incremental text decoding."
82
+ "This blocks hypotheses with repeating ngram tokens."
83
+ ),
84
+ default=False,
85
+ )
86
+ parser.add_argument(
87
+ "--no_repeat_ngram_size",
88
+ type=int,
89
+ help="Size of ngram repeat block for both text & unit decoding.",
90
+ default=4,
91
+ )
92
+ # Unit generation args.
93
+ parser.add_argument(
94
+ "--unit_generation_beam_size",
95
+ type=int,
96
+ help=(
97
+ "Beam size for incremental unit decoding"
98
+ "not applicable for the NAR T2U decoder."
99
+ ),
100
+ default=5,
101
+ )
102
+ parser.add_argument(
103
+ "--unit_generation_max_len_a",
104
+ type=int,
105
+ help=(
106
+ "`a` in `ax + b` for incremental unit decoding"
107
+ "not applicable for the NAR T2U decoder."
108
+ ),
109
+ default=25,
110
+ )
111
+ parser.add_argument(
112
+ "--unit_generation_max_len_b",
113
+ type=int,
114
+ help=(
115
+ "`b` in `ax + b` for incremental unit decoding"
116
+ "not applicable for the NAR T2U decoder."
117
+ ),
118
+ default=50,
119
+ )
120
+ parser.add_argument(
121
+ "--unit_generation_ngram_blocking",
122
+ type=bool,
123
+ help=(
124
+ "Enable ngram_repeat_block for incremental unit decoding."
125
+ "This blocks hypotheses with repeating ngram tokens."
126
+ ),
127
+ default=False,
128
+ )
129
+ parser.add_argument(
130
+ "--unit_generation_ngram_filtering",
131
+ type=bool,
132
+ help=(
133
+ "If True, removes consecutive repeated ngrams"
134
+ "from the decoded unit output."
135
+ ),
136
+ default=False,
137
+ )
138
+ parser.add_argument(
139
+ "--text_unk_blocking",
140
+ type=bool,
141
+ help=(
142
+ "If True, set penalty of UNK to inf in text generator "
143
+ "to block unk output."
144
+ ),
145
+ default=False,
146
+ )
147
+ return parser
148
+
149
+
150
+ def set_generation_opts(
151
+ args: Namespace,
152
+ ) -> Tuple[SequenceGeneratorOptions, SequenceGeneratorOptions]:
153
+ # Set text, unit generation opts.
154
+ text_generation_opts = SequenceGeneratorOptions(
155
+ beam_size=args.text_generation_beam_size,
156
+ soft_max_seq_len=(
157
+ args.text_generation_max_len_a,
158
+ args.text_generation_max_len_b,
159
+ ),
160
+ )
161
+ if args.text_unk_blocking:
162
+ text_generation_opts.unk_penalty = torch.inf
163
+ if args.text_generation_ngram_blocking:
164
+ text_generation_opts.step_processor = NGramRepeatBlockProcessor(
165
+ ngram_size=args.no_repeat_ngram_size
166
+ )
167
+
168
+ unit_generation_opts = SequenceGeneratorOptions(
169
+ beam_size=args.unit_generation_beam_size,
170
+ soft_max_seq_len=(
171
+ args.unit_generation_max_len_a,
172
+ args.unit_generation_max_len_b,
173
+ ),
174
+ )
175
+ if args.unit_generation_ngram_blocking:
176
+ unit_generation_opts.step_processor = NGramRepeatBlockProcessor(
177
+ ngram_size=args.no_repeat_ngram_size
178
+ )
179
+ return text_generation_opts, unit_generation_opts
180
+
181
+
182
+ def main() -> None:
183
+ parser = argparse.ArgumentParser(
184
+ description="M4T inference on supported tasks using Translator."
185
+ )
186
+ parser.add_argument("input", type=str, help="Audio WAV file path or text input.")
187
+
188
+ parser = add_inference_arguments(parser)
189
+ args = parser.parse_args()
190
+ if not args.task or not args.tgt_lang:
191
+ raise Exception(
192
+ "Please provide required arguments for evaluation - task, tgt_lang"
193
+ )
194
+
195
+ if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None:
196
+ raise ValueError("output_path must be provided to save the generated audio")
197
+
198
+ if torch.cuda.is_available():
199
+ device = torch.device("cuda:0")
200
+ dtype = torch.float16
201
+ else:
202
+ device = torch.device("cpu")
203
+ dtype = torch.float32
204
+
205
+ logger.info(f"Running inference on {device=} with {dtype=}.")
206
+
207
+ translator = Translator(args.model_name, args.vocoder_name, device, dtype=dtype)
208
+
209
+ text_generation_opts, unit_generation_opts = set_generation_opts(args)
210
+
211
+ logger.info(f"{text_generation_opts=}")
212
+ logger.info(f"{unit_generation_opts=}")
213
+ logger.info(
214
+ f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
215
+ )
216
+
217
+ text_output, speech_output = translator.predict(
218
+ args.input,
219
+ args.task,
220
+ args.tgt_lang,
221
+ src_lang=args.src_lang,
222
+ text_generation_opts=text_generation_opts,
223
+ unit_generation_opts=unit_generation_opts,
224
+ unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
225
+ )
226
+
227
+ if speech_output is not None:
228
+ logger.info(f"Saving translated audio in {args.tgt_lang}")
229
+ torchaudio.save(
230
+ args.output_path,
231
+ speech_output.audio_wavs[0][0].to(torch.float32).cpu(),
232
+ sample_rate=speech_output.sample_rate,
233
+ )
234
+ logger.info(f"Translated text in {args.tgt_lang}: {text_output[0]}")
235
+
236
+
237
+ if __name__ == "__main__":
238
+ main()