File size: 4,436 Bytes
f7fb447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Functionality for detecting reflections in a room impulse response (RIR)."""

from abc import ABC, abstractmethod
from enum import Enum
from typing import Tuple, Union

import numpy as np
from scipy.signal import find_peaks, find_peaks_cwt


# pylint: disable=too-few-public-methods
class ReflectionDetectionStrategy(ABC):
    """Base interface for a reflection detection algorithm."""

    @staticmethod
    @abstractmethod
    def get_indeces_of_reflections(intensity_magnitude: np.ndarray) -> np.ndarray:
        """Abstract method to be overwritten by concrete implementations of
        reflection detection."""


# pylint: disable=too-few-public-methods
class CorrelationReflectionDetectionStrategy(ReflectionDetectionStrategy):
    """Algorithm for detecting reflections based on the correlation."""

    @staticmethod
    def get_indeces_of_reflections(intensity_magnitude: np.ndarray) -> np.ndarray:
        raise NotImplementedError("Implemented method is Scipy's find_peaks")


# pylint: disable=too-few-public-methods
class ThresholdReflectionDetectionStrategy(ReflectionDetectionStrategy):
    """Algorithm for detecting reflections based on a threshold."""

    @staticmethod
    def get_indeces_of_reflections(intensity_magnitude: np.ndarray) -> np.ndarray:
        raise NotImplementedError("Implemented method is Scipy's find_peaks")


# pylint: disable=too-few-public-methods
class NeighborReflectionDetectionStrategy(ReflectionDetectionStrategy):
    """Algorithm for detecting reflections based on the surrounding values
    of local maxima."""

    @staticmethod
    def get_indeces_of_reflections(intensity_magnitude: np.ndarray) -> np.ndarray:
        """Find local maxima in the intensity magnitude signal.

        Args:
            intensity_magnitude (np.ndarray): intensity magnitude signal.

        Returns:
            np.ndarray: an array with the indeces of the peaks.
        """
        # Drop peak properties ([0]) and direct sound peak ([1])
        return find_peaks(intensity_magnitude)[0]


class WaveletReflectionDetectionStrategy(ReflectionDetectionStrategy):
    """Algorithm for detecting reflections based on the wavelet transform"""

    @staticmethod
    def get_indeces_of_reflections(intensity_magnitude: np.ndarray) -> np.ndarray:
        """Find local maxima in the intensity magnitude signal.

        Args:
            intensity_magnitude (np.ndarray): intensity magnitude signal.

        Returns:
            np.ndarray: an array with the indeces of the peaks.
        """
        # Drop peak properties ([0]) and direct sound peak ([1])
        return find_peaks_cwt(intensity_magnitude, widths=np.arange(5, 15))[0]


class ReflectionDetectionStrategies(Enum):
    """Enum class for accessing the existing `ReflectionDetectionStrategy`s"""

    CORRELATION = CorrelationReflectionDetectionStrategy
    THRESHOLD = ThresholdReflectionDetectionStrategy
    SCIPY = NeighborReflectionDetectionStrategy
    WAVELET = WaveletReflectionDetectionStrategy


def detect_reflections(
    intensity: np.ndarray,
    azimuth: np.ndarray,
    elevation: np.ndarray,
    detection_strategy: Union[
        ReflectionDetectionStrategy, ReflectionDetectionStrategies
    ] = NeighborReflectionDetectionStrategy,
) -> Tuple[np.ndarray]:
    """Analyze the normalized intensity, azimuth and elevation arrays to look for
    reflections. Timeframes which don't contain a reflection are masked.

    Args:
        intensity (np.ndarray): normalized intensity array.
        azimuth (np.ndarray): array with horizontal angles with respect to the XZ plane.
        elevation (np.ndarray): array with vertical angles with respect to the XY plane.

    Returns:
        intensity (np.ndarray): masked intensities with only reflections different than 0.
        azimuth (np.ndarray): masked intensities with only reflections different than 0.
        elevation (np.ndarray): masked intensities with only reflections different than 0.
    """
    if isinstance(detection_strategy, ReflectionDetectionStrategies):
        detection_strategy = detection_strategy.value

    reflections_indeces = detection_strategy.get_indeces_of_reflections(intensity)
    # Add direct sound
    reflections_indeces = np.insert(reflections_indeces, 0, 0)
    return (
        intensity[reflections_indeces],
        azimuth[reflections_indeces],
        elevation[reflections_indeces],
        reflections_indeces,
    )