victan commited on
Commit
7ef749b
1 Parent(s): 1d81c94

Upload seamless_communication/inference/translator.py with huggingface_hub

Browse files
seamless_communication/inference/translator.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # This source code is licensed under the license found in the
4
+ # MIT_LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from enum import Enum, auto
9
+ from pathlib import Path
10
+ from typing import List, Optional, Tuple, Union, cast
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from fairseq2.assets import asset_store
15
+ from fairseq2.assets.card import AssetCard
16
+ from fairseq2.data import Collater, SequenceData, StringLike
17
+ from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
18
+ from fairseq2.data.text import TextTokenizer
19
+ from fairseq2.memory import MemoryBlock
20
+ from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
21
+ from fairseq2.typing import DataType, Device
22
+ from torch import Tensor
23
+
24
+ from seamless_communication.inference.generator import (
25
+ SequenceGeneratorOptions,
26
+ UnitYGenerator,
27
+ )
28
+ from seamless_communication.models.unity import (
29
+ UnitTokenizer,
30
+ UnitYModel,
31
+ UnitYNART2UModel,
32
+ UnitYT2UModel,
33
+ load_unity_model,
34
+ load_unity_text_tokenizer,
35
+ load_unity_unit_tokenizer,
36
+ unity_archs,
37
+ )
38
+ from seamless_communication.models.vocoder import load_vocoder_model
39
+ from seamless_communication.toxicity import (
40
+ ETOXBadWordChecker,
41
+ load_etox_bad_word_checker,
42
+ )
43
+ from seamless_communication.toxicity.mintox import mintox_pipeline
44
+
45
+ logging.basicConfig(
46
+ level=logging.INFO,
47
+ format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
48
+ )
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ class Task(Enum):
54
+ S2ST = auto()
55
+ S2TT = auto()
56
+ T2ST = auto()
57
+ T2TT = auto()
58
+ ASR = auto()
59
+
60
+
61
+ class Modality(Enum):
62
+ SPEECH = "speech"
63
+ TEXT = "text"
64
+
65
+
66
+ @dataclass
67
+ class BatchedSpeechOutput:
68
+ units: List[List[int]]
69
+ """The batched list of generated units."""
70
+
71
+ audio_wavs: List[Tensor]
72
+ """The batched list of audio waveforms."""
73
+
74
+ sample_rate: int = 16000
75
+ """Sample rate of the audio waveforms."""
76
+
77
+
78
+ class Translator(nn.Module):
79
+ def __init__(
80
+ self,
81
+ model_name_or_card: Union[str, AssetCard],
82
+ vocoder_name_or_card: Union[str, AssetCard, None],
83
+ device: Device,
84
+ text_tokenizer: Optional[TextTokenizer] = None,
85
+ apply_mintox: bool = False,
86
+ dtype: DataType = torch.float16,
87
+ input_modality: Optional[Modality] = None,
88
+ output_modality: Optional[Modality] = None,
89
+ ):
90
+ super().__init__()
91
+
92
+ if isinstance(model_name_or_card, str):
93
+ model_name_or_card = asset_store.retrieve_card(model_name_or_card)
94
+
95
+ assert isinstance(model_name_or_card, AssetCard)
96
+
97
+ if input_modality or output_modality:
98
+ unity_config = unity_archs.get_config(
99
+ model_name_or_card.field("model_arch").as_(str)
100
+ )
101
+ # Skip loading the text encoder.
102
+ if input_modality == Modality.SPEECH:
103
+ unity_config.use_text_encoder = False
104
+ # Skip loading the T2U model.
105
+ if output_modality == Modality.TEXT:
106
+ unity_config.t2u_config = None
107
+ model_name_or_card.field("model_config").set(unity_config)
108
+
109
+ # Load the model.
110
+ if device == torch.device("cpu"):
111
+ dtype = torch.float32
112
+
113
+ self.model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
114
+ self.model.eval()
115
+ assert isinstance(self.model, UnitYModel)
116
+
117
+ if text_tokenizer is None:
118
+ self.text_tokenizer: TextTokenizer = load_unity_text_tokenizer(
119
+ model_name_or_card
120
+ )
121
+ else:
122
+ self.text_tokenizer = text_tokenizer
123
+
124
+ self.unit_tokenizer: Optional[UnitTokenizer] = None
125
+ if self.model.t2u_model is not None:
126
+ self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
127
+
128
+ self.bad_word_checker: Optional[ETOXBadWordChecker] = None
129
+ if apply_mintox:
130
+ self.bad_word_checker = load_etox_bad_word_checker("mintox")
131
+
132
+ self.apply_mintox = apply_mintox
133
+
134
+ self.device = device
135
+ self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
136
+ self.convert_to_fbank = WaveformToFbankConverter(
137
+ num_mel_bins=80,
138
+ waveform_scale=2**15,
139
+ channel_last=True,
140
+ standardize=True,
141
+ device=device,
142
+ dtype=dtype,
143
+ )
144
+ self.collate = Collater(
145
+ pad_value=self.text_tokenizer.vocab_info.pad_idx or 0, pad_to_multiple=2
146
+ )
147
+ self.vocoder = None
148
+ if vocoder_name_or_card is not None and (
149
+ output_modality is None or output_modality == Modality.SPEECH
150
+ ):
151
+ self.vocoder = load_vocoder_model(
152
+ vocoder_name_or_card, device=device, dtype=dtype
153
+ )
154
+ self.vocoder.eval()
155
+
156
+ @classmethod
157
+ def get_prediction(
158
+ cls,
159
+ model: UnitYModel,
160
+ text_tokenizer: TextTokenizer,
161
+ unit_tokenizer: Optional[UnitTokenizer],
162
+ seqs: Tensor,
163
+ padding_mask: Optional[PaddingMask],
164
+ input_modality: Modality,
165
+ output_modality: Modality,
166
+ tgt_lang: str,
167
+ text_generation_opts: SequenceGeneratorOptions,
168
+ unit_generation_opts: Optional[SequenceGeneratorOptions],
169
+ unit_generation_ngram_filtering: bool = False,
170
+ duration_factor: float = 1.0,
171
+ prosody_encoder_input: Optional[SequenceData] = None,
172
+ ) -> Tuple[List[StringLike], Optional[Tensor]]:
173
+ # We disregard unit generations opts for the NAR T2U decoder.
174
+ if output_modality != Modality.SPEECH or isinstance(
175
+ model.t2u_model, UnitYNART2UModel
176
+ ):
177
+ unit_generation_opts = None
178
+
179
+ generator = UnitYGenerator(
180
+ model,
181
+ text_tokenizer,
182
+ tgt_lang,
183
+ unit_tokenizer if output_modality == Modality.SPEECH else None,
184
+ text_opts=text_generation_opts,
185
+ unit_opts=unit_generation_opts,
186
+ )
187
+
188
+ return generator(
189
+ seqs,
190
+ padding_mask,
191
+ input_modality.value,
192
+ output_modality.value,
193
+ ngram_filtering=unit_generation_ngram_filtering,
194
+ duration_factor=duration_factor,
195
+ prosody_encoder_input=prosody_encoder_input,
196
+ )
197
+
198
+ @staticmethod
199
+ def get_modalities_from_task_str(task_str: str) -> Tuple[Modality, Modality]:
200
+ try:
201
+ task = Task[task_str.upper()]
202
+ except KeyError:
203
+ raise ValueError(f"Unsupported task: {task_str}")
204
+
205
+ if task == Task.S2ST:
206
+ return Modality.SPEECH, Modality.SPEECH
207
+ # ASR is treated as S2TT with src_lang == tgt_lang
208
+ elif task == Task.S2TT or task == Task.ASR:
209
+ return Modality.SPEECH, Modality.TEXT
210
+ elif task == Task.T2TT:
211
+ return Modality.TEXT, Modality.TEXT
212
+ else:
213
+ return Modality.TEXT, Modality.SPEECH
214
+
215
+ @torch.inference_mode()
216
+ def predict(
217
+ self,
218
+ input: Union[str, Tensor, SequenceData],
219
+ task_str: str,
220
+ tgt_lang: str,
221
+ src_lang: Optional[str] = None,
222
+ text_generation_opts: Optional[SequenceGeneratorOptions] = None,
223
+ unit_generation_opts: Optional[SequenceGeneratorOptions] = None,
224
+ spkr: Optional[int] = -1,
225
+ sample_rate: int = 16000,
226
+ unit_generation_ngram_filtering: bool = False,
227
+ duration_factor: float = 1.0,
228
+ prosody_encoder_input: Optional[SequenceData] = None,
229
+ src_text: Optional[StringLike] = None,
230
+ ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
231
+ """
232
+ The main method used to perform inference on all tasks.
233
+
234
+ :param input:
235
+ Either text or path to audio or audio Tensor.
236
+ :param task_str:
237
+ String representing the task.
238
+ Valid choices are "S2ST", "S2TT", "T2ST", "T2TT", "ASR"
239
+ :param tgt_lang:
240
+ Target language to decode into.
241
+ :param src_lang:
242
+ Source language of input, only required for T2ST, T2TT tasks.
243
+ :param text_generation_opts:
244
+ Text generation hyperparameters for incremental decoding.
245
+ :param unit_generation_opts:
246
+ Unit generation hyperparameters for incremental decoding.
247
+ :param spkr:
248
+ Speaker id for vocoder.
249
+ :param unit_generation_ngram_filtering:
250
+ If True, removes consecutive repeated ngrams
251
+ from the decoded unit output.
252
+ :param src_text:
253
+ Optional source transcript (obtained by ASR for instance). This is used for
254
+ applying mintox toxicity mitigation. If this is not specify and apply_mintox=True
255
+ then src_lang must be specified and ASR will be run on the audio source.
256
+
257
+ :returns:
258
+ - Batched list of Translated text.
259
+ - Translated BatchedSpeechOutput.
260
+ """
261
+ input_modality, output_modality = self.get_modalities_from_task_str(task_str)
262
+
263
+ if self.apply_mintox and not (src_lang is not None or src_text is not None):
264
+ raise ValueError(
265
+ "`src_lang` must be specified when `apply_mintox` is `True` or you need to specify src_text."
266
+ )
267
+
268
+ if isinstance(input, dict):
269
+ src = cast(SequenceData, input)
270
+ elif input_modality == Modality.SPEECH:
271
+ audio = input
272
+ if isinstance(audio, str):
273
+ with Path(audio).open("rb") as fb:
274
+ block = MemoryBlock(fb.read())
275
+ decoded_audio = self.decode_audio(block)
276
+ else:
277
+ assert (
278
+ audio.dim() <= 2
279
+ ), "The audio tensor can't be more than 2 dimensions."
280
+ if audio.dim() == 1:
281
+ audio = audio.unsqueeze(1)
282
+ elif audio.dim() == 2 and audio.size(0) < audio.size(1):
283
+ logger.warning(
284
+ "Transposing audio tensor from (bsz, seq_len) -> (seq_len, bsz)."
285
+ )
286
+ audio = audio.transpose(0, 1)
287
+
288
+ decoded_audio = {
289
+ "waveform": audio,
290
+ "sample_rate": sample_rate,
291
+ "format": -1,
292
+ }
293
+ src = self.collate(self.convert_to_fbank(decoded_audio))["fbank"]
294
+ else:
295
+ if src_lang is None:
296
+ raise ValueError("src_lang must be specified for T2ST, T2TT tasks.")
297
+
298
+ text = input
299
+ assert isinstance(text, str)
300
+
301
+ self.token_encoder = self.text_tokenizer.create_encoder(
302
+ task="translation", lang=src_lang, mode="source", device=self.device
303
+ )
304
+ src = self.collate(self.token_encoder(text))
305
+
306
+ assert isinstance(self.model, UnitYModel)
307
+
308
+ seqs, padding_mask = get_seqs_and_padding_mask(src)
309
+
310
+ if text_generation_opts is None:
311
+ text_generation_opts = SequenceGeneratorOptions(
312
+ beam_size=5, soft_max_seq_len=(1, 200)
313
+ )
314
+ if unit_generation_opts is None:
315
+ unit_generation_opts = SequenceGeneratorOptions(
316
+ beam_size=5, soft_max_seq_len=(25, 50)
317
+ )
318
+
319
+ texts, units = self.get_prediction(
320
+ self.model,
321
+ self.text_tokenizer,
322
+ self.unit_tokenizer,
323
+ seqs,
324
+ padding_mask,
325
+ input_modality,
326
+ output_modality,
327
+ tgt_lang,
328
+ text_generation_opts,
329
+ unit_generation_opts,
330
+ unit_generation_ngram_filtering=unit_generation_ngram_filtering,
331
+ duration_factor=duration_factor,
332
+ prosody_encoder_input=prosody_encoder_input,
333
+ )
334
+
335
+ if self.apply_mintox and task_str != Task.ASR.name:
336
+ if input_modality == Modality.SPEECH:
337
+ if src_text is not None:
338
+ src_texts = [src_text]
339
+ else:
340
+ src_texts, _, = self.predict(
341
+ input=input,
342
+ task_str=Task.ASR.name,
343
+ tgt_lang=tgt_lang,
344
+ src_lang=src_lang,
345
+ text_generation_opts=text_generation_opts,
346
+ unit_generation_opts=unit_generation_opts,
347
+ spkr=spkr,
348
+ sample_rate=sample_rate,
349
+ unit_generation_ngram_filtering=unit_generation_ngram_filtering,
350
+ )
351
+ else:
352
+ assert isinstance(input, str)
353
+
354
+ src_texts = [input]
355
+
356
+ assert src_lang is not None
357
+ assert self.unit_tokenizer is not None
358
+ assert self.bad_word_checker is not None
359
+
360
+ texts, units = mintox_pipeline(
361
+ model=self.model,
362
+ text_tokenizer=self.text_tokenizer,
363
+ unit_tokenizer=self.unit_tokenizer,
364
+ device=self.device,
365
+ src_lang=src_lang,
366
+ tgt_lang=tgt_lang,
367
+ model_input=src,
368
+ input_modality=input_modality,
369
+ output_modality=output_modality,
370
+ src_texts=src_texts,
371
+ original_texts=texts,
372
+ original_units=units,
373
+ unit_generation_ngram_filtering=unit_generation_ngram_filtering,
374
+ text_generation_opts=text_generation_opts,
375
+ unit_generation_opts=unit_generation_opts,
376
+ bad_word_checker=self.bad_word_checker,
377
+ duration_factor=duration_factor,
378
+ prosody_encoder_input=prosody_encoder_input,
379
+ )
380
+
381
+ if output_modality == Modality.TEXT:
382
+ return texts, None
383
+ else:
384
+ assert units is not None
385
+
386
+ if isinstance(self.model.t2u_model, UnitYT2UModel):
387
+ # Remove the lang token for AR UnitY since the vocoder doesn't need it
388
+ # in the unit sequence. tgt_lang is fed as an argument to the vocoder.
389
+ units = units[:, 1:]
390
+ duration_prediction = True
391
+ else:
392
+ # Vocoder duration predictions not required since the NAR
393
+ # T2U model already predicts duration in the units.
394
+ duration_prediction = False
395
+
396
+ audio_wavs = []
397
+ speech_units = []
398
+ for i in range(len(units)):
399
+ assert self.model.t2u_model is not None
400
+ unit_padding_mask = (
401
+ units[i] != self.model.t2u_model.target_vocab_info.pad_idx
402
+ )
403
+ u = units[i][unit_padding_mask]
404
+ speech_units.append(u.tolist())
405
+
406
+ if self.vocoder is not None:
407
+ translated_audio_wav = self.vocoder(
408
+ units, tgt_lang, spkr, dur_prediction=duration_prediction
409
+ )
410
+ for i in range(len(units)):
411
+ padding_removed_audio_wav = translated_audio_wav[
412
+ i,
413
+ :,
414
+ : int(
415
+ translated_audio_wav.size(-1)
416
+ * len(speech_units[i])
417
+ / len(units[i])
418
+ ),
419
+ ].unsqueeze(0)
420
+ audio_wavs.append(padding_removed_audio_wav)
421
+ return (
422
+ texts,
423
+ BatchedSpeechOutput(
424
+ units=speech_units,
425
+ audio_wavs=audio_wavs,
426
+ sample_rate=sample_rate,
427
+ ),
428
+ )