victan commited on
Commit
411bb63
1 Parent(s): 5688ece

Upload seamless_communication/models/aligner/alignment_extractor.py with huggingface_hub

Browse files
seamless_communication/models/aligner/alignment_extractor.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from typing import Any, List, Tuple, Union
9
+
10
+ import numpy
11
+ import torch
12
+ import torch.nn as nn
13
+ import torchaudio
14
+ from fairseq2.typing import DataType, Device
15
+ from fairseq2.data.typing import StringLike
16
+ from torch import Tensor
17
+
18
+ from seamless_communication.models.aligner.loader import load_unity2_alignment_model
19
+ from seamless_communication.models.unit_extractor import UnitExtractor
20
+
21
+ try:
22
+ import matplotlib.pyplot as plt
23
+
24
+ matplotlib_available = True
25
+ except ImportError:
26
+ matplotlib_available = False
27
+
28
+
29
+ class AlignmentExtractor(nn.Module):
30
+ def __init__(
31
+ self,
32
+ aligner_model_name_or_card: str,
33
+ unit_extractor_model_name_or_card: Union[Any, str] = None,
34
+ unit_extractor_output_layer: Union[Any, int] = None,
35
+ unit_extractor_kmeans_model_uri: Union[Any, str] = None,
36
+ device: Device = Device("cpu"),
37
+ dtype: DataType = torch.float32,
38
+ ):
39
+ super().__init__()
40
+ self.device = device
41
+ self.dtype = dtype
42
+
43
+ if self.dtype == torch.float16 and self.device == Device("cpu"):
44
+ raise RuntimeError("FP16 only works on GPU, set args accordingly")
45
+
46
+ self.alignment_model = load_unity2_alignment_model(
47
+ aligner_model_name_or_card, device=self.device, dtype=self.dtype
48
+ )
49
+ self.alignment_model.eval()
50
+
51
+ self.unit_extractor = None
52
+ self.unit_extractor_output_layer = 0
53
+
54
+ if unit_extractor_model_name_or_card is not None:
55
+ self.unit_extractor = UnitExtractor(
56
+ unit_extractor_model_name_or_card,
57
+ unit_extractor_kmeans_model_uri,
58
+ device=device,
59
+ dtype=dtype,
60
+ )
61
+ self.unit_extractor_output_layer = unit_extractor_output_layer
62
+
63
+ def load_audio(
64
+ self, audio_path: str, sampling_rate: int = 16_000
65
+ ) -> Tuple[Tensor, int]:
66
+ assert os.path.exists(audio_path)
67
+ audio, rate = torchaudio.load(audio_path)
68
+ if rate != sampling_rate:
69
+ audio = torchaudio.functional.resample(audio, rate, sampling_rate)
70
+ rate = sampling_rate
71
+ return audio, rate
72
+
73
+ def prepare_audio(self, audio: Union[str, Tensor]) -> Tensor:
74
+ # TODO: switch to fairseq2 data pipeline once it supports resampling
75
+ if isinstance(audio, str):
76
+ audio, _ = self.load_audio(audio, sampling_rate=16_000)
77
+ if audio.ndim > 1:
78
+ # averaging over channels
79
+ assert audio.size(0) < audio.size(
80
+ 1
81
+ ), "Expected [Channel,Time] shape, but Channel > Time"
82
+ audio = audio.mean(0)
83
+ assert (
84
+ audio.ndim == 1
85
+ ), f"After channel averaging audio shape expected to be [Time] i.e. mono audio"
86
+ audio = audio.to(self.device, self.dtype)
87
+
88
+ return audio
89
+
90
+ def extract_units(self, audio: Tensor) -> Tensor:
91
+ assert isinstance(
92
+ self.unit_extractor, UnitExtractor
93
+ ), "Unit extractor is required to get units from audio tensor"
94
+ units = self.unit_extractor.predict(audio, self.unit_extractor_output_layer)
95
+ return units
96
+
97
+ @torch.inference_mode()
98
+ def extract_alignment(
99
+ self,
100
+ audio: Union[str, Tensor],
101
+ text: str,
102
+ plot: bool = False,
103
+ add_trailing_silence: bool = False,
104
+ ) -> Tuple[Tensor, Tensor, List[StringLike]]:
105
+ if isinstance(audio, Tensor) and not torch.is_floating_point(audio):
106
+ # we got units as audio arg
107
+ units = audio
108
+ units = units.to(self.device)
109
+ audio_tensor = None
110
+ else:
111
+ audio_tensor = self.prepare_audio(audio)
112
+ units = self.extract_units(audio_tensor)
113
+
114
+ tokenized_unit_ids = self.alignment_model.alignment_frontend.tokenize_unit(
115
+ units
116
+ ).unsqueeze(0)
117
+ tokenized_text_ids = (
118
+ self.alignment_model.alignment_frontend.tokenize_text(
119
+ text, add_trailing_silence=add_trailing_silence
120
+ )
121
+ .to(self.device)
122
+ .unsqueeze(0)
123
+ )
124
+ tokenized_text_tokens = (
125
+ self.alignment_model.alignment_frontend.tokenize_text_to_tokens(
126
+ text, add_trailing_silence=add_trailing_silence
127
+ )
128
+ )
129
+ _, alignment_durations = self.alignment_model(
130
+ tokenized_text_ids, tokenized_unit_ids
131
+ )
132
+
133
+ if plot and (audio_tensor is not None):
134
+ self.plot_alignment(
135
+ audio_tensor.cpu(), tokenized_text_tokens, alignment_durations.cpu()
136
+ )
137
+
138
+ return alignment_durations, tokenized_text_ids, tokenized_text_tokens
139
+
140
+ def detokenize_text(self, tokenized_text_ids: Tensor) -> StringLike:
141
+ return self.alignment_model.alignment_frontend.decode_text(tokenized_text_ids)
142
+
143
+ def plot_alignment(
144
+ self, audio: Tensor, text_tokens: List[StringLike], durations: Tensor
145
+ ) -> None:
146
+ if not matplotlib_available:
147
+ raise RuntimeError(
148
+ "Please `pip install matplotlib` in order to use plot alignment."
149
+ )
150
+ _, ax = plt.subplots(figsize=(22, 3.5))
151
+ ax.plot(audio, color="gray", linewidth=0.3)
152
+ durations_cumul = numpy.concatenate([numpy.array([0]), numpy.cumsum(durations)])
153
+ alignment_ticks = durations_cumul * 320 # 320 is hardcoded for 20ms rate here
154
+
155
+ ax.vlines(
156
+ alignment_ticks,
157
+ ymax=1,
158
+ ymin=-1,
159
+ color="indigo",
160
+ linestyles="dashed",
161
+ lw=0.5,
162
+ )
163
+
164
+ middle_tick_positions = (
165
+ durations_cumul[:-1] + (durations_cumul[1:] - durations_cumul[:-1]) / 2
166
+ )
167
+ ax.set_xticks(middle_tick_positions * 320)
168
+ ax.set_xticklabels(text_tokens, fontsize=13)
169
+ ax.set_xlim(0, len(audio))
170
+
171
+ ax.set_ylim(audio.min(), audio.max())
172
+ ax.set_yticks([])
173
+ plt.tight_layout()
174
+ plt.show()