|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import logging |
|
|
|
import torch |
|
|
|
from seamless_communication.models.unit_extractor import UnitExtractor |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description="Convert raw audio to units (and optionally audio) using UnitExtractor." |
|
) |
|
parser.add_argument("audio", type=str, help="Audio WAV file path.") |
|
parser.add_argument( |
|
"--kmeans_uri", |
|
type=str, |
|
help="URL path to the K-Means model.", |
|
default="https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy", |
|
) |
|
parser.add_argument( |
|
"--model_name", |
|
type=str, |
|
help="Feature extraction model name (`xlsr2_1b_v2`)", |
|
default="xlsr2_1b_v2", |
|
) |
|
parser.add_argument( |
|
"--out_layer_number", |
|
type=int, |
|
help="Layer number of the feature extraction model to pull out features from.", |
|
default=35, |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
logger.info("Running unit_extraction on the GPU.") |
|
else: |
|
device = torch.device("cpu") |
|
logger.info("Running unit_extraction on the CPU.") |
|
|
|
unit_extractor = UnitExtractor(args.model_name, args.kmeans_uri, device=device) |
|
units = unit_extractor.predict(args.audio, args.out_layer_number - 1) |
|
logger.info(f"Converted to units: {units}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|