File size: 3,797 Bytes
c1e08a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29e2c21
c1e08a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Module for handling audio input through Gradio interface."""

from typing import Callable

import numpy as np
from scipy import signal

from improvisation_lab.infrastructure.audio.audio_processor import \
    AudioProcessor


class WebAudioProcessor(AudioProcessor):
    """Handle audio input from Gradio interface."""

    def __init__(
        self,
        sample_rate: int,
        callback: Callable[[np.ndarray], None] | None = None,
        buffer_duration: float = 0.3,
    ):
        """Initialize GradioAudioInput.

        Args:
            sample_rate: Audio sample rate in Hz
            callback: Optional callback function to process audio data
            buffer_duration: Duration of audio buffer in seconds
        """
        super().__init__(sample_rate, callback, buffer_duration)

    def _resample_audio(
        self, audio_data: np.ndarray, original_sr: int, target_sr: int
    ) -> np.ndarray:
        """Resample audio data to target sample rate.

        In the case of Gradio,
        the sample rate of the audio data may not match the target sample rate.

        Args:
            audio_data: numpy array of audio samples
            original_sr: Original sample rate in Hz
            target_sr: Target sample rate in Hz

        Returns:
            Resampled audio data with target sample rate
        """
        number_of_samples = round(len(audio_data) * float(target_sr) / original_sr)
        resampled_data = signal.resample(audio_data, number_of_samples)
        return resampled_data

    def _normalize_audio(self, audio_data: np.ndarray) -> np.ndarray:
        """Normalize audio data to range [-1, 1] by dividing by maximum absolute value.

        Args:
            audio_data: numpy array of audio samples

        Returns:
            Normalized audio data with values between -1 and 1
        """
        if len(audio_data) == 0:
            return audio_data
        max_abs = np.max(np.abs(audio_data))
        return audio_data if max_abs == 0 else audio_data / max_abs

    def _remove_low_amplitude_noise(self, audio_data: np.ndarray) -> np.ndarray:
        """Remove low amplitude noise from audio data.

        Applies a threshold to remove low amplitude signals that are likely noise.

        Args:
            audio_data: Audio data as numpy array

        Returns:
            Audio data with low amplitude noise removed
        """
        # [TODO] Set appropriate threshold
        threshold = 20.0
        audio_data[np.abs(audio_data) < threshold] = 0
        return audio_data

    def process_audio(self, audio_input: tuple[int, np.ndarray]) -> None:
        """Process incoming audio data from Gradio.

        Args:
            audio_input: Tuple of (sample_rate, audio_data)
                        where audio_data is a (samples, channels) array
        """
        if not self.is_recording:
            return

        input_sample_rate, audio_data = audio_input
        if input_sample_rate != self.sample_rate:
            audio_data = self._resample_audio(
                audio_data, input_sample_rate, self.sample_rate
            )
        audio_data = self._remove_low_amplitude_noise(audio_data)
        audio_data = self._normalize_audio(audio_data)

        self._append_to_buffer(audio_data)
        self._process_buffer()

    def start_recording(self):
        """Start accepting audio input from Gradio."""
        if self.is_recording:
            raise RuntimeError("Recording is already in progress")
        self.is_recording = True

    def stop_recording(self):
        """Stop accepting audio input from Gradio."""
        if not self.is_recording:
            raise RuntimeError("Recording is not in progress")
        self.is_recording = False
        self._buffer = np.array([], dtype=np.float32)