File size: 2,784 Bytes
1af34cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/jzi040941/PercepNet

https://arxiv.org/abs/2008.04259

https://modelzoo.co/model/percepnet

太复杂了。
(1)pytorch 模型只是整个 pipeline 中的一部分。
(2)训练样本需经过基音分析,频谱包络之类的计算。

"""
import torch
import torch.nn as nn


class PercepNet(nn.Module):
    """
    https://github.com/jzi040941/PercepNet/blob/main/rnn_train.py#L105

    4.1% of an x86 CPU core
    """
    def __init__(self, input_dim=70):
        super(PercepNet, self).__init__()
        # self.hidden_dim = hidden_dim
        # self.n_layers = n_layers

        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU()
        )
        self.conv1 = nn.Sequential(
            nn.Conv1d(128, 512, 5, stride=1, padding=4),
            nn.ReLU()
        )#padding for align with c++ dnn
        self.conv2 = nn.Sequential(
            nn.Conv1d(512, 512, 3, stride=1, padding=2),
            nn.Tanh()
        )
        #self.gru = nn.GRU(512, 512, 3, batch_first=True)
        self.gru1 = nn.GRU(512, 512, 1, batch_first=True)
        self.gru2 = nn.GRU(512, 512, 1, batch_first=True)
        self.gru3 = nn.GRU(512, 512, 1, batch_first=True)

        self.gru_gb = nn.GRU(512, 512, 1, batch_first=True)
        self.gru_rb = nn.GRU(1024, 128, 1, batch_first=True)

        self.fc_gb = nn.Sequential(
            nn.Linear(512*5, 34),
            nn.Sigmoid()
        )
        self.fc_rb = nn.Sequential(
            nn.Linear(128, 34),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor):
        # x shape: [b, t, f]
        x = self.fc(x)
        x = x.permute([0, 2, 1])
        # x shape: [b, f, t]

        # causal conv
        x = self.conv1(x)
        x = x[:, :, :-4]

        # x shape: [b, f, t]
        convout = self.conv2(x)
        convout = convout[:, :, :-2]
        convout = convout.permute([0, 2, 1])
        # convout shape: [b, t, f]

        gru1_out, gru1_state = self.gru1(convout)
        gru2_out, gru2_state = self.gru2(gru1_out)
        gru3_out, gru3_state = self.gru3(gru2_out)

        gru_gb_out, gru_gb_state = self.gru_gb(gru3_out)
        concat_gb_layer = torch.cat(tensors=(convout, gru1_out, gru2_out, gru3_out, gru_gb_out), dim=-1)
        gb = self.fc_gb(concat_gb_layer)

        # concat rb need fix
        concat_rb_layer = torch.cat(tensors=(gru3_out, convout), dim=-1)
        rnn_rb_out, gru_rb_state = self.gru_rb(concat_rb_layer)
        rb = self.fc_rb(rnn_rb_out)

        output = torch.cat((gb, rb), dim=-1)
        return output


def main():
    model = PercepNet()
    x = torch.randn(20, 8, 70)
    out = model(x)
    print(out.shape)


if __name__ == "__main__":
    main()