Spaces:
Running
on
Zero
Running
on
Zero
"""Run NatureLM-audio over a set of audio files paths or a directory with audio files.""" | |
import argparse | |
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
import librosa | |
import torch | |
from NatureLM.config import Config | |
from NatureLM.models import NatureLM | |
from NatureLM.processors import NatureLMAudioProcessor | |
from NatureLM.utils import move_to_device | |
_MAX_LENGTH_SECONDS = 10 | |
_MIN_CHUNK_LENGTH_SECONDS = 0.5 | |
_SAMPLE_RATE = 16000 # Assuming the model uses a sample rate of 16kHz | |
_AUDIO_FILE_EXTENSIONS = [ | |
".wav", | |
".mp3", | |
".flac", | |
".ogg", | |
".mp4" | |
] # Add other audio file formats as needed | |
_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
__root_dir = Path(__file__).parent.parent | |
_DEFAULT_CONFIG_PATH = __root_dir / "configs" / "inference.yml" | |
def load_model_and_config( | |
cfg_path: str | Path = _DEFAULT_CONFIG_PATH, device: str = _DEVICE | |
) -> tuple[NatureLM, Config]: | |
"""Load the NatureLM model and configuration. | |
Returns: | |
tuple: The loaded model and configuration. | |
""" | |
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio") | |
model = model.to(device).eval() | |
model.llama_tokenizer.pad_token_id = model.llama_tokenizer.eos_token_id | |
model.llama_model.generation_config.pad_token_id = ( | |
model.llama_tokenizer.pad_token_id | |
) | |
cfg = Config.from_sources(cfg_path) | |
return model, cfg | |
def output_template(model_output: str, start_time: float, end_time: float) -> str: | |
"""Format the output of the model.""" | |
return f"#{start_time:.2f}s - {end_time:.2f}s#: {model_output}\n" | |
def sliding_window_inference( | |
audio: str | Path | np.ndarray, | |
query: str, | |
processor: NatureLMAudioProcessor, | |
model: NatureLM, | |
cfg: Config, | |
window_length_seconds: float = 10.0, | |
hop_length_seconds: float = 10.0, | |
input_sr: int = _SAMPLE_RATE, | |
device: str = _DEVICE, | |
) -> list[dict[str, any]]: | |
"""Run inference on a long audio file using sliding window approach. | |
Args: | |
audio (str | Path | np.ndarray): Path to the audio file. | |
query (str): Query for the model. | |
processor (NatureLMAudioProcessor): Audio processor. | |
model (NatureLM): NatureLM model. | |
cfg (Config): Model configuration. | |
window_length_seconds (float): Length of the sliding window in seconds. | |
hop_length_seconds (float): Hop length for the sliding window in seconds. | |
input_sr (int): Sample rate of the audio file. | |
Returns: | |
str: The output of the model. | |
Raises: | |
ValueError: If the audio file is too short or if the audio file path is invalid. | |
""" | |
if isinstance(audio, str) or isinstance(audio, Path): | |
audio_array, input_sr = librosa.load(str(audio), sr=None, mono=False) | |
elif isinstance(audio, np.ndarray): | |
audio_array = audio | |
print(f"Using provided sample rate: {input_sr}") | |
audio_array = audio_array.squeeze() | |
if audio_array.ndim > 1: | |
axis_to_average = int(np.argmin(audio_array.shape)) | |
audio_array = audio_array.mean(axis=axis_to_average) | |
audio_array = audio_array.squeeze() | |
# Do initial check that the audio is long enough | |
if audio_array.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr): | |
raise ValueError( | |
f"Audio is too short. Minimum length is {_MIN_CHUNK_LENGTH_SECONDS} seconds." | |
) | |
start = 0 | |
stride = int(hop_length_seconds * input_sr) | |
window_length = int(window_length_seconds * input_sr) | |
window_id = 0 | |
output = [] # Initialize output list | |
while True: | |
chunk = audio_array[start : start + window_length] | |
if chunk.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr): | |
break | |
# Resamples, pads, truncates and creates torch Tensor | |
audio_tensor, prompt_list = processor([chunk], [query], [input_sr]) | |
input_to_model = { | |
"raw_wav": audio_tensor, | |
"prompt": prompt_list[0], | |
"audio_chunk_sizes": 1, | |
"padding_mask": torch.zeros_like(audio_tensor).to(torch.bool), | |
} | |
input_to_model = move_to_device(input_to_model, device) | |
# generate | |
prediction: str = model.generate(input_to_model, cfg.generate, prompt_list)[0] | |
# Post-process the prediction | |
# prediction = output_template(prediction, start / input_sr, (start + window_length) / input_sr) | |
# output += prediction | |
output.append( | |
{ | |
"start_time": start / input_sr, | |
"end_time": (start + window_length) / input_sr, | |
"prediction": prediction, | |
"window_number": window_id, | |
} | |
) | |
# Move the window | |
start += stride | |
if start + window_length > audio_array.shape[-1]: | |
break | |
return output | |
class Pipeline: | |
"""Pipeline for running NatureLM-audio inference on a list of audio files or audio arrays""" | |
def __init__( | |
self, model: NatureLM = None, cfg_path: str | Path = _DEFAULT_CONFIG_PATH | |
): | |
self.cfg_path = cfg_path | |
# Load model and config | |
if model is not None: | |
self.cfg = Config.from_sources(cfg_path) | |
self.model = model | |
else: | |
# Download model from hub | |
self.model, self.cfg = load_model_and_config(cfg_path) | |
self.processor = NatureLMAudioProcessor( | |
sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS | |
) | |
def __call__( | |
self, | |
audios: list[str | Path | np.ndarray], | |
queries: str | list[str], | |
window_length_seconds: float = 10.0, | |
hop_length_seconds: float = 10.0, | |
input_sample_rate: int = _SAMPLE_RATE, | |
verbose: bool = False, | |
) -> list[str]: | |
"""Run inference on a list of audio file paths or a single audio file with a | |
single query or a list of queries. If multiple queries are provided, | |
we assume that they are in the same order as the audio files. If a single query | |
is provided, it will be used for all audio files. | |
Args: | |
audios (list[str | Path | np.ndarray]): List of audio file paths or a single audio file path or audio array(s) | |
queries (str | list[str]): Queries for the model. | |
window_length_seconds (float): Length of the sliding window in seconds. Defaults to 10.0. | |
hop_length_seconds (float): Hop length for the sliding window in seconds. Defaults to 10.0. | |
input_sample_rate (int): Sample rate of the audio. Defaults to 16000, which is the model's sample rate. | |
verbose (bool): If True, print the output of the model for each audio file. | |
Defaults to False. | |
Returns: | |
list[list[dict]]: List of model outputs for each audio file. Each output is a list of dictionaries | |
containing the start time, end time, and prediction for each chunk of audio. | |
Raises: | |
ValueError: If the number of audio files and queries do not match. | |
""" | |
if isinstance(audios, str) or isinstance(audios, Path): | |
audios = [audios] | |
if isinstance(queries, str): | |
queries = [queries] * len(audios) | |
if len(audios) != len(queries): | |
raise ValueError("Number of audio files and queries must match.") | |
# Run inference | |
results = [] | |
for audio, query in zip(audios, queries): | |
output = sliding_window_inference( | |
audio, | |
query, | |
self.processor, | |
self.model, | |
self.cfg, | |
window_length_seconds, | |
hop_length_seconds, | |
input_sr=input_sample_rate, | |
) | |
results.append(output) | |
if verbose: | |
print(f"Processed {audio}, model output:\n=======\n{output}\n=======") | |
return results | |
def parse_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser("Run NatureLM-audio inference") | |
parser.add_argument( | |
"-a", | |
"--audio", | |
type=str, | |
required=True, | |
help="Path to an audio file or a directory containing audio files", | |
) | |
parser.add_argument( | |
"-q", "--query", type=str, required=True, help="Query for the model" | |
) | |
parser.add_argument( | |
"--cfg-path", | |
type=str, | |
default="configs/inference.yml", | |
help="Path to the configuration file for the model", | |
) | |
parser.add_argument( | |
"--output_path", | |
type=str, | |
default="inference_output.jsonl", | |
help="Output path for the results", | |
) | |
parser.add_argument( | |
"--window_length_seconds", | |
type=float, | |
default=10.0, | |
help="Length of the sliding window in seconds", | |
) | |
parser.add_argument( | |
"--hop_length_seconds", | |
type=float, | |
default=10.0, | |
help="Hop length for the sliding window in seconds", | |
) | |
args = parser.parse_args() | |
return args | |
def main( | |
cfg_path: str | Path, | |
audio_path: str | Path, | |
query: str, | |
output_path: str, | |
window_length_seconds: float, | |
hop_length_seconds: float, | |
) -> None: | |
"""Main function to run the NatureLM-audio inference script. | |
It takes command line arguments for audio file path, query, output path, | |
window length, and hop length. It processes the audio files and saves the | |
results to a CSV file. | |
Args: | |
cfg_path (str | Path): Path to the configuration file. | |
audio_path (str | Path): Path to the audio file or directory. | |
query (str): Query for the model. | |
output_path (str): Path to save the output results. | |
window_length_seconds (float): Length of the sliding window in seconds. | |
hop_length_seconds (float): Hop length for the sliding window in seconds. | |
Raises: | |
ValueError: If the audio file path is invalid or if the query is empty. | |
ValueError: If no audio files are found. | |
ValueError: If the audio file extension is not supported. | |
""" | |
# Prepare sample | |
audio_path = Path(audio_path) | |
if audio_path.is_dir(): | |
audio_paths = [] | |
print( | |
f"Searching for audio files in {str(audio_path)} with extensions {', '.join(_AUDIO_FILE_EXTENSIONS)}" | |
) | |
for ext in _AUDIO_FILE_EXTENSIONS: | |
audio_paths.extend(list(audio_path.rglob(f"*{ext}"))) | |
print(f"Found {len(audio_paths)} audio files in {str(audio_path)}") | |
else: | |
# check that the extension is valid | |
if not any(audio_path.suffix == ext for ext in _AUDIO_FILE_EXTENSIONS): | |
raise ValueError( | |
f"Invalid audio file extension. Supported extensions are: {', '.join(_AUDIO_FILE_EXTENSIONS)}" | |
) | |
audio_paths = [audio_path] | |
# check that query is not empty | |
if not query: | |
raise ValueError("Query cannot be empty") | |
if not audio_paths: | |
raise ValueError( | |
"No audio files found. Please check the path or file extensions." | |
) | |
# Load model and config | |
model, cfg = load_model_and_config(cfg_path) | |
# Load audio processor | |
processor = NatureLMAudioProcessor( | |
sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS | |
) | |
# Run inference | |
results = {"audio_path": [], "output": []} | |
for path in audio_paths: | |
output = sliding_window_inference( | |
path, | |
query, | |
processor, | |
model, | |
cfg, | |
window_length_seconds, | |
hop_length_seconds, | |
) | |
results["audio_path"].append(str(path)) | |
results["output"].append(output) | |
print(f"Processed {path}, model output:\n=======\n{output}\n=======\n") | |
# Save results as a csv | |
output_path = Path(output_path) | |
output_path.parent.mkdir(parents=True, exist_ok=True) | |
df = pd.DataFrame(results) | |
df.to_json(output_path, orient="records", lines=True) | |
print(f"Results saved to {output_path}") | |
if __name__ == "__main__": | |
args = parse_args() | |
main( | |
cfg_path=args.cfg_path, | |
audio_path=args.audio, | |
query=args.query, | |
output_path=args.output_path, | |
window_length_seconds=args.window_length_seconds, | |
hop_length_seconds=args.hop_length_seconds, | |
) | |