Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
"""Module for audio data processing including segmentation and positional encoding.""" | |
from typing import List, Optional, Tuple, Any | |
import librosa | |
import numpy as np | |
from chorus_detection.audio.processor import AudioFeature | |
from chorus_detection.config import SR, HOP_LENGTH, MAX_FRAMES, MAX_METERS, N_FEATURES | |
from chorus_detection.utils.logging import logger | |
def segment_data_meters(data: np.ndarray, meter_grid: np.ndarray) -> List[np.ndarray]: | |
"""Divide song data into segments based on measure grid frames. | |
Args: | |
data: The song data to be segmented | |
meter_grid: The grid indicating the start of each measure | |
Returns: | |
A list of song data segments | |
""" | |
# Create segments using vectorized operations | |
meter_segments = [data[s:e] for s, e in zip(meter_grid[:-1], meter_grid[1:])] | |
# Convert all segments to float32 for consistent processing | |
meter_segments = [segment.astype(np.float32) for segment in meter_segments] | |
return meter_segments | |
def positional_encoding(position: int, d_model: int) -> np.ndarray: | |
"""Generate a positional encoding for a given position and model dimension. | |
Args: | |
position: The position for which to generate the encoding | |
d_model: The dimension of the model | |
Returns: | |
The positional encoding | |
""" | |
# Create position array | |
positions = np.arange(position)[:, np.newaxis] | |
# Calculate dimension-based scaling factors | |
dim_indices = np.arange(d_model)[np.newaxis, :] | |
angles = positions / np.power(10000, (2 * (dim_indices // 2)) / np.float32(d_model)) | |
# Apply sine to even indices and cosine to odd indices | |
encodings = np.zeros((position, d_model), dtype=np.float32) | |
encodings[:, 0::2] = np.sin(angles[:, 0::2]) | |
encodings[:, 1::2] = np.cos(angles[:, 1::2]) | |
return encodings | |
def apply_hierarchical_positional_encoding(segments: List[np.ndarray]) -> List[np.ndarray]: | |
"""Apply positional encoding at the meter and frame levels to a list of segments. | |
Args: | |
segments: The list of segments to encode | |
Returns: | |
The list of segments with applied positional encoding | |
""" | |
if not segments: | |
logger.warning("No segments to encode") | |
return [] | |
n_features = segments[0].shape[1] | |
# Generate measure-level positional encodings | |
measure_level_encodings = positional_encoding(len(segments), n_features) | |
# Apply hierarchical encodings to each segment | |
encoded_segments = [] | |
for i, segment in enumerate(segments): | |
# Generate frame-level positional encoding | |
frame_level_encoding = positional_encoding(len(segment), n_features) | |
# Combine frame-level and measure-level encodings | |
encoded_segment = segment + frame_level_encoding + measure_level_encodings[i] | |
encoded_segments.append(encoded_segment) | |
return encoded_segments | |
def pad_song(encoded_segments: List[np.ndarray], max_frames: int = MAX_FRAMES, | |
max_meters: int = MAX_METERS, n_features: int = N_FEATURES) -> np.ndarray: | |
"""Pad or truncate the encoded segments to have the specified dimensions. | |
Args: | |
encoded_segments: The encoded segments to pad or truncate | |
max_frames: The maximum number of frames per segment | |
max_meters: The maximum number of meters | |
n_features: The number of features per frame | |
Returns: | |
The padded or truncated song as a numpy array | |
""" | |
if not encoded_segments: | |
logger.warning("No encoded segments to pad") | |
return np.zeros((max_meters, max_frames, n_features), dtype=np.float32) | |
# Pad or truncate each meter/segment to max_frames | |
padded_meters = [] | |
for meter in encoded_segments: | |
# Truncate if longer than max_frames | |
truncated_meter = meter[:max_frames] if meter.shape[0] > max_frames else meter | |
# Pad if shorter than max_frames | |
if truncated_meter.shape[0] < max_frames: | |
padding = ((0, max_frames - truncated_meter.shape[0]), (0, 0)) | |
padded_meter = np.pad(truncated_meter, padding, 'constant', constant_values=0) | |
else: | |
padded_meter = truncated_meter | |
padded_meters.append(padded_meter) | |
# Create padding meter (all zeros) | |
padding_meter = np.zeros((max_frames, n_features), dtype=np.float32) | |
# Truncate or pad to max_meters | |
if len(padded_meters) > max_meters: | |
padded_song = np.array(padded_meters[:max_meters]) | |
else: | |
padded_song = np.array(padded_meters + [padding_meter] * (max_meters - len(padded_meters))) | |
return padded_song | |
def process_audio(audio_path: str, trim_silence: bool = True, sr: int = SR, | |
hop_length: int = HOP_LENGTH) -> Tuple[Optional[np.ndarray], Optional[AudioFeature]]: | |
"""Process an audio file, extracting features and applying positional encoding. | |
Args: | |
audio_path: The path to the audio file | |
trim_silence: Whether to trim silence from the audio | |
sr: The sample rate to use when loading the audio | |
hop_length: The hop length to use for feature extraction | |
Returns: | |
A tuple containing the processed audio and its features | |
""" | |
logger.info(f"Processing audio file: {audio_path}") | |
try: | |
# First optionally strip silence | |
if trim_silence: | |
from chorus_detection.audio.processor import strip_silence | |
strip_silence(audio_path) | |
# Create audio feature object and extract features | |
audio_features = AudioFeature(audio_path=audio_path, sr=sr, hop_length=hop_length) | |
audio_features.extract_features() | |
audio_features.create_meter_grid() | |
# Segment the audio data by meter grid | |
audio_segments = segment_data_meters( | |
audio_features.combined_features, audio_features.meter_grid) | |
# Apply positional encoding | |
encoded_audio_segments = apply_hierarchical_positional_encoding(audio_segments) | |
# Pad song to fixed dimensions and add batch dimension | |
processed_audio = np.expand_dims(pad_song(encoded_audio_segments), axis=0) | |
logger.info(f"Audio processing complete: {processed_audio.shape}") | |
return processed_audio, audio_features | |
except Exception as e: | |
logger.error(f"Error processing audio: {e}") | |
import traceback | |
logger.debug(traceback.format_exc()) | |
return None, None |