Spaces:
Running
Running
"""Input preprocessing module.""" | |
from abc import ABC, abstractmethod | |
from enum import Enum | |
import numpy as np | |
from scipy.signal import fftconvolve | |
from aira.engine.filtering import NonCoincidentMicsCorrection | |
from aira.utils import convert_ambisonics_a_to_b | |
# pylint: disable=too-few-public-methods | |
class InputMode(Enum): | |
"""Enum class for accessing the existing `InputMode`s""" | |
LSS = "lss" | |
AFORMAT = "aformat" | |
BFORMAT = "bformat" | |
# pylint: disable=too-few-public-methods | |
class InputProcessor(ABC): | |
"""Base interface for inputs processors""" | |
def process(self, input_dict: dict) -> dict: | |
"""Abstract method to be overwritten by concrete implementations of | |
input processing.""" | |
# pylint: disable=too-few-public-methods | |
class LSSInputProcessor(InputProcessor): | |
"""Processing when input data is in LSS mode""" | |
def process(self, input_dict: dict) -> dict: | |
"""Gets impulse response arrays from Long Sine Sweep (LSS) measurements. The new | |
signals are in A-Format. | |
Parameters | |
---------- | |
input_dict : dict | |
Dictionary with LSS measurement arrays | |
Returns | |
------- | |
dict | |
input_dict overwritten with A-Format signals | |
""" | |
if input_dict["input_mode"] != InputMode.LSS: | |
return input_dict | |
input_dict["stacked_signals"] = np.apply_along_axis( | |
lambda array: fftconvolve(array, input_dict["inverse_filter"], mode="full"), | |
axis=1, | |
arr=input_dict["stacked_signals"], | |
) | |
input_dict["input_mode"] = InputMode.AFORMAT | |
return input_dict | |
# pylint: disable=too-few-public-methods | |
class AFormatProcessor(InputProcessor): | |
"""Processing when input data is in mode AFORMAT""" | |
def process(self, input_dict: dict) -> dict: | |
"""Gets B-format arrays from A-format arrays. For more details see | |
aira.utils.formatter.convert_ambisonics_a_to_b function. | |
Parameters | |
---------- | |
input_dict : dict | |
Dictionary with A-format arrays | |
Returns | |
------- | |
dict | |
input_dict overwritten with B-format signals | |
""" | |
if input_dict["input_mode"] != InputMode.AFORMAT: | |
return input_dict | |
input_dict["stacked_signals"] = convert_ambisonics_a_to_b( | |
input_dict["stacked_signals"][0, :], | |
input_dict["stacked_signals"][1, :], | |
input_dict["stacked_signals"][2, :], | |
input_dict["stacked_signals"][3, :], | |
) | |
input_dict["input_mode"] = InputMode.BFORMAT | |
return input_dict | |
# pylint: disable=too-few-public-methods | |
class BFormatProcessor(InputProcessor): | |
"""Processin when input data is in BFORMAT mode.""" | |
def process(self, input_dict: dict) -> dict: | |
"""Corrects B-format arrays frequency response for non-coincident microphones. | |
Parameters | |
---------- | |
input_dict : dict | |
Dictionary with B-format arrays. | |
Returns | |
------- | |
dict | |
input_dict overwritten with B-format frequency corrected arrays. | |
""" | |
if input_dict["input_mode"] != InputMode.BFORMAT and not bool( | |
input_dict["frequency_correction"] | |
): | |
return input_dict | |
frequency_corrector = NonCoincidentMicsCorrection(input_dict["sample_rate"]) | |
input_dict["stacked_signals"][0, :] = frequency_corrector.correct_omni( | |
input_dict["stacked_signals"][0, :] | |
) | |
input_dict["stacked_signals"][1:, :] = frequency_corrector.correct_axis( | |
input_dict["stacked_signals"][1:, :] | |
) | |
input_dict["input_mode"] = InputMode.BFORMAT | |
return input_dict | |
# pylint: disable=too-few-public-methods | |
class InputProcessorChain: | |
"""Chain of input processors""" | |
def __init__(self): | |
self.processors = [LSSInputProcessor(), AFormatProcessor(), BFormatProcessor()] | |
def process(self, input_dict: dict) -> np.ndarray: | |
"""Applies the chain of processors for the input_mode setted. | |
Parameters | |
---------- | |
input_dict : dict | |
Contains arrays and input mode data | |
Returns | |
------- | |
np.ndarray | |
Arrays processed stacked in single numpy.ndarray object | |
""" | |
for process_i in self.processors: | |
input_dict = process_i.process(input_dict) | |
return input_dict["stacked_signals"] | |