Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://github.com/jzi040941/PercepNet | |
https://arxiv.org/abs/2008.04259 | |
https://modelzoo.co/model/percepnet | |
太复杂了。 | |
(1)pytorch 模型只是整个 pipeline 中的一部分。 | |
(2)训练样本需经过基音分析,频谱包络之类的计算。 | |
""" | |
import torch | |
import torch.nn as nn | |
class PercepNet(nn.Module): | |
""" | |
https://github.com/jzi040941/PercepNet/blob/main/rnn_train.py#L105 | |
4.1% of an x86 CPU core | |
""" | |
def __init__(self, input_dim=70): | |
super(PercepNet, self).__init__() | |
# self.hidden_dim = hidden_dim | |
# self.n_layers = n_layers | |
self.fc = nn.Sequential( | |
nn.Linear(input_dim, 128), | |
nn.ReLU() | |
) | |
self.conv1 = nn.Sequential( | |
nn.Conv1d(128, 512, 5, stride=1, padding=4), | |
nn.ReLU() | |
)#padding for align with c++ dnn | |
self.conv2 = nn.Sequential( | |
nn.Conv1d(512, 512, 3, stride=1, padding=2), | |
nn.Tanh() | |
) | |
#self.gru = nn.GRU(512, 512, 3, batch_first=True) | |
self.gru1 = nn.GRU(512, 512, 1, batch_first=True) | |
self.gru2 = nn.GRU(512, 512, 1, batch_first=True) | |
self.gru3 = nn.GRU(512, 512, 1, batch_first=True) | |
self.gru_gb = nn.GRU(512, 512, 1, batch_first=True) | |
self.gru_rb = nn.GRU(1024, 128, 1, batch_first=True) | |
self.fc_gb = nn.Sequential( | |
nn.Linear(512*5, 34), | |
nn.Sigmoid() | |
) | |
self.fc_rb = nn.Sequential( | |
nn.Linear(128, 34), | |
nn.Sigmoid() | |
) | |
def forward(self, x: torch.Tensor): | |
# x shape: [b, t, f] | |
x = self.fc(x) | |
x = x.permute([0, 2, 1]) | |
# x shape: [b, f, t] | |
# causal conv | |
x = self.conv1(x) | |
x = x[:, :, :-4] | |
# x shape: [b, f, t] | |
convout = self.conv2(x) | |
convout = convout[:, :, :-2] | |
convout = convout.permute([0, 2, 1]) | |
# convout shape: [b, t, f] | |
gru1_out, gru1_state = self.gru1(convout) | |
gru2_out, gru2_state = self.gru2(gru1_out) | |
gru3_out, gru3_state = self.gru3(gru2_out) | |
gru_gb_out, gru_gb_state = self.gru_gb(gru3_out) | |
concat_gb_layer = torch.cat(tensors=(convout, gru1_out, gru2_out, gru3_out, gru_gb_out), dim=-1) | |
gb = self.fc_gb(concat_gb_layer) | |
# concat rb need fix | |
concat_rb_layer = torch.cat(tensors=(gru3_out, convout), dim=-1) | |
rnn_rb_out, gru_rb_state = self.gru_rb(concat_rb_layer) | |
rb = self.fc_rb(rnn_rb_out) | |
output = torch.cat((gb, rb), dim=-1) | |
return output | |
def main(): | |
model = PercepNet() | |
x = torch.randn(20, 8, 70) | |
out = model(x) | |
print(out.shape) | |
if __name__ == "__main__": | |
main() | |