victan commited on
Commit
90d2634
1 Parent(s): 7933050

Upload seamless_communication/cli/expressivity/evaluate/pretssel_inference_helper.py with huggingface_hub

Browse files
seamless_communication/cli/expressivity/evaluate/pretssel_inference_helper.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # This source code is licensed under the license found in the
4
+ # MIT_LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import List
7
+
8
+ import torch
9
+ from torch.nn import Module
10
+
11
+ from fairseq2.typing import DataType, Device
12
+
13
+ from fairseq2.assets import asset_store
14
+ from fairseq2.data import (
15
+ Collater,
16
+ SequenceData,
17
+ VocabularyInfo,
18
+ )
19
+ from fairseq2.nn.padding import get_seqs_and_padding_mask
20
+
21
+ from seamless_communication.inference import BatchedSpeechOutput
22
+ from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
23
+
24
+
25
+ class PretsselGenerator(Module):
26
+ def __init__(
27
+ self,
28
+ pretssel_name_or_card: str,
29
+ vocab_info: VocabularyInfo,
30
+ device: Device,
31
+ dtype: DataType = torch.float16,
32
+ ):
33
+ super().__init__()
34
+ # Load the model.
35
+ if device == torch.device("cpu"):
36
+ dtype = torch.float32
37
+
38
+ self.device = device
39
+ self.dtype = dtype
40
+
41
+ self.pretssel_model = load_pretssel_vocoder_model(
42
+ pretssel_name_or_card,
43
+ device=device,
44
+ dtype=dtype,
45
+ )
46
+ self.pretssel_model.eval()
47
+
48
+ vocoder_model_card = asset_store.retrieve_card(pretssel_name_or_card)
49
+ self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
50
+
51
+ self.vocab_info = vocab_info
52
+ self.unit_collate = Collater(pad_value=vocab_info.pad_idx)
53
+ self.duration_collate = Collater(pad_value=0)
54
+ self.unit_eos_token = torch.tensor([vocab_info.eos_idx], device=device)
55
+
56
+ @torch.inference_mode()
57
+ def predict(
58
+ self,
59
+ units: List[List[int]],
60
+ tgt_lang: str,
61
+ prosody_encoder_input: SequenceData,
62
+ ) -> BatchedSpeechOutput:
63
+
64
+ units_batch, durations = [], []
65
+ for u in units:
66
+ unit = torch.tensor(u).to(self.unit_eos_token)
67
+
68
+ # adjust the control symbols for the embedding
69
+ unit += 4
70
+ unit = torch.cat([unit, self.unit_eos_token], dim=0)
71
+
72
+ unit, duration = torch.unique_consecutive(unit, return_counts=True)
73
+
74
+ # adjust for the last eos token
75
+ duration[-1] = 0
76
+
77
+ units_batch.append(unit)
78
+ durations.append(duration * 2)
79
+
80
+ speech_units = self.unit_collate(units_batch)
81
+ durations = self.duration_collate(durations)["seqs"]
82
+
83
+ units_tensor, unit_padding_mask = get_seqs_and_padding_mask(speech_units)
84
+ prosody_input_seqs, prosody_padding_mask = get_seqs_and_padding_mask(
85
+ prosody_encoder_input
86
+ )
87
+
88
+ audio_wavs = self.pretssel_model(
89
+ units_tensor,
90
+ tgt_lang,
91
+ prosody_input_seqs,
92
+ padding_mask=unit_padding_mask,
93
+ prosody_padding_mask=prosody_padding_mask,
94
+ durations=durations,
95
+ )
96
+ return BatchedSpeechOutput(
97
+ units=units,
98
+ audio_wavs=audio_wavs,
99
+ sample_rate=self.output_sample_rate,
100
+ )