Spaces:
Running
Running
File size: 6,661 Bytes
ad0da04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Module for audio data processing including segmentation and positional encoding."""
from typing import List, Optional, Tuple, Any
import librosa
import numpy as np
from chorus_detection.audio.processor import AudioFeature
from chorus_detection.config import SR, HOP_LENGTH, MAX_FRAMES, MAX_METERS, N_FEATURES
from chorus_detection.utils.logging import logger
def segment_data_meters(data: np.ndarray, meter_grid: np.ndarray) -> List[np.ndarray]:
"""Divide song data into segments based on measure grid frames.
Args:
data: The song data to be segmented
meter_grid: The grid indicating the start of each measure
Returns:
A list of song data segments
"""
# Create segments using vectorized operations
meter_segments = [data[s:e] for s, e in zip(meter_grid[:-1], meter_grid[1:])]
# Convert all segments to float32 for consistent processing
meter_segments = [segment.astype(np.float32) for segment in meter_segments]
return meter_segments
def positional_encoding(position: int, d_model: int) -> np.ndarray:
"""Generate a positional encoding for a given position and model dimension.
Args:
position: The position for which to generate the encoding
d_model: The dimension of the model
Returns:
The positional encoding
"""
# Create position array
positions = np.arange(position)[:, np.newaxis]
# Calculate dimension-based scaling factors
dim_indices = np.arange(d_model)[np.newaxis, :]
angles = positions / np.power(10000, (2 * (dim_indices // 2)) / np.float32(d_model))
# Apply sine to even indices and cosine to odd indices
encodings = np.zeros((position, d_model), dtype=np.float32)
encodings[:, 0::2] = np.sin(angles[:, 0::2])
encodings[:, 1::2] = np.cos(angles[:, 1::2])
return encodings
def apply_hierarchical_positional_encoding(segments: List[np.ndarray]) -> List[np.ndarray]:
"""Apply positional encoding at the meter and frame levels to a list of segments.
Args:
segments: The list of segments to encode
Returns:
The list of segments with applied positional encoding
"""
if not segments:
logger.warning("No segments to encode")
return []
n_features = segments[0].shape[1]
# Generate measure-level positional encodings
measure_level_encodings = positional_encoding(len(segments), n_features)
# Apply hierarchical encodings to each segment
encoded_segments = []
for i, segment in enumerate(segments):
# Generate frame-level positional encoding
frame_level_encoding = positional_encoding(len(segment), n_features)
# Combine frame-level and measure-level encodings
encoded_segment = segment + frame_level_encoding + measure_level_encodings[i]
encoded_segments.append(encoded_segment)
return encoded_segments
def pad_song(encoded_segments: List[np.ndarray], max_frames: int = MAX_FRAMES,
max_meters: int = MAX_METERS, n_features: int = N_FEATURES) -> np.ndarray:
"""Pad or truncate the encoded segments to have the specified dimensions.
Args:
encoded_segments: The encoded segments to pad or truncate
max_frames: The maximum number of frames per segment
max_meters: The maximum number of meters
n_features: The number of features per frame
Returns:
The padded or truncated song as a numpy array
"""
if not encoded_segments:
logger.warning("No encoded segments to pad")
return np.zeros((max_meters, max_frames, n_features), dtype=np.float32)
# Pad or truncate each meter/segment to max_frames
padded_meters = []
for meter in encoded_segments:
# Truncate if longer than max_frames
truncated_meter = meter[:max_frames] if meter.shape[0] > max_frames else meter
# Pad if shorter than max_frames
if truncated_meter.shape[0] < max_frames:
padding = ((0, max_frames - truncated_meter.shape[0]), (0, 0))
padded_meter = np.pad(truncated_meter, padding, 'constant', constant_values=0)
else:
padded_meter = truncated_meter
padded_meters.append(padded_meter)
# Create padding meter (all zeros)
padding_meter = np.zeros((max_frames, n_features), dtype=np.float32)
# Truncate or pad to max_meters
if len(padded_meters) > max_meters:
padded_song = np.array(padded_meters[:max_meters])
else:
padded_song = np.array(padded_meters + [padding_meter] * (max_meters - len(padded_meters)))
return padded_song
def process_audio(audio_path: str, trim_silence: bool = True, sr: int = SR,
hop_length: int = HOP_LENGTH) -> Tuple[Optional[np.ndarray], Optional[AudioFeature]]:
"""Process an audio file, extracting features and applying positional encoding.
Args:
audio_path: The path to the audio file
trim_silence: Whether to trim silence from the audio
sr: The sample rate to use when loading the audio
hop_length: The hop length to use for feature extraction
Returns:
A tuple containing the processed audio and its features
"""
logger.info(f"Processing audio file: {audio_path}")
try:
# First optionally strip silence
if trim_silence:
from chorus_detection.audio.processor import strip_silence
strip_silence(audio_path)
# Create audio feature object and extract features
audio_features = AudioFeature(audio_path=audio_path, sr=sr, hop_length=hop_length)
audio_features.extract_features()
audio_features.create_meter_grid()
# Segment the audio data by meter grid
audio_segments = segment_data_meters(
audio_features.combined_features, audio_features.meter_grid)
# Apply positional encoding
encoded_audio_segments = apply_hierarchical_positional_encoding(audio_segments)
# Pad song to fixed dimensions and add batch dimension
processed_audio = np.expand_dims(pad_song(encoded_audio_segments), axis=0)
logger.info(f"Audio processing complete: {processed_audio.shape}")
return processed_audio, audio_features
except Exception as e:
logger.error(f"Error processing audio: {e}")
import traceback
logger.debug(traceback.format_exc())
return None, None |