dennisvdang's picture
Flatten directory structure for simpler imports
ad0da04
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Module for audio data processing including segmentation and positional encoding."""
from typing import List, Optional, Tuple, Any
import librosa
import numpy as np
from chorus_detection.audio.processor import AudioFeature
from chorus_detection.config import SR, HOP_LENGTH, MAX_FRAMES, MAX_METERS, N_FEATURES
from chorus_detection.utils.logging import logger
def segment_data_meters(data: np.ndarray, meter_grid: np.ndarray) -> List[np.ndarray]:
"""Divide song data into segments based on measure grid frames.
Args:
data: The song data to be segmented
meter_grid: The grid indicating the start of each measure
Returns:
A list of song data segments
"""
# Create segments using vectorized operations
meter_segments = [data[s:e] for s, e in zip(meter_grid[:-1], meter_grid[1:])]
# Convert all segments to float32 for consistent processing
meter_segments = [segment.astype(np.float32) for segment in meter_segments]
return meter_segments
def positional_encoding(position: int, d_model: int) -> np.ndarray:
"""Generate a positional encoding for a given position and model dimension.
Args:
position: The position for which to generate the encoding
d_model: The dimension of the model
Returns:
The positional encoding
"""
# Create position array
positions = np.arange(position)[:, np.newaxis]
# Calculate dimension-based scaling factors
dim_indices = np.arange(d_model)[np.newaxis, :]
angles = positions / np.power(10000, (2 * (dim_indices // 2)) / np.float32(d_model))
# Apply sine to even indices and cosine to odd indices
encodings = np.zeros((position, d_model), dtype=np.float32)
encodings[:, 0::2] = np.sin(angles[:, 0::2])
encodings[:, 1::2] = np.cos(angles[:, 1::2])
return encodings
def apply_hierarchical_positional_encoding(segments: List[np.ndarray]) -> List[np.ndarray]:
"""Apply positional encoding at the meter and frame levels to a list of segments.
Args:
segments: The list of segments to encode
Returns:
The list of segments with applied positional encoding
"""
if not segments:
logger.warning("No segments to encode")
return []
n_features = segments[0].shape[1]
# Generate measure-level positional encodings
measure_level_encodings = positional_encoding(len(segments), n_features)
# Apply hierarchical encodings to each segment
encoded_segments = []
for i, segment in enumerate(segments):
# Generate frame-level positional encoding
frame_level_encoding = positional_encoding(len(segment), n_features)
# Combine frame-level and measure-level encodings
encoded_segment = segment + frame_level_encoding + measure_level_encodings[i]
encoded_segments.append(encoded_segment)
return encoded_segments
def pad_song(encoded_segments: List[np.ndarray], max_frames: int = MAX_FRAMES,
max_meters: int = MAX_METERS, n_features: int = N_FEATURES) -> np.ndarray:
"""Pad or truncate the encoded segments to have the specified dimensions.
Args:
encoded_segments: The encoded segments to pad or truncate
max_frames: The maximum number of frames per segment
max_meters: The maximum number of meters
n_features: The number of features per frame
Returns:
The padded or truncated song as a numpy array
"""
if not encoded_segments:
logger.warning("No encoded segments to pad")
return np.zeros((max_meters, max_frames, n_features), dtype=np.float32)
# Pad or truncate each meter/segment to max_frames
padded_meters = []
for meter in encoded_segments:
# Truncate if longer than max_frames
truncated_meter = meter[:max_frames] if meter.shape[0] > max_frames else meter
# Pad if shorter than max_frames
if truncated_meter.shape[0] < max_frames:
padding = ((0, max_frames - truncated_meter.shape[0]), (0, 0))
padded_meter = np.pad(truncated_meter, padding, 'constant', constant_values=0)
else:
padded_meter = truncated_meter
padded_meters.append(padded_meter)
# Create padding meter (all zeros)
padding_meter = np.zeros((max_frames, n_features), dtype=np.float32)
# Truncate or pad to max_meters
if len(padded_meters) > max_meters:
padded_song = np.array(padded_meters[:max_meters])
else:
padded_song = np.array(padded_meters + [padding_meter] * (max_meters - len(padded_meters)))
return padded_song
def process_audio(audio_path: str, trim_silence: bool = True, sr: int = SR,
hop_length: int = HOP_LENGTH) -> Tuple[Optional[np.ndarray], Optional[AudioFeature]]:
"""Process an audio file, extracting features and applying positional encoding.
Args:
audio_path: The path to the audio file
trim_silence: Whether to trim silence from the audio
sr: The sample rate to use when loading the audio
hop_length: The hop length to use for feature extraction
Returns:
A tuple containing the processed audio and its features
"""
logger.info(f"Processing audio file: {audio_path}")
try:
# First optionally strip silence
if trim_silence:
from chorus_detection.audio.processor import strip_silence
strip_silence(audio_path)
# Create audio feature object and extract features
audio_features = AudioFeature(audio_path=audio_path, sr=sr, hop_length=hop_length)
audio_features.extract_features()
audio_features.create_meter_grid()
# Segment the audio data by meter grid
audio_segments = segment_data_meters(
audio_features.combined_features, audio_features.meter_grid)
# Apply positional encoding
encoded_audio_segments = apply_hierarchical_positional_encoding(audio_segments)
# Pad song to fixed dimensions and add batch dimension
processed_audio = np.expand_dims(pad_song(encoded_audio_segments), axis=0)
logger.info(f"Audio processing complete: {processed_audio.shape}")
return processed_audio, audio_features
except Exception as e:
logger.error(f"Error processing audio: {e}")
import traceback
logger.debug(traceback.format_exc())
return None, None