File size: 4,825 Bytes
b8b70ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# modified from https://github.com/dhchoi99/NANSY
# We have modified the implementation of dhchoi99 to be fully differentiable.
import math
import torch
from yin import *


class Pitch(torch.nn.Module):

    def __init__(
            self, 
            sr=22050, 
            w_step=256, 
            W=2048, 
            tau_max=2048, 
            midi_start=5, 
            midi_end=85, 
            octave_range=12):
        super(Pitch, self).__init__()
        self.sr = sr
        self.w_step = w_step
        self.W = W
        self.tau_max = tau_max
        self.unfold = torch.nn.Unfold((1, self.W),
                                      1,
                                      0,
                                      stride=(1, self.w_step))
        midis = list(range(midi_start, midi_end))
        self.len_midis = len(midis)
        c_ms = torch.tensor([self.midi_to_lag(m, octave_range) for m in midis])
        self.register_buffer('c_ms', c_ms)
        self.register_buffer('c_ms_ceil', torch.ceil(self.c_ms).long())
        self.register_buffer('c_ms_floor', torch.floor(self.c_ms).long())

    def midi_to_lag(self, m: int, octave_range: float = 12):
        """converts midi-to-lag, eq. (4)

        Args:
            m: midi
            sr: sample_rate
            octave_range:

        Returns:
            lag: time lag(tau, c(m)) calculated from midi, eq. (4)

        """
        f = 440 * math.pow(2, (m - 69) / octave_range)
        lag = self.sr / f
        return lag

    def yingram_from_cmndf(self, cmndfs: torch.Tensor) -> torch.Tensor:
        """ yingram calculator from cMNDFs(cumulative Mean Normalized Difference Functions)

        Args:
            cmndfs: torch.Tensor
                calculated cumulative mean normalized difference function
                for details, see models/yin.py or eq. (1) and (2)
            ms: list of midi(int)
            sr: sampling rate

        Returns:
            y:
                calculated batch yingram


        """
        #c_ms = np.asarray([Pitch.midi_to_lag(m, sr) for m in ms])
        #c_ms = torch.from_numpy(c_ms).to(cmndfs.device)

        y = (cmndfs[:, self.c_ms_ceil] -
             cmndfs[:, self.c_ms_floor]) / (self.c_ms_ceil - self.c_ms_floor).unsqueeze(0) * (
                 self.c_ms - self.c_ms_floor).unsqueeze(0) + cmndfs[:, self.c_ms_floor]
        return y

    def yingram(self, x: torch.Tensor):
        """calculates yingram from raw audio (multi segment)

        Args:
            x: raw audio, torch.Tensor of shape (t)
            W: yingram Window Size
            tau_max:
            sr: sampling rate
            w_step: yingram bin step size

        Returns:
            yingram: yingram. torch.Tensor of shape (80 x t')

        """
        # x.shape: t -> B,T, B,T = x.shape
        B, T = x.shape
        w_len = self.W


        frames = self.unfold(x.view(B, 1, 1, T))
        frames = frames.permute(0, 2,
                                1).contiguous().view(-1,
                                                     self.W)  #[B* frames, W]
        # If not using gpu, or torch not compatible, implemented numpy batch function is still fine
        dfs = differenceFunctionTorch(frames, frames.shape[-1], self.tau_max)
        cmndfs = cumulativeMeanNormalizedDifferenceFunctionTorch(
            dfs, self.tau_max)
        yingram = self.yingram_from_cmndf(cmndfs)  #[B*frames,F]
        yingram = yingram.view(B, -1, self.len_midis).permute(0, 2,
                                                              1)  # [B,F,T]
        return yingram

    def crop_scope(self, x, yin_start,
                   scope_shift):  # x: tensor [B,C,T] #scope_shift: tensor [B]
        return torch.stack([
            x[i, yin_start + scope_shift[i]:yin_start + self.yin_scope +
              scope_shift[i], :] for i in range(x.shape[0])
        ],
                           dim=0)


if __name__ == '__main__':
    import torch
    import librosa as rosa
    import matplotlib.pyplot as plt
    wav = torch.tensor(rosa.load('LJ001-0002.wav', sr=22050,
                                 mono=True)[0]).unsqueeze(0)
    #    wav = torch.randn(1,40965)

    wav = torch.nn.functional.pad(wav, (0, (-wav.shape[1]) % 256))
    #    wav = wav[#:,:8096]
    print(wav.shape)
    pitch = Pitch()

    with torch.no_grad():
        ps = pitch.yingram(torch.nn.functional.pad(wav, (1024, 1024)))
        ps = torch.nn.functional.pad(ps, (0, 0, 8, 8), mode='replicate')
        print(ps.shape)
        spec = torch.stft(wav, 1024, 256, return_complex=False)
        print(spec.shape)
        plt.subplot(2, 1, 1)
        plt.pcolor(ps[0].numpy(), cmap='magma')
        plt.colorbar()
        plt.subplot(2, 1, 2)
        plt.pcolor(ps[0][15:65, :].numpy(), cmap='magma')
        plt.colorbar()
        plt.show()