victan commited on
Commit
3e33b3e
1 Parent(s): f183300

Upload seamless_communication/cli/m4t/audio_to_units/audio_to_units.py with huggingface_hub

Browse files
seamless_communication/cli/m4t/audio_to_units/audio_to_units.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # MIT_LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import logging
8
+
9
+ import torch
10
+
11
+ from seamless_communication.models.unit_extractor import UnitExtractor
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def main():
18
+ parser = argparse.ArgumentParser(
19
+ description="Convert raw audio to units (and optionally audio) using UnitExtractor."
20
+ )
21
+ parser.add_argument("audio", type=str, help="Audio WAV file path.")
22
+ parser.add_argument(
23
+ "--kmeans_uri",
24
+ type=str,
25
+ help="URL path to the K-Means model.",
26
+ default="https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
27
+ )
28
+ parser.add_argument(
29
+ "--model_name",
30
+ type=str,
31
+ help="Feature extraction model name (`xlsr2_1b_v2`)",
32
+ default="xlsr2_1b_v2",
33
+ )
34
+ parser.add_argument(
35
+ "--out_layer_number",
36
+ type=int,
37
+ help="Layer number of the feature extraction model to pull out features from.",
38
+ default=35,
39
+ )
40
+
41
+ args = parser.parse_args()
42
+
43
+ if torch.cuda.is_available():
44
+ device = torch.device("cuda:0")
45
+ logger.info("Running unit_extraction on the GPU.")
46
+ else:
47
+ device = torch.device("cpu")
48
+ logger.info("Running unit_extraction on the CPU.")
49
+
50
+ unit_extractor = UnitExtractor(args.model_name, args.kmeans_uri, device=device)
51
+ units = unit_extractor.predict(args.audio, args.out_layer_number - 1)
52
+ logger.info(f"Converted to units: {units}")
53
+
54
+
55
+ if __name__ == "__main__":
56
+ main()