File size: 6,661 Bytes
ad0da04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""Module for audio data processing including segmentation and positional encoding."""

from typing import List, Optional, Tuple, Any

import librosa
import numpy as np

from chorus_detection.audio.processor import AudioFeature
from chorus_detection.config import SR, HOP_LENGTH, MAX_FRAMES, MAX_METERS, N_FEATURES
from chorus_detection.utils.logging import logger


def segment_data_meters(data: np.ndarray, meter_grid: np.ndarray) -> List[np.ndarray]:
    """Divide song data into segments based on measure grid frames.
    
    Args:
        data: The song data to be segmented
        meter_grid: The grid indicating the start of each measure
        
    Returns:
        A list of song data segments
    """
    # Create segments using vectorized operations
    meter_segments = [data[s:e] for s, e in zip(meter_grid[:-1], meter_grid[1:])]
    
    # Convert all segments to float32 for consistent processing
    meter_segments = [segment.astype(np.float32) for segment in meter_segments]
    
    return meter_segments


def positional_encoding(position: int, d_model: int) -> np.ndarray:
    """Generate a positional encoding for a given position and model dimension.
    
    Args:
        position: The position for which to generate the encoding
        d_model: The dimension of the model
        
    Returns:
        The positional encoding
    """
    # Create position array
    positions = np.arange(position)[:, np.newaxis]
    
    # Calculate dimension-based scaling factors
    dim_indices = np.arange(d_model)[np.newaxis, :]
    angles = positions / np.power(10000, (2 * (dim_indices // 2)) / np.float32(d_model))
    
    # Apply sine to even indices and cosine to odd indices
    encodings = np.zeros((position, d_model), dtype=np.float32)
    encodings[:, 0::2] = np.sin(angles[:, 0::2])
    encodings[:, 1::2] = np.cos(angles[:, 1::2])
    
    return encodings


def apply_hierarchical_positional_encoding(segments: List[np.ndarray]) -> List[np.ndarray]:
    """Apply positional encoding at the meter and frame levels to a list of segments.
    
    Args:
        segments: The list of segments to encode
        
    Returns:
        The list of segments with applied positional encoding
    """
    if not segments:
        logger.warning("No segments to encode")
        return []
    
    n_features = segments[0].shape[1]
    
    # Generate measure-level positional encodings
    measure_level_encodings = positional_encoding(len(segments), n_features)
    
    # Apply hierarchical encodings to each segment
    encoded_segments = []
    for i, segment in enumerate(segments):
        # Generate frame-level positional encoding
        frame_level_encoding = positional_encoding(len(segment), n_features)
        
        # Combine frame-level and measure-level encodings
        encoded_segment = segment + frame_level_encoding + measure_level_encodings[i]
        encoded_segments.append(encoded_segment)
    
    return encoded_segments


def pad_song(encoded_segments: List[np.ndarray], max_frames: int = MAX_FRAMES, 
             max_meters: int = MAX_METERS, n_features: int = N_FEATURES) -> np.ndarray:
    """Pad or truncate the encoded segments to have the specified dimensions.
    
    Args:
        encoded_segments: The encoded segments to pad or truncate
        max_frames: The maximum number of frames per segment
        max_meters: The maximum number of meters
        n_features: The number of features per frame
        
    Returns:
        The padded or truncated song as a numpy array
    """
    if not encoded_segments:
        logger.warning("No encoded segments to pad")
        return np.zeros((max_meters, max_frames, n_features), dtype=np.float32)
    
    # Pad or truncate each meter/segment to max_frames
    padded_meters = []
    for meter in encoded_segments:
        # Truncate if longer than max_frames
        truncated_meter = meter[:max_frames] if meter.shape[0] > max_frames else meter
        
        # Pad if shorter than max_frames
        if truncated_meter.shape[0] < max_frames:
            padding = ((0, max_frames - truncated_meter.shape[0]), (0, 0))
            padded_meter = np.pad(truncated_meter, padding, 'constant', constant_values=0)
        else:
            padded_meter = truncated_meter
        
        padded_meters.append(padded_meter)
    
    # Create padding meter (all zeros)
    padding_meter = np.zeros((max_frames, n_features), dtype=np.float32)
    
    # Truncate or pad to max_meters
    if len(padded_meters) > max_meters:
        padded_song = np.array(padded_meters[:max_meters])
    else:
        padded_song = np.array(padded_meters + [padding_meter] * (max_meters - len(padded_meters)))
    
    return padded_song


def process_audio(audio_path: str, trim_silence: bool = True, sr: int = SR, 
                  hop_length: int = HOP_LENGTH) -> Tuple[Optional[np.ndarray], Optional[AudioFeature]]:
    """Process an audio file, extracting features and applying positional encoding.
    
    Args:
        audio_path: The path to the audio file
        trim_silence: Whether to trim silence from the audio
        sr: The sample rate to use when loading the audio
        hop_length: The hop length to use for feature extraction
        
    Returns:
        A tuple containing the processed audio and its features
    """
    logger.info(f"Processing audio file: {audio_path}")
    
    try:
        # First optionally strip silence
        if trim_silence:
            from chorus_detection.audio.processor import strip_silence
            strip_silence(audio_path)

        # Create audio feature object and extract features
        audio_features = AudioFeature(audio_path=audio_path, sr=sr, hop_length=hop_length)
        audio_features.extract_features()
        audio_features.create_meter_grid()
        
        # Segment the audio data by meter grid
        audio_segments = segment_data_meters(
            audio_features.combined_features, audio_features.meter_grid)
        
        # Apply positional encoding
        encoded_audio_segments = apply_hierarchical_positional_encoding(audio_segments)
        
        # Pad song to fixed dimensions and add batch dimension
        processed_audio = np.expand_dims(pad_song(encoded_audio_segments), axis=0)
        
        logger.info(f"Audio processing complete: {processed_audio.shape}")
        return processed_audio, audio_features
    
    except Exception as e:
        logger.error(f"Error processing audio: {e}")
        
        import traceback
        logger.debug(traceback.format_exc())
        
        return None, None