Update model.py
#1
by
sin2piusc
- opened
model.py
CHANGED
@@ -130,6 +130,7 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
|
|
130 |
axs[current_ax].set_ylabel("Mel Bin")
|
131 |
axs[current_ax].set_xlim([0, max_time])
|
132 |
axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
|
|
|
133 |
current_ax += 1
|
134 |
|
135 |
if p is not None:
|
@@ -237,52 +238,9 @@ def sinusoids(length, channels, max_timescale=10000):
|
|
237 |
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
238 |
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
239 |
|
240 |
-
class ParameterCycler:
|
241 |
-
def __init__(self, parameters):
|
242 |
-
self.parameters = parameters
|
243 |
-
self.current_idx = 0
|
244 |
-
def toggle_requires_grad(self):
|
245 |
-
x = random.randint(0, len(self.parameters) - 1)
|
246 |
-
for x, param in enumerate(self.parameters):
|
247 |
-
param.requires_grad = (x == self.current_idx)
|
248 |
-
print(f"Parameter {x}: requires_grad={param.requires_grad}")
|
249 |
-
self.current_idx = (self.current_idx + 1) % len(self.parameters)
|
250 |
-
|
251 |
-
def extract_f0(waveform, sampling_rate=16000, hop_length=128, device="cuda:0"):
|
252 |
-
"""Extract F0 from waveform - handle various input types"""
|
253 |
-
if waveform is None:
|
254 |
-
return None
|
255 |
-
|
256 |
-
if isinstance(waveform, list):
|
257 |
-
if len(waveform) == 0:
|
258 |
-
return None
|
259 |
-
waveform = waveform[0]
|
260 |
-
print(f"DEBUG: Converted list to tensor, new type: {type(waveform)}")
|
261 |
-
|
262 |
-
if not isinstance(waveform, torch.Tensor):
|
263 |
-
waveform = torch.tensor(waveform)
|
264 |
-
|
265 |
-
if isinstance(waveform, torch.Tensor):
|
266 |
-
if waveform.dim() == 3:
|
267 |
-
waveform = waveform.squeeze(1)
|
268 |
-
if waveform.dim() == 2:
|
269 |
-
waveform = waveform[0]
|
270 |
-
|
271 |
-
wav_np = waveform.detach().cpu().numpy().astype(np.float64)
|
272 |
-
else:
|
273 |
-
wav_np = np.array(waveform).astype(np.float64)
|
274 |
-
|
275 |
-
f0, t = pw.dio(wav_np, sampling_rate,
|
276 |
-
frame_period=hop_length/sampling_rate*1000)
|
277 |
-
f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
|
278 |
-
|
279 |
-
f0_tensor = torch.from_numpy(f0).float().to(device)
|
280 |
-
return f0_tensor.unsqueeze(0).unsqueeze(0)
|
281 |
-
|
282 |
class rotary(nn.Module):
|
283 |
-
|
284 |
-
def __init__(self, dims, max_ctx=1500, theta=10000,
|
285 |
-
learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = [], use_pbias = False):
|
286 |
super().__init__()
|
287 |
|
288 |
self.use_pbias = use_pbias
|
@@ -309,49 +267,64 @@ class rotary(nn.Module):
|
|
309 |
self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
|
310 |
self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
|
311 |
|
312 |
-
def
|
313 |
-
|
314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
if isinstance(x, int):
|
316 |
ctx = x
|
317 |
else:
|
318 |
batch, ctx, dims = x.shape
|
319 |
t = torch.arange(ctx, device=self.device).float()
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
if f0 is not None:
|
322 |
f0_mean=f0.mean()+1e-8
|
323 |
theta=f0_mean*self.pitch_scale
|
324 |
freqs = 1. / (theta ** (torch.arange(0, self.dims, 2, device=self.device, dtype=self.dtype)[:(self.dims // 2)].float() /self.dims))
|
325 |
else:
|
326 |
freqs = self.freqs
|
327 |
-
|
328 |
freqs = torch.einsum('i,j->ij', t, freqs)
|
329 |
freqs = freqs.float()
|
330 |
-
|
331 |
if self.radii:
|
332 |
-
|
333 |
-
radius =
|
|
|
334 |
radius = radius.float()
|
335 |
-
|
336 |
else:
|
337 |
radius = self.radius
|
338 |
-
# freqs = torch.polar(self.radius.unsqueeze(-1), freqs)
|
339 |
freqs = torch.polar(radius.unsqueeze(-1), freqs)
|
340 |
|
341 |
-
if "rotary" in self.debug:
|
342 |
if f0 is not None:
|
343 |
-
|
344 |
-
if
|
345 |
-
|
346 |
-
self._prev_f0_theta = theta
|
347 |
-
# print(f"Step {self._counter}: Theta: {theta:.2f} Hz")
|
348 |
-
elif abs(self._prev_f0_theta - theta) > 100.0:
|
349 |
-
# print(f"Step {self._counter}: Theta: {theta:.2f} Hz, freqs: {freqs.shape}")
|
350 |
-
print(f"{layer} : {f0_mean} : Theta: {theta:.2f} : {theta:.2f} : {ctx} ")
|
351 |
-
if self.radii:
|
352 |
-
print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")
|
353 |
-
self._prev_f0_theta = theta
|
354 |
-
rotary._seen.add(key)
|
355 |
self._counter += 1
|
356 |
return freqs
|
357 |
|
@@ -440,19 +413,19 @@ class MultiheadA(nn.Module):
|
|
440 |
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
441 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
442 |
|
443 |
-
def forward(self, x: Tensor, xa: Tensor = None,
|
444 |
|
|
|
445 |
scale = (self.dims // self.head) ** -0.25
|
446 |
|
447 |
z = xa if xa is not None else x
|
448 |
q = self.q(x).to(x.dtype)
|
449 |
k = self.k(z).to(x.dtype)
|
450 |
v = self.v(z).to(x.dtype)
|
451 |
-
batch, ctx, dims = q.shape
|
452 |
|
453 |
if self.rotary_emb:
|
454 |
-
qf = self.rope(q.size(1),
|
455 |
-
kf = self.rope(k.size(1),
|
456 |
|
457 |
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
458 |
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
@@ -466,13 +439,13 @@ class MultiheadA(nn.Module):
|
|
466 |
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
467 |
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
468 |
batch, head, ctx, head_dim = q.shape
|
469 |
-
|
470 |
if self.rbf:
|
471 |
qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
|
472 |
|
473 |
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
474 |
-
if self.rope.use_pbias:
|
475 |
-
pbias = self.rope.
|
476 |
if pbias is not None:
|
477 |
qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
|
478 |
token_ids = k[:, :, :, 0]
|
@@ -484,6 +457,7 @@ class MultiheadA(nn.Module):
|
|
484 |
mask = mask[:q.shape[2], :q.shape[2]]
|
485 |
qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
|
486 |
qk = qk * zscale.unsqueeze(-2)
|
|
|
487 |
w = F.softmax(qk, dim=-1).to(q.dtype)
|
488 |
wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
489 |
|
@@ -539,7 +513,6 @@ class c_gate(nn.Module):
|
|
539 |
s = self.s_gate(x) * s_feat
|
540 |
w = self.w_gate(x) * w_feat
|
541 |
p = self.p_gate(x) * p_feat
|
542 |
-
|
543 |
comb = torch.cat([s, w, p], dim=-1)
|
544 |
return self.integ(comb)
|
545 |
|
@@ -590,12 +563,12 @@ class Residual(nn.Module):
|
|
590 |
if not any([t_gate, m_gate, c_gate]):
|
591 |
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
592 |
|
593 |
-
def forward(self, x
|
594 |
bln = self.blend
|
595 |
-
x = x + self.attna(self.lna(x), xa=None, mask=mask,
|
596 |
|
597 |
if self.attnb and xa is not None:
|
598 |
-
c = self.attnb(self.lnb(x), xa, mask=None, layer=layer
|
599 |
b = torch.sigmoid(bln)
|
600 |
x = b * x + (1 - b) * c
|
601 |
|
@@ -610,7 +583,7 @@ class Residual(nn.Module):
|
|
610 |
gate = self.m_gate(normx)
|
611 |
x = x + gate * mlp_out
|
612 |
|
613 |
-
elif self.c_gate is not None:
|
614 |
gate_output = self.c_gate(normx, self.features)
|
615 |
x = x + gate_output
|
616 |
|
@@ -650,7 +623,7 @@ class PEncoder(nn.Module):
|
|
650 |
Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
|
651 |
Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
|
652 |
|
653 |
-
def forward(self, x,
|
654 |
x = self.encoder(x).permute(0, 2, 1)
|
655 |
x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
|
656 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
@@ -679,7 +652,7 @@ class WEncoder(nn.Module):
|
|
679 |
self.positional = lambda length: sinusoids(length, dims)
|
680 |
self.norm = RMSNorm(dims)
|
681 |
|
682 |
-
def forward(self, x,
|
683 |
x = self.downsample(x)
|
684 |
x = self.encoder(x)
|
685 |
x = x.permute(0, 2, 1)
|
@@ -706,49 +679,13 @@ class FEncoder(nn.Module):
|
|
706 |
self.norm = RMSNorm(dims)
|
707 |
self._norm = RMSNorm(dims)
|
708 |
|
709 |
-
def forward(self, x,
|
710 |
x = self.encoder(x).permute(0, 2, 1)
|
711 |
x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
|
712 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
713 |
x = self._norm(x)
|
714 |
return x
|
715 |
-
|
716 |
-
class F0Encoder(nn.Module):
|
717 |
-
def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1):
|
718 |
-
super().__init__()
|
719 |
-
|
720 |
-
self.head_dim = dims // head
|
721 |
-
self.dropout = 0.01
|
722 |
-
|
723 |
-
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
|
724 |
-
"tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
|
725 |
-
"softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
|
726 |
-
"leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
|
727 |
-
act_fn = act_map.get(act, nn.GELU())
|
728 |
-
|
729 |
-
self.encoder = nn.Sequential(
|
730 |
-
Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
|
731 |
-
Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
|
732 |
-
Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
|
733 |
-
|
734 |
-
self.positional = lambda length: sinusoids(length, dims)
|
735 |
-
self.norm = RMSNorm(dims)
|
736 |
-
self._norm = RMSNorm(dims)
|
737 |
-
|
738 |
-
def forward(self, x, feat=None, layer=None):
|
739 |
-
if x.dim() == 3 and x.shape[0] == 1 and x.shape[1] == 1:
|
740 |
-
pass
|
741 |
-
elif x.dim() == 2:
|
742 |
-
x = x.unsqueeze(1)
|
743 |
-
elif x.dim() == 1:
|
744 |
-
x = x.unsqueeze(0).unsqueeze(0)
|
745 |
-
x = self.encoder(x)
|
746 |
-
x = x.permute(0, 2, 1)
|
747 |
-
x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
|
748 |
-
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
749 |
-
x = self._norm(x)
|
750 |
-
return x
|
751 |
-
|
752 |
class AudioEncoder(nn.Module):
|
753 |
_seen = set()
|
754 |
def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str],
|
@@ -760,10 +697,8 @@ class AudioEncoder(nn.Module):
|
|
760 |
self.ctx = ctx
|
761 |
self.head_dim = dims // head
|
762 |
|
763 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
764 |
-
dtype = torch.float32
|
765 |
-
self.device = device
|
766 |
-
self.dtype = dtype
|
767 |
self.debug = debug
|
768 |
self._counter = 0
|
769 |
|
@@ -772,7 +707,8 @@ class AudioEncoder(nn.Module):
|
|
772 |
self.f0_rotary = f0_rotary
|
773 |
|
774 |
self.rope = rotary(
|
775 |
-
dims=self.head_dim
|
|
|
776 |
|
777 |
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),
|
778 |
"tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
|
@@ -809,32 +745,27 @@ class AudioEncoder(nn.Module):
|
|
809 |
FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)
|
810 |
for _ in range(layer)])
|
811 |
|
812 |
-
def forward(self,
|
813 |
-
|
814 |
if self._counter < 1:
|
815 |
-
s =
|
816 |
-
w =
|
817 |
-
p =
|
818 |
plot_waveform(x=s, w=w, p=p, hop_length=128)
|
819 |
|
820 |
enc = {}
|
821 |
-
enc.update(feat)
|
822 |
-
|
823 |
-
for f in self.features:
|
824 |
-
if f in feat and f in self.blocks:
|
825 |
-
x = feat[f]
|
826 |
-
for block in self.blocks[f]:
|
827 |
-
x = block(x, feat=feat, layer=layer)
|
828 |
-
enc[f] = x
|
829 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
830 |
if "encoder" in self.debug and self._counter % 100 == 0:
|
831 |
-
names = list(
|
832 |
-
shapes = {k: v.shape for k, v in
|
833 |
print(f"Step {self._counter}: mode: {names}")
|
834 |
print(f"shapes: {shapes}")
|
835 |
-
for name, param in self.named_parameters():
|
836 |
-
if param.requires_grad:
|
837 |
-
print(f"ENCODER LAYER {name}: grad_norm={param.median():.4f}")
|
838 |
self._counter += 1
|
839 |
return enc
|
840 |
|
@@ -848,10 +779,8 @@ class TextDecoder(nn.Module):
|
|
848 |
self.ctx = ctx
|
849 |
self.head_dim = dims // head
|
850 |
|
851 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
852 |
-
dtype = torch.float32
|
853 |
-
self.device = device
|
854 |
-
self.dtype = dtype
|
855 |
self.debug = debug
|
856 |
self._counter = 0
|
857 |
|
@@ -878,23 +807,8 @@ class TextDecoder(nn.Module):
|
|
878 |
|
879 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
880 |
self.register_buffer("mask", mask, persistent=False)
|
881 |
-
|
882 |
-
rotary_emb = False
|
883 |
-
if rotary_emb:
|
884 |
-
self.rope = rotary(
|
885 |
-
dims=self.head_dim,
|
886 |
-
debug = debug,
|
887 |
-
radii=False,
|
888 |
-
learned_pitch=False,
|
889 |
-
learned_freq=False,
|
890 |
-
learned_theta=False,
|
891 |
-
learned_radius=False,
|
892 |
-
)
|
893 |
-
else:
|
894 |
-
self.rope = None
|
895 |
|
896 |
-
def forward(self, x,
|
897 |
-
|
898 |
bln = self.blend
|
899 |
x = x.to(device)
|
900 |
if order is None:
|
@@ -902,15 +816,13 @@ class TextDecoder(nn.Module):
|
|
902 |
mask = self.mask[:x.shape[1], :x.shape[1]]
|
903 |
x = self.token(x) + self.positional[:x.shape[1]]
|
904 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
905 |
-
|
906 |
for block in self.block:
|
907 |
-
x = block(x, xa=None,
|
908 |
-
|
909 |
for f in order:
|
910 |
-
if f in
|
911 |
-
xa =
|
912 |
for block in self.blocks[f]:
|
913 |
-
out = block(x=x, xa=xa,
|
914 |
a = torch.sigmoid(bln[f])
|
915 |
x = a * out + (1 - a) * x
|
916 |
x = self.ln_dec(x)
|
@@ -994,13 +906,11 @@ class Echo(nn.Module):
|
|
994 |
encoder_inputs["envelope"] = envelope
|
995 |
if phase is not None:
|
996 |
encoder_inputs["phase"] = phase
|
997 |
-
if f0 is not None:
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
-
|
1002 |
-
encoder_outputs = self.encoder(encoder_inputs)
|
1003 |
-
logits = self.decoder(input_ids, encoder_outputs)
|
1004 |
|
1005 |
loss = None
|
1006 |
if labels is not None:
|
@@ -1071,7 +981,7 @@ class Echo(nn.Module):
|
|
1071 |
print(f"{module_type}: {count}")
|
1072 |
|
1073 |
def register_gradient_hooks(self):
|
1074 |
-
|
1075 |
for name, param in self.named_parameters():
|
1076 |
if param.requires_grad:
|
1077 |
if "encoder" in name:
|
@@ -1096,6 +1006,6 @@ class Echo(nn.Module):
|
|
1096 |
return None
|
1097 |
|
1098 |
def reset_counter(self):
|
|
|
1099 |
self._counter = 0
|
1100 |
print("Counter reset to 0.")
|
1101 |
-
|
|
|
130 |
axs[current_ax].set_ylabel("Mel Bin")
|
131 |
axs[current_ax].set_xlim([0, max_time])
|
132 |
axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
|
133 |
+
# fig.colorbar(im, ax=axs[current_ax])
|
134 |
current_ax += 1
|
135 |
|
136 |
if p is not None:
|
|
|
238 |
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
239 |
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
class rotary(nn.Module):
|
242 |
+
|
243 |
+
def __init__(self, dims, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias = False):
|
|
|
244 |
super().__init__()
|
245 |
|
246 |
self.use_pbias = use_pbias
|
|
|
267 |
self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
|
268 |
self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
|
269 |
|
270 |
+
def get_pitch_bias(self, f0):
|
271 |
+
if f0 is None:
|
272 |
+
return None
|
273 |
+
f0_flat = f0.squeeze().float()
|
274 |
+
f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
|
275 |
+
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
|
276 |
+
f0_norm.unsqueeze(1)) * self.pitch_scale)
|
277 |
+
return f0_sim.unsqueeze(0).unsqueeze(0)
|
278 |
+
|
279 |
+
def align_f0(self, f0, ctx):
|
280 |
+
b, l = f0.shape
|
281 |
+
if l == ctx:
|
282 |
+
return f0.squeeze(0).float()
|
283 |
+
frames_per_token = l / ctx
|
284 |
+
idx = torch.arange(ctx, device=self.device, dtype=torch.float32)
|
285 |
+
src_idx = (idx * frames_per_token).long().clamp(0, l-1)
|
286 |
+
batch_idx = torch.arange(b, device=self.device, dtype=torch.float32).unsqueeze(1)
|
287 |
+
f0 = f0[batch_idx, src_idx]
|
288 |
+
return f0.squeeze(0).float()
|
289 |
+
|
290 |
+
def forward(self, x: Tensor, xa: Tensor = None, f0: Tensor = None, mask: Tensor = None, layer = None) -> Tensor:
|
291 |
+
|
292 |
if isinstance(x, int):
|
293 |
ctx = x
|
294 |
else:
|
295 |
batch, ctx, dims = x.shape
|
296 |
t = torch.arange(ctx, device=self.device).float()
|
297 |
|
298 |
+
if self.learned_adaptation:
|
299 |
+
freqs = self.get_f0_adapted_freqs(ctx, f0)
|
300 |
+
x_complex = torch.view_as_complex(
|
301 |
+
x.float().reshape(*x.shape[:-1], -1, 2).contiguous())
|
302 |
+
x_rotated = x_complex * freqs.unsqueeze(0).unsqueeze(0)
|
303 |
+
freqs = torch.view_as_real(x_rotated).flatten(3).type_as(x)
|
304 |
+
|
305 |
if f0 is not None:
|
306 |
f0_mean=f0.mean()+1e-8
|
307 |
theta=f0_mean*self.pitch_scale
|
308 |
freqs = 1. / (theta ** (torch.arange(0, self.dims, 2, device=self.device, dtype=self.dtype)[:(self.dims // 2)].float() /self.dims))
|
309 |
else:
|
310 |
freqs = self.freqs
|
|
|
311 |
freqs = torch.einsum('i,j->ij', t, freqs)
|
312 |
freqs = freqs.float()
|
313 |
+
|
314 |
if self.radii:
|
315 |
+
radius = self.align_f0(f0, ctx)
|
316 |
+
radius = torch.clamp(radius, min=40.0, max=600.0)
|
317 |
+
radius = radius / 600.0
|
318 |
radius = radius.float()
|
|
|
319 |
else:
|
320 |
radius = self.radius
|
|
|
321 |
freqs = torch.polar(radius.unsqueeze(-1), freqs)
|
322 |
|
323 |
+
if "rotary" in self.debug and self._counter % 100 == 50:
|
324 |
if f0 is not None:
|
325 |
+
print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
|
326 |
+
if self.radii:
|
327 |
+
print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
self._counter += 1
|
329 |
return freqs
|
330 |
|
|
|
413 |
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
414 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
415 |
|
416 |
+
def forward(self, x: Tensor, xa: Tensor = None, f0: Tensor = None, mask: Tensor = None, layer = None) -> tuple:
|
417 |
|
418 |
+
batch, ctx, dims = x.shape
|
419 |
scale = (self.dims // self.head) ** -0.25
|
420 |
|
421 |
z = xa if xa is not None else x
|
422 |
q = self.q(x).to(x.dtype)
|
423 |
k = self.k(z).to(x.dtype)
|
424 |
v = self.v(z).to(x.dtype)
|
|
|
425 |
|
426 |
if self.rotary_emb:
|
427 |
+
qf = self.rope(q.size(1), f0=f0, layer=layer)
|
428 |
+
kf = self.rope(k.size(1), f0=f0, layer=layer)
|
429 |
|
430 |
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
431 |
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
|
|
439 |
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
440 |
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
441 |
batch, head, ctx, head_dim = q.shape
|
442 |
+
|
443 |
if self.rbf:
|
444 |
qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
|
445 |
|
446 |
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
447 |
+
if f0 is not None and self.rope.use_pbias:
|
448 |
+
pbias = self.rope.use_pbias(f0)
|
449 |
if pbias is not None:
|
450 |
qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
|
451 |
token_ids = k[:, :, :, 0]
|
|
|
457 |
mask = mask[:q.shape[2], :q.shape[2]]
|
458 |
qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
|
459 |
qk = qk * zscale.unsqueeze(-2)
|
460 |
+
|
461 |
w = F.softmax(qk, dim=-1).to(q.dtype)
|
462 |
wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
463 |
|
|
|
513 |
s = self.s_gate(x) * s_feat
|
514 |
w = self.w_gate(x) * w_feat
|
515 |
p = self.p_gate(x) * p_feat
|
|
|
516 |
comb = torch.cat([s, w, p], dim=-1)
|
517 |
return self.integ(comb)
|
518 |
|
|
|
563 |
if not any([t_gate, m_gate, c_gate]):
|
564 |
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
565 |
|
566 |
+
def forward(self, x, xa=None, mask=None, f0=None, mode=None, layer=None):
|
567 |
bln = self.blend
|
568 |
+
x = x + self.attna(self.lna(x), xa=None, mask=mask, f0=f0, layer=layer)[0]
|
569 |
|
570 |
if self.attnb and xa is not None:
|
571 |
+
c = self.attnb(self.lnb(x), xa=xa, f0=f0, mask=None, layer=layer)[0]
|
572 |
b = torch.sigmoid(bln)
|
573 |
x = b * x + (1 - b) * c
|
574 |
|
|
|
583 |
gate = self.m_gate(normx)
|
584 |
x = x + gate * mlp_out
|
585 |
|
586 |
+
elif self.c_gate and mode is not None:
|
587 |
gate_output = self.c_gate(normx, self.features)
|
588 |
x = x + gate_output
|
589 |
|
|
|
623 |
Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
|
624 |
Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
|
625 |
|
626 |
+
def forward(self, x, f0=None, layer=None):
|
627 |
x = self.encoder(x).permute(0, 2, 1)
|
628 |
x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
|
629 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
|
|
652 |
self.positional = lambda length: sinusoids(length, dims)
|
653 |
self.norm = RMSNorm(dims)
|
654 |
|
655 |
+
def forward(self, x, f0=None, layer=None):
|
656 |
x = self.downsample(x)
|
657 |
x = self.encoder(x)
|
658 |
x = x.permute(0, 2, 1)
|
|
|
679 |
self.norm = RMSNorm(dims)
|
680 |
self._norm = RMSNorm(dims)
|
681 |
|
682 |
+
def forward(self, x, f0=None, layer=None):
|
683 |
x = self.encoder(x).permute(0, 2, 1)
|
684 |
x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
|
685 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
686 |
x = self._norm(x)
|
687 |
return x
|
688 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
689 |
class AudioEncoder(nn.Module):
|
690 |
_seen = set()
|
691 |
def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str],
|
|
|
697 |
self.ctx = ctx
|
698 |
self.head_dim = dims // head
|
699 |
|
700 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
701 |
+
self.dtype = torch.float32
|
|
|
|
|
702 |
self.debug = debug
|
703 |
self._counter = 0
|
704 |
|
|
|
707 |
self.f0_rotary = f0_rotary
|
708 |
|
709 |
self.rope = rotary(
|
710 |
+
dims=self.head_dim,
|
711 |
+
)
|
712 |
|
713 |
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),
|
714 |
"tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
|
|
|
745 |
FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)
|
746 |
for _ in range(layer)])
|
747 |
|
748 |
+
def forward(self, x, f0=None, layer="ENC"):
|
|
|
749 |
if self._counter < 1:
|
750 |
+
s = x.get("spectrogram")
|
751 |
+
w = x.get("waveform")
|
752 |
+
p = f0 if f0 is not None else x.get("pitch")
|
753 |
plot_waveform(x=s, w=w, p=p, hop_length=128)
|
754 |
|
755 |
enc = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
756 |
|
757 |
+
for y in self.features:
|
758 |
+
if y in x and y in self.blocks:
|
759 |
+
f = x[y]
|
760 |
+
for block in self.blocks[y]:
|
761 |
+
f = block(f, f0=f0, layer=layer)
|
762 |
+
enc[y] = f
|
763 |
+
|
764 |
if "encoder" in self.debug and self._counter % 100 == 0:
|
765 |
+
names = list(x.keys())
|
766 |
+
shapes = {k: v.shape for k, v in x.items()}
|
767 |
print(f"Step {self._counter}: mode: {names}")
|
768 |
print(f"shapes: {shapes}")
|
|
|
|
|
|
|
769 |
self._counter += 1
|
770 |
return enc
|
771 |
|
|
|
779 |
self.ctx = ctx
|
780 |
self.head_dim = dims // head
|
781 |
|
782 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
783 |
+
self.dtype = torch.float32
|
|
|
|
|
784 |
self.debug = debug
|
785 |
self._counter = 0
|
786 |
|
|
|
807 |
|
808 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
809 |
self.register_buffer("mask", mask, persistent=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
810 |
|
811 |
+
def forward(self, x, enc, f0=None, order=None, layer='DEC') -> Tensor:
|
|
|
812 |
bln = self.blend
|
813 |
x = x.to(device)
|
814 |
if order is None:
|
|
|
816 |
mask = self.mask[:x.shape[1], :x.shape[1]]
|
817 |
x = self.token(x) + self.positional[:x.shape[1]]
|
818 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
|
819 |
for block in self.block:
|
820 |
+
x = block(x, xa=None, f0=f0, mask=mask, layer=layer)
|
|
|
821 |
for f in order:
|
822 |
+
if f in enc:
|
823 |
+
xa = enc[f]
|
824 |
for block in self.blocks[f]:
|
825 |
+
out = block(x=x, xa=xa, f0=f0, mask=None, layer=layer)
|
826 |
a = torch.sigmoid(bln[f])
|
827 |
x = a * out + (1 - a) * x
|
828 |
x = self.ln_dec(x)
|
|
|
906 |
encoder_inputs["envelope"] = envelope
|
907 |
if phase is not None:
|
908 |
encoder_inputs["phase"] = phase
|
909 |
+
# if f0 is not None:
|
910 |
+
# encoder_inputs["f0"] = f0
|
911 |
+
|
912 |
+
encoder_outputs = self.encoder(encoder_inputs, f0=f0)
|
913 |
+
logits = self.decoder(input_ids, encoder_outputs, f0=f0d)
|
|
|
|
|
914 |
|
915 |
loss = None
|
916 |
if labels is not None:
|
|
|
981 |
print(f"{module_type}: {count}")
|
982 |
|
983 |
def register_gradient_hooks(self):
|
984 |
+
"""Add this method to your Echo model class"""
|
985 |
for name, param in self.named_parameters():
|
986 |
if param.requires_grad:
|
987 |
if "encoder" in name:
|
|
|
1006 |
return None
|
1007 |
|
1008 |
def reset_counter(self):
|
1009 |
+
"""Reset the internal counter for debugging purposes."""
|
1010 |
self._counter = 0
|
1011 |
print("Counter reset to 0.")
|
|