tja-generator / model.py
github-actions[bot]
Sync to HuggingFace Spaces
9df2e22
raw
history blame contribute delete
No virus
2.33 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from preprocess import *
class convNet(nn.Module):
"""
copies the neural net used in a paper.
"Improved musical onset detection with Convolutional Neural Networks".
src: https://ieeexplore.ieee.org/document/6854953
"""
def __init__(self):
super(convNet, self).__init__()
# model
self.conv1 = nn.Conv2d(3, 10, (3, 7))
self.conv2 = nn.Conv2d(10, 20, 3)
self.fc1 = nn.Linear(1120, 256)
self.fc2 = nn.Linear(256, 120)
self.fc3 = nn.Linear(120, 1)
def forward(self, x, istraining=False, minibatch=1):
x = F.max_pool2d(F.relu(self.conv1(x)), (3, 1))
x = F.max_pool2d(F.relu(self.conv2(x)), (3, 1))
x = F.dropout(x.view(minibatch, -1), training=istraining)
x = F.dropout(F.relu(self.fc1(x)), training=istraining)
x = F.dropout(F.relu(self.fc2(x)), training=istraining)
return F.sigmoid(self.fc3(x))
def infer_data_builder(self, feats, soundlen=15, minibatch=1):
x = []
for i in range(feats.shape[2] - soundlen):
x.append(feats[:, :, i : i + soundlen])
if (i + 1) % minibatch == 0:
yield (torch.from_numpy(np.array(x)).float())
x = []
if len(x) != 0:
yield (torch.from_numpy(np.array(x)).float())
def infer(self, feats, device, minibatch=1):
with torch.no_grad():
inference = None
for x in tqdm(
self.infer_data_builder(feats, minibatch=minibatch),
total=feats.shape[2] // minibatch,
):
output = self(x.to(device), minibatch=x.shape[0])
if inference is not None:
inference = np.concatenate(
(inference, output.cpu().numpy().reshape(-1))
)
else:
inference = output.cpu().numpy().reshape(-1)
return np.array(inference).reshape(-1)
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = convNet()
net = net.to(device)
print(net)
print("parameters: ", sum(p.numel() for p in net.parameters()))