EXP
Browse files
infer/lib/predictors/DJCM/DJCM.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from scipy.signal import medfilt
|
| 8 |
+
|
| 9 |
+
sys.path.append(os.getcwd())
|
| 10 |
+
|
| 11 |
+
from main.library.predictors.DJCM.spec import Spectrogram
|
| 12 |
+
|
| 13 |
+
SAMPLE_RATE, WINDOW_LENGTH, N_CLASS = 16000, 1024, 360
|
| 14 |
+
|
| 15 |
+
class DJCM:
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
model_path,
|
| 19 |
+
device = "cpu",
|
| 20 |
+
is_half = False,
|
| 21 |
+
onnx = False,
|
| 22 |
+
svs = False,
|
| 23 |
+
providers = ["CPUExecutionProvider"],
|
| 24 |
+
batch_size = 1,
|
| 25 |
+
segment_len = 5.12,
|
| 26 |
+
kernel_size = 3
|
| 27 |
+
):
|
| 28 |
+
super(DJCM, self).__init__()
|
| 29 |
+
if svs: WINDOW_LENGTH = 2048
|
| 30 |
+
self.onnx = onnx
|
| 31 |
+
|
| 32 |
+
if self.onnx:
|
| 33 |
+
import onnxruntime as ort
|
| 34 |
+
|
| 35 |
+
sess_options = ort.SessionOptions()
|
| 36 |
+
sess_options.log_severity_level = 3
|
| 37 |
+
self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
|
| 38 |
+
else:
|
| 39 |
+
from main.library.predictors.DJCM.model import DJCMM
|
| 40 |
+
|
| 41 |
+
model = DJCMM(1, 1, 1, svs=svs, window_length=WINDOW_LENGTH, n_class=N_CLASS)
|
| 42 |
+
model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
|
| 43 |
+
model.eval()
|
| 44 |
+
if is_half: model = model.half()
|
| 45 |
+
self.model = model.to(device)
|
| 46 |
+
|
| 47 |
+
self.batch_size = batch_size
|
| 48 |
+
self.seg_len = int(segment_len * SAMPLE_RATE)
|
| 49 |
+
self.seg_frames = int(self.seg_len // int(SAMPLE_RATE // 100))
|
| 50 |
+
|
| 51 |
+
self.device = device
|
| 52 |
+
self.is_half = is_half
|
| 53 |
+
self.kernel_size = kernel_size
|
| 54 |
+
|
| 55 |
+
self.spec_extractor = Spectrogram(int(SAMPLE_RATE // 100), WINDOW_LENGTH).to(device)
|
| 56 |
+
cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
|
| 57 |
+
self.cents_mapping = np.pad(cents_mapping, (4, 4))
|
| 58 |
+
|
| 59 |
+
def spec2hidden(self, spec):
|
| 60 |
+
if self.onnx:
|
| 61 |
+
spec = spec.cpu().numpy().astype(np.float32)
|
| 62 |
+
|
| 63 |
+
hidden = torch.as_tensor(
|
| 64 |
+
self.model.run(
|
| 65 |
+
[self.model.get_outputs()[0].name],
|
| 66 |
+
{self.model.get_inputs()[0].name: spec}
|
| 67 |
+
)[0],
|
| 68 |
+
device=self.device
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
if self.is_half: spec = spec.half()
|
| 72 |
+
hidden = self.model(spec)
|
| 73 |
+
|
| 74 |
+
return hidden
|
| 75 |
+
|
| 76 |
+
def infer_from_audio(self, audio, thred=0.03):
|
| 77 |
+
if torch.is_tensor(audio): audio = audio.cpu().numpy()
|
| 78 |
+
if audio.ndim > 1: audio = audio.squeeze()
|
| 79 |
+
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
padded_audio = self.pad_audio(audio)
|
| 82 |
+
hidden = self.inference(padded_audio)[:(audio.shape[-1] // int(SAMPLE_RATE // 100) + 1)]
|
| 83 |
+
|
| 84 |
+
f0 = self.decode(hidden.squeeze(0).cpu().numpy(), thred)
|
| 85 |
+
if self.kernel_size is not None: f0 = medfilt(f0, kernel_size=self.kernel_size)
|
| 86 |
+
|
| 87 |
+
return f0
|
| 88 |
+
|
| 89 |
+
def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
|
| 90 |
+
f0 = self.infer_from_audio(audio, thred)
|
| 91 |
+
f0[(f0 < f0_min) | (f0 > f0_max)] = 0
|
| 92 |
+
|
| 93 |
+
return f0
|
| 94 |
+
|
| 95 |
+
def to_local_average_cents(self, salience, thred=0.05):
|
| 96 |
+
center = np.argmax(salience, axis=1)
|
| 97 |
+
salience = np.pad(salience, ((0, 0), (4, 4)))
|
| 98 |
+
center += 4
|
| 99 |
+
todo_salience, todo_cents_mapping = [], []
|
| 100 |
+
starts = center - 4
|
| 101 |
+
ends = center + 5
|
| 102 |
+
|
| 103 |
+
for idx in range(salience.shape[0]):
|
| 104 |
+
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
|
| 105 |
+
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
|
| 106 |
+
|
| 107 |
+
todo_salience = np.array(todo_salience)
|
| 108 |
+
devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
|
| 109 |
+
devided[np.max(salience, axis=1) <= thred] = 0
|
| 110 |
+
|
| 111 |
+
return devided
|
| 112 |
+
|
| 113 |
+
def decode(self, hidden, thred=0.03):
|
| 114 |
+
f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
|
| 115 |
+
f0[f0 == 10] = 0
|
| 116 |
+
return f0
|
| 117 |
+
|
| 118 |
+
def pad_audio(self, audio):
|
| 119 |
+
audio_len = audio.shape[-1]
|
| 120 |
+
|
| 121 |
+
seg_nums = int(np.ceil(audio_len / self.seg_len)) + 1
|
| 122 |
+
pad_len = int(seg_nums * self.seg_len - audio_len + self.seg_len // 2)
|
| 123 |
+
|
| 124 |
+
left_pad = np.zeros(int(self.seg_len // 4), dtype=np.float32)
|
| 125 |
+
right_pad = np.zeros(int(pad_len - self.seg_len // 4), dtype=np.float32)
|
| 126 |
+
padded_audio = np.concatenate([left_pad, audio, right_pad], axis=-1)
|
| 127 |
+
|
| 128 |
+
segments = [
|
| 129 |
+
padded_audio[start: start + int(self.seg_len)]
|
| 130 |
+
for start in range(
|
| 131 |
+
0,
|
| 132 |
+
len(padded_audio) - int(self.seg_len) + 1,
|
| 133 |
+
int(self.seg_len // 2)
|
| 134 |
+
)
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
segments = np.stack(segments, axis=0)
|
| 138 |
+
segments = torch.from_numpy(segments).unsqueeze(1).to(self.device)
|
| 139 |
+
|
| 140 |
+
return segments
|
| 141 |
+
|
| 142 |
+
def inference(self, segments):
|
| 143 |
+
hidden_segments = torch.cat([
|
| 144 |
+
self.spec2hidden(self.spec_extractor(segments[i:i + self.batch_size].float()))
|
| 145 |
+
for i in range(0, len(segments), self.batch_size)
|
| 146 |
+
], dim=0)
|
| 147 |
+
|
| 148 |
+
hidden = torch.cat([
|
| 149 |
+
seg[self.seg_frames // 4: int(self.seg_frames * 0.75)]
|
| 150 |
+
for seg in hidden_segments
|
| 151 |
+
], dim=0)
|
| 152 |
+
|
| 153 |
+
return hidden
|
infer/lib/predictors/DJCM/decoder.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
sys.path.append(os.getcwd())
|
| 9 |
+
|
| 10 |
+
from main.library.predictors.DJCM.encoder import ResEncoderBlock
|
| 11 |
+
from main.library.predictors.DJCM.utils import ResConvBlock, BiGRU, init_bn, init_layer
|
| 12 |
+
|
| 13 |
+
class ResDecoderBlock(nn.Module):
|
| 14 |
+
def __init__(self, in_channels, out_channels, n_blocks, stride):
|
| 15 |
+
super(ResDecoderBlock, self).__init__()
|
| 16 |
+
self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, stride, stride, (0, 0), bias=False)
|
| 17 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=0.01)
|
| 18 |
+
self.conv = nn.ModuleList([ResConvBlock(out_channels * 2, out_channels)])
|
| 19 |
+
|
| 20 |
+
for _ in range(n_blocks - 1):
|
| 21 |
+
self.conv.append(ResConvBlock(out_channels, out_channels))
|
| 22 |
+
|
| 23 |
+
self.init_weights()
|
| 24 |
+
|
| 25 |
+
def init_weights(self):
|
| 26 |
+
init_bn(self.bn1)
|
| 27 |
+
init_layer(self.conv1)
|
| 28 |
+
|
| 29 |
+
def forward(self, x, concat):
|
| 30 |
+
x = self.conv1(F.relu_(self.bn1(x)))
|
| 31 |
+
x = torch.cat((x, concat), dim=1)
|
| 32 |
+
|
| 33 |
+
for each_layer in self.conv:
|
| 34 |
+
x = each_layer(x)
|
| 35 |
+
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
class Decoder(nn.Module):
|
| 39 |
+
def __init__(self, n_blocks):
|
| 40 |
+
super(Decoder, self).__init__()
|
| 41 |
+
self.de_blocks = nn.ModuleList([
|
| 42 |
+
ResDecoderBlock(384, 384, n_blocks, (1, 2)),
|
| 43 |
+
ResDecoderBlock(384, 384, n_blocks, (1, 2)),
|
| 44 |
+
ResDecoderBlock(384, 256, n_blocks, (1, 2)),
|
| 45 |
+
ResDecoderBlock(256, 128, n_blocks, (1, 2)),
|
| 46 |
+
ResDecoderBlock(128, 64, n_blocks, (1, 2)),
|
| 47 |
+
ResDecoderBlock(64, 32, n_blocks, (1, 2))
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
def forward(self, x, concat_tensors):
|
| 51 |
+
for i, layer in enumerate(self.de_blocks):
|
| 52 |
+
x = layer(x, concat_tensors[-1 - i])
|
| 53 |
+
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
class PE_Decoder(nn.Module):
|
| 57 |
+
def __init__(self, n_blocks, seq_layers=1, window_length = 1024, n_class = 360):
|
| 58 |
+
super(PE_Decoder, self).__init__()
|
| 59 |
+
self.de_blocks = Decoder(n_blocks)
|
| 60 |
+
self.after_conv1 = ResEncoderBlock(32, 32, n_blocks, None)
|
| 61 |
+
self.after_conv2 = nn.Conv2d(32, 1, (1, 1))
|
| 62 |
+
self.fc = nn.Sequential(
|
| 63 |
+
BiGRU(
|
| 64 |
+
(1, window_length // 2),
|
| 65 |
+
1,
|
| 66 |
+
seq_layers
|
| 67 |
+
),
|
| 68 |
+
nn.Linear(
|
| 69 |
+
window_length // 2,
|
| 70 |
+
n_class
|
| 71 |
+
),
|
| 72 |
+
nn.Sigmoid()
|
| 73 |
+
)
|
| 74 |
+
init_layer(self.after_conv2)
|
| 75 |
+
|
| 76 |
+
def forward(self, x, concat_tensors):
|
| 77 |
+
return self.fc(self.after_conv2(self.after_conv1(self.de_blocks(x, concat_tensors)))).squeeze(1)
|
| 78 |
+
|
| 79 |
+
class SVS_Decoder(nn.Module):
|
| 80 |
+
def __init__(self, in_channels, n_blocks):
|
| 81 |
+
super(SVS_Decoder, self).__init__()
|
| 82 |
+
self.de_blocks = Decoder(n_blocks)
|
| 83 |
+
self.after_conv1 = ResEncoderBlock(32, 32, n_blocks, None)
|
| 84 |
+
self.after_conv2 = nn.Conv2d(32, in_channels * 4, (1, 1))
|
| 85 |
+
self.init_weights()
|
| 86 |
+
|
| 87 |
+
def init_weights(self):
|
| 88 |
+
init_layer(self.after_conv2)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, concat_tensors):
|
| 91 |
+
return self.after_conv2(self.after_conv1(self.de_blocks(x, concat_tensors)))
|
infer/lib/predictors/DJCM/encoder.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
sys.path.append(os.getcwd())
|
| 7 |
+
|
| 8 |
+
from main.library.predictors.DJCM.utils import ResConvBlock
|
| 9 |
+
|
| 10 |
+
class ResEncoderBlock(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
in_channels,
|
| 14 |
+
out_channels,
|
| 15 |
+
n_blocks,
|
| 16 |
+
kernel_size
|
| 17 |
+
):
|
| 18 |
+
super(ResEncoderBlock, self).__init__()
|
| 19 |
+
self.conv = nn.ModuleList([
|
| 20 |
+
ResConvBlock(
|
| 21 |
+
in_channels,
|
| 22 |
+
out_channels
|
| 23 |
+
)
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
for _ in range(n_blocks - 1):
|
| 27 |
+
self.conv.append(
|
| 28 |
+
ResConvBlock(
|
| 29 |
+
out_channels,
|
| 30 |
+
out_channels
|
| 31 |
+
)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
self.pool = nn.MaxPool2d(kernel_size) if kernel_size is not None else None
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
for each_layer in self.conv:
|
| 38 |
+
x = each_layer(x)
|
| 39 |
+
|
| 40 |
+
if self.pool is not None: return x, self.pool(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
class Encoder(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
in_channels,
|
| 47 |
+
n_blocks
|
| 48 |
+
):
|
| 49 |
+
super(Encoder, self).__init__()
|
| 50 |
+
self.en_blocks = nn.ModuleList([
|
| 51 |
+
ResEncoderBlock(
|
| 52 |
+
in_channels,
|
| 53 |
+
32,
|
| 54 |
+
n_blocks,
|
| 55 |
+
(1, 2)
|
| 56 |
+
),
|
| 57 |
+
ResEncoderBlock(
|
| 58 |
+
32,
|
| 59 |
+
64,
|
| 60 |
+
n_blocks,
|
| 61 |
+
(1, 2)
|
| 62 |
+
),
|
| 63 |
+
ResEncoderBlock(
|
| 64 |
+
64,
|
| 65 |
+
128,
|
| 66 |
+
n_blocks,
|
| 67 |
+
(1, 2)
|
| 68 |
+
),
|
| 69 |
+
ResEncoderBlock(
|
| 70 |
+
128,
|
| 71 |
+
256,
|
| 72 |
+
n_blocks,
|
| 73 |
+
(1, 2)
|
| 74 |
+
),
|
| 75 |
+
ResEncoderBlock(
|
| 76 |
+
256,
|
| 77 |
+
384,
|
| 78 |
+
n_blocks,
|
| 79 |
+
(1, 2)
|
| 80 |
+
),
|
| 81 |
+
ResEncoderBlock(
|
| 82 |
+
384,
|
| 83 |
+
384,
|
| 84 |
+
n_blocks,
|
| 85 |
+
(1, 2)
|
| 86 |
+
)
|
| 87 |
+
])
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
concat_tensors = []
|
| 91 |
+
|
| 92 |
+
for layer in self.en_blocks:
|
| 93 |
+
_, x = layer(x)
|
| 94 |
+
concat_tensors.append(_)
|
| 95 |
+
|
| 96 |
+
return x, concat_tensors
|
infer/lib/predictors/DJCM/model.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
sys.path.append(os.getcwd())
|
| 7 |
+
|
| 8 |
+
from main.library.predictors.DJCM.utils import init_bn
|
| 9 |
+
from main.library.predictors.DJCM.decoder import PE_Decoder, SVS_Decoder
|
| 10 |
+
from main.library.predictors.DJCM.encoder import ResEncoderBlock, Encoder
|
| 11 |
+
|
| 12 |
+
class LatentBlocks(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
n_blocks,
|
| 16 |
+
latent_layers
|
| 17 |
+
):
|
| 18 |
+
super(LatentBlocks, self).__init__()
|
| 19 |
+
self.latent_blocks = nn.ModuleList([
|
| 20 |
+
ResEncoderBlock(
|
| 21 |
+
384,
|
| 22 |
+
384,
|
| 23 |
+
n_blocks,
|
| 24 |
+
None
|
| 25 |
+
)
|
| 26 |
+
for _ in range(latent_layers)
|
| 27 |
+
])
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
for layer in self.latent_blocks:
|
| 31 |
+
x = layer(x)
|
| 32 |
+
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
class DJCMM(nn.Module):
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
in_channels,
|
| 39 |
+
n_blocks,
|
| 40 |
+
latent_layers,
|
| 41 |
+
svs=False,
|
| 42 |
+
window_length=1024,
|
| 43 |
+
n_class=360
|
| 44 |
+
):
|
| 45 |
+
super(DJCMM, self).__init__()
|
| 46 |
+
self.bn = nn.BatchNorm2d(
|
| 47 |
+
window_length // 2 + 1,
|
| 48 |
+
momentum=0.01
|
| 49 |
+
)
|
| 50 |
+
self.pe_encoder = Encoder(
|
| 51 |
+
in_channels,
|
| 52 |
+
n_blocks
|
| 53 |
+
)
|
| 54 |
+
self.pe_latent = LatentBlocks(
|
| 55 |
+
n_blocks,
|
| 56 |
+
latent_layers
|
| 57 |
+
)
|
| 58 |
+
self.pe_decoder = PE_Decoder(
|
| 59 |
+
n_blocks,
|
| 60 |
+
window_length=window_length,
|
| 61 |
+
n_class=n_class
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.svs = svs
|
| 65 |
+
|
| 66 |
+
if svs:
|
| 67 |
+
self.svs_encoder = Encoder(
|
| 68 |
+
in_channels,
|
| 69 |
+
n_blocks
|
| 70 |
+
)
|
| 71 |
+
self.svs_latent = LatentBlocks(
|
| 72 |
+
n_blocks,
|
| 73 |
+
latent_layers
|
| 74 |
+
)
|
| 75 |
+
self.svs_decoder = SVS_Decoder(
|
| 76 |
+
in_channels,
|
| 77 |
+
n_blocks
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
init_bn(self.bn)
|
| 81 |
+
|
| 82 |
+
def spec(self, x, spec_m):
|
| 83 |
+
bs, c, time_steps, freqs_steps = x.shape
|
| 84 |
+
x = x.reshape(bs, c // 4, 4, time_steps, freqs_steps)
|
| 85 |
+
|
| 86 |
+
mask_spec = x[:, :, 0, :, :].sigmoid()
|
| 87 |
+
linear_spec = x[:, :, 3, :, :]
|
| 88 |
+
|
| 89 |
+
out_spec = (
|
| 90 |
+
spec_m.detach() * mask_spec + linear_spec
|
| 91 |
+
).relu()
|
| 92 |
+
|
| 93 |
+
return out_spec
|
| 94 |
+
|
| 95 |
+
def forward(self, spec):
|
| 96 |
+
x = self.bn(
|
| 97 |
+
spec.transpose(1, 3)
|
| 98 |
+
).transpose(1, 3)[..., :-1]
|
| 99 |
+
|
| 100 |
+
if self.svs:
|
| 101 |
+
x, concat_tensors = self.svs_encoder(x)
|
| 102 |
+
|
| 103 |
+
x = self.svs_decoder(
|
| 104 |
+
self.svs_latent(x),
|
| 105 |
+
concat_tensors
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
x = self.spec(
|
| 109 |
+
nn.functional.pad(x, pad=(0, 1)),
|
| 110 |
+
spec
|
| 111 |
+
)[..., :-1]
|
| 112 |
+
|
| 113 |
+
x, concat_tensors = self.pe_encoder(x)
|
| 114 |
+
|
| 115 |
+
pe_out = self.pe_decoder(
|
| 116 |
+
self.pe_latent(x),
|
| 117 |
+
concat_tensors
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return pe_out
|
infer/lib/predictors/DJCM/spec.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
sys.path.append(os.getcwd())
|
| 9 |
+
|
| 10 |
+
class Spectrogram(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
hop_length,
|
| 14 |
+
win_length,
|
| 15 |
+
n_fft=None,
|
| 16 |
+
clamp=1e-10
|
| 17 |
+
):
|
| 18 |
+
super(Spectrogram, self).__init__()
|
| 19 |
+
self.n_fft = win_length if n_fft is None else n_fft
|
| 20 |
+
self.hop_length = hop_length
|
| 21 |
+
self.win_length = win_length
|
| 22 |
+
self.clamp = clamp
|
| 23 |
+
self.register_buffer("window", torch.hann_window(win_length), persistent=False)
|
| 24 |
+
|
| 25 |
+
def forward(self, audio, center=True):
|
| 26 |
+
bs, c, segment_samples = audio.shape
|
| 27 |
+
audio = audio.reshape(bs * c, segment_samples)
|
| 28 |
+
|
| 29 |
+
if str(audio.device).startswith(("ocl", "privateuseone")):
|
| 30 |
+
if not hasattr(self, "stft"):
|
| 31 |
+
from main.library.backends.utils import STFT
|
| 32 |
+
|
| 33 |
+
self.stft = STFT(
|
| 34 |
+
filter_length=self.n_fft,
|
| 35 |
+
hop_length=self.hop_length,
|
| 36 |
+
win_length=self.win_length
|
| 37 |
+
).to(audio.device)
|
| 38 |
+
|
| 39 |
+
magnitude = self.stft.transform(audio, 1e-9)
|
| 40 |
+
else:
|
| 41 |
+
fft = torch.stft(
|
| 42 |
+
audio,
|
| 43 |
+
n_fft=self.n_fft,
|
| 44 |
+
hop_length=self.hop_length,
|
| 45 |
+
win_length=self.win_length,
|
| 46 |
+
window=self.window,
|
| 47 |
+
center=center,
|
| 48 |
+
pad_mode="reflect",
|
| 49 |
+
return_complex=True
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
magnitude = (fft.real.pow(2) + fft.imag.pow(2)).sqrt()
|
| 53 |
+
|
| 54 |
+
mag = magnitude.transpose(1, 2).clamp(self.clamp, np.inf)
|
| 55 |
+
mag = mag.reshape(bs, c, mag.shape[1], mag.shape[2])
|
| 56 |
+
|
| 57 |
+
return mag
|
infer/lib/predictors/DJCM/utils.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
from einops.layers.torch import Rearrange
|
| 5 |
+
|
| 6 |
+
def init_layer(layer):
|
| 7 |
+
nn.init.xavier_uniform_(layer.weight)
|
| 8 |
+
|
| 9 |
+
if hasattr(layer, "bias") and layer.bias is not None:
|
| 10 |
+
layer.bias.data.fill_(0.0)
|
| 11 |
+
|
| 12 |
+
def init_bn(bn):
|
| 13 |
+
bn.bias.data.fill_(0.0)
|
| 14 |
+
bn.weight.data.fill_(1.0)
|
| 15 |
+
bn.running_mean.data.fill_(0.0)
|
| 16 |
+
bn.running_var.data.fill_(1.0)
|
| 17 |
+
|
| 18 |
+
class BiGRU(nn.Module):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
patch_size,
|
| 22 |
+
channels,
|
| 23 |
+
depth
|
| 24 |
+
):
|
| 25 |
+
super(BiGRU, self).__init__()
|
| 26 |
+
patch_width, patch_height = patch_size
|
| 27 |
+
patch_dim = channels * patch_height * patch_width
|
| 28 |
+
|
| 29 |
+
self.to_patch_embedding = nn.Sequential(
|
| 30 |
+
Rearrange(
|
| 31 |
+
'b c (w p1) (h p2) -> b (w h) (p1 p2 c)',
|
| 32 |
+
p1=patch_width,
|
| 33 |
+
p2=patch_height
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.gru = nn.GRU(
|
| 38 |
+
patch_dim,
|
| 39 |
+
patch_dim // 2,
|
| 40 |
+
num_layers=depth,
|
| 41 |
+
batch_first=True,
|
| 42 |
+
bidirectional=True
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
x = self.to_patch_embedding(x)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
return self.gru(x)[0]
|
| 50 |
+
except:
|
| 51 |
+
torch.backends.cudnn.enabled = False
|
| 52 |
+
return self.gru(x)[0]
|
| 53 |
+
|
| 54 |
+
class ResConvBlock(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_planes,
|
| 58 |
+
out_planes
|
| 59 |
+
):
|
| 60 |
+
super(ResConvBlock, self).__init__()
|
| 61 |
+
self.bn1 = nn.BatchNorm2d(
|
| 62 |
+
in_planes,
|
| 63 |
+
momentum=0.01
|
| 64 |
+
)
|
| 65 |
+
self.bn2 = nn.BatchNorm2d(
|
| 66 |
+
out_planes,
|
| 67 |
+
momentum=0.01
|
| 68 |
+
)
|
| 69 |
+
self.act1 = nn.PReLU()
|
| 70 |
+
self.act2 = nn.PReLU()
|
| 71 |
+
self.conv1 = nn.Conv2d(
|
| 72 |
+
in_planes,
|
| 73 |
+
out_planes,
|
| 74 |
+
(3, 3),
|
| 75 |
+
padding=(1, 1),
|
| 76 |
+
bias=False
|
| 77 |
+
)
|
| 78 |
+
self.conv2 = nn.Conv2d(
|
| 79 |
+
out_planes,
|
| 80 |
+
out_planes,
|
| 81 |
+
(3, 3),
|
| 82 |
+
padding=(1, 1),
|
| 83 |
+
bias=False
|
| 84 |
+
)
|
| 85 |
+
self.is_shortcut = False
|
| 86 |
+
|
| 87 |
+
if in_planes != out_planes:
|
| 88 |
+
self.shortcut = nn.Conv2d(
|
| 89 |
+
in_planes,
|
| 90 |
+
out_planes,
|
| 91 |
+
(1, 1)
|
| 92 |
+
)
|
| 93 |
+
self.is_shortcut = True
|
| 94 |
+
|
| 95 |
+
self.init_weights()
|
| 96 |
+
|
| 97 |
+
def init_weights(self):
|
| 98 |
+
init_bn(self.bn1)
|
| 99 |
+
init_bn(self.bn2)
|
| 100 |
+
|
| 101 |
+
init_layer(self.conv1)
|
| 102 |
+
init_layer(self.conv2)
|
| 103 |
+
|
| 104 |
+
if self.is_shortcut: init_layer(self.shortcut)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
out = self.conv1(
|
| 108 |
+
self.act1(self.bn1(x))
|
| 109 |
+
)
|
| 110 |
+
out = self.conv2(
|
| 111 |
+
self.act2(self.bn2(out))
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if self.is_shortcut: return self.shortcut(x) + out
|
| 115 |
+
else: return out + x
|