| import warnings
|
| from pathlib import Path
|
|
|
| import argbind
|
| import numpy as np
|
| import torch
|
| from audiotools import AudioSignal
|
| from tqdm import tqdm
|
|
|
| from dac import DACFile
|
| from dac.utils import load_model
|
|
|
| warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
|
|
| @argbind.bind(group="decode", positional=True, without_prefix=True)
|
| @torch.inference_mode()
|
| @torch.no_grad()
|
| def decode(
|
| input: str,
|
| output: str = "",
|
| weights_path: str = "",
|
| model_tag: str = "latest",
|
| model_bitrate: str = "8kbps",
|
| device: str = "cuda",
|
| model_type: str = "44khz",
|
| verbose: bool = False,
|
| ):
|
| """Decode audio from codes.
|
|
|
| Parameters
|
| ----------
|
| input : str
|
| Path to input directory or file
|
| output : str, optional
|
| Path to output directory, by default "".
|
| If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
|
| weights_path : str, optional
|
| Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
|
| model_tag and model_type.
|
| model_tag : str, optional
|
| Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
|
| model_bitrate: str
|
| Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
| device : str, optional
|
| Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
|
| model_type : str, optional
|
| The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
|
| """
|
| generator = load_model(
|
| model_type=model_type,
|
| model_bitrate=model_bitrate,
|
| tag=model_tag,
|
| load_path=weights_path,
|
| )
|
| generator.to(device)
|
| generator.eval()
|
|
|
|
|
| _input = Path(input)
|
| input_files = list(_input.glob("**/*.dac"))
|
|
|
|
|
| if _input.suffix == ".dac":
|
| input_files.append(_input)
|
|
|
|
|
| output = Path(output)
|
| output.mkdir(parents=True, exist_ok=True)
|
|
|
| for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
|
|
|
| artifact = DACFile.load(input_files[i])
|
|
|
|
|
| recons = generator.decompress(artifact, verbose=verbose)
|
|
|
|
|
| relative_path = input_files[i].relative_to(input)
|
| output_dir = output / relative_path.parent
|
| if not relative_path.name:
|
| output_dir = output
|
| relative_path = input_files[i]
|
| output_name = relative_path.with_suffix(".wav").name
|
| output_path = output_dir / output_name
|
| output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| recons.write(output_path)
|
|
|
|
|
| if __name__ == "__main__":
|
| args = argbind.parse_args()
|
| with argbind.scope(args):
|
| decode()
|
|
|