File size: 4,876 Bytes
1ecb721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import numpy as np
from matplotlib import pyplot as plt
from scipy.fft import fft
from scipy.signal import savgol_filter
from tools import rms_normalize

colors = [
    # (0, 0, 0),       # Black
    # (86, 180, 233),  # Sky blue
    # (240, 228, 66),  # Yellow
    # (204, 121, 167),  # Reddish purple
    (213, 94, 0),  # Vermilion
    (0, 114, 178),  # Blue
    (230, 159, 0),  # Orange
    (0, 158, 115),  # Bluish green
]


def plot_psd_multiple_signals(signals_list, labels_list, sample_rate=16000, window_size=500,

                              figsize=(10, 6), save_path=None, normalize=False):
    """

    在同一张图上绘制多组音频信号的功率谱密度比较图,使用对数刻度的响度轴(以2为底),并应用平滑处理。



    参数:

    signals_list: 包含多组音频信号的列表,每组信号形状为 [sample_number, sample_length] 的numpy array

    labels_list: 每组音频信号对应的标签字符串列表

    sample_rate: 音频的采样率

    """

    # 确保传入的signals_list和labels_list长度相同
    assert len(signals_list) == len(labels_list), "每组信号必须有一个对应的标签。"

    signals_list = [np.array([rms_normalize(signal) for signal in signals]) for signals in signals_list]

    # 绘图准备
    plt.figure(figsize=figsize)

    # 遍历所有的音频信号
    i = 0
    for signal, label in zip(signals_list, labels_list):
        # 计算FFT
        fft_signal = fft(signal, axis=1)

        # 计算平均功率谱密度
        psd_signal = np.mean(np.abs(fft_signal)**2, axis=0)

        # 计算频率轴
        freqs = np.fft.fftfreq(signal.shape[1], 1/sample_rate)

        # 应用Savitzky-Golay滤波器进行平滑
        psd_smoothed = savgol_filter(np.log2(psd_signal[:signal.shape[1] // 2] + 1), window_size, 3)  # 窗口大小51, 多项式阶数3

        # Normalize each curve if normalize is True
        if normalize:
            psd_smoothed /= np.mean(psd_smoothed)

        # 绘制每组信号的功率谱密度
        plt.plot(freqs[:signal.shape[1] // 2], psd_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1)
        i += 1

    # 设置图表元素
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Mean Log-Amplitude')
    plt.legend()

    # 根据save_path参数决定保存图像还是直接显示
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()


def plot_amplitude_over_time(signals_list, labels_list, sample_rate=16000, window_size=500,

                             figsize=(10, 6), save_path=None, normalize=False, start_time=0):
    """

    Plot the loudness of multiple sets of audio signals over time on the same graph,

    using a logarithmic scale for the loudness axis (base 2), with smoothing applied.



    Parameters:

    signals_list: List of sets of audio signals, each set is a numpy array with shape [sample_number, sample_length]

    labels_list: List of labels corresponding to each set of audio signals

    sample_rate: Sampling rate of the audio

    window_size: Window size for the Savitzky-Golay filter

    figsize: Figure size

    save_path: Path to save the figure, if None, the figure will be displayed

    normalize: Whether to normalize each curve so that the sum of each curve is the same

    start_time: Time (in seconds) to start plotting, only data after this time will be retained

    """
    assert len(signals_list) == len(labels_list), f"len(signals_list) != len(labels_list) for " \
                                                  f"len(signals_list) = {len(signals_list)} and len(labels_list) = {len(labels_list)}"

    # Compute starting sample index
    start_sample = int(start_time * sample_rate)
    
    # Normalize signals and truncate data
    signals_list = [np.array([rms_normalize(signal)[start_sample:] for signal in signals]) for signals in signals_list]
    time_axis = np.arange(start_sample, start_sample + signals_list[0].shape[1]) / sample_rate

    plt.figure(figsize=figsize)

    i = 0
    for signal, label in zip(signals_list, labels_list):
        amplitude_mean = np.mean(np.abs(signal), axis=0)

        amplitude_smoothed = savgol_filter(np.log2(amplitude_mean + 1), window_size, 3)

        # Normalize each curve if normalize is True
        if normalize:
            amplitude_smoothed /= np.mean(amplitude_smoothed)

        plt.plot(time_axis, amplitude_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1)
        i += 1

    plt.xlabel('Time (seconds)')
    plt.ylabel('Mean Log-Amplitude')
    plt.legend()

    # Save or show the figure based on save_path parameter
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()