Files changed (1) hide show
  1. model.py +87 -177
model.py CHANGED
@@ -130,6 +130,7 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
130
  axs[current_ax].set_ylabel("Mel Bin")
131
  axs[current_ax].set_xlim([0, max_time])
132
  axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
 
133
  current_ax += 1
134
 
135
  if p is not None:
@@ -237,52 +238,9 @@ def sinusoids(length, channels, max_timescale=10000):
237
  scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
238
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
239
 
240
- class ParameterCycler:
241
- def __init__(self, parameters):
242
- self.parameters = parameters
243
- self.current_idx = 0
244
- def toggle_requires_grad(self):
245
- x = random.randint(0, len(self.parameters) - 1)
246
- for x, param in enumerate(self.parameters):
247
- param.requires_grad = (x == self.current_idx)
248
- print(f"Parameter {x}: requires_grad={param.requires_grad}")
249
- self.current_idx = (self.current_idx + 1) % len(self.parameters)
250
-
251
- def extract_f0(waveform, sampling_rate=16000, hop_length=128, device="cuda:0"):
252
- """Extract F0 from waveform - handle various input types"""
253
- if waveform is None:
254
- return None
255
-
256
- if isinstance(waveform, list):
257
- if len(waveform) == 0:
258
- return None
259
- waveform = waveform[0]
260
- print(f"DEBUG: Converted list to tensor, new type: {type(waveform)}")
261
-
262
- if not isinstance(waveform, torch.Tensor):
263
- waveform = torch.tensor(waveform)
264
-
265
- if isinstance(waveform, torch.Tensor):
266
- if waveform.dim() == 3:
267
- waveform = waveform.squeeze(1)
268
- if waveform.dim() == 2:
269
- waveform = waveform[0]
270
-
271
- wav_np = waveform.detach().cpu().numpy().astype(np.float64)
272
- else:
273
- wav_np = np.array(waveform).astype(np.float64)
274
-
275
- f0, t = pw.dio(wav_np, sampling_rate,
276
- frame_period=hop_length/sampling_rate*1000)
277
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
278
-
279
- f0_tensor = torch.from_numpy(f0).float().to(device)
280
- return f0_tensor.unsqueeze(0).unsqueeze(0)
281
-
282
  class rotary(nn.Module):
283
- _seen = set()
284
- def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, radii=False,
285
- learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = [], use_pbias = False):
286
  super().__init__()
287
 
288
  self.use_pbias = use_pbias
@@ -309,49 +267,64 @@ class rotary(nn.Module):
309
  self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
310
  self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
311
 
312
- def forward(self, x=None, layer=None, enc=None) -> Tensor:
313
-
314
- f0 = enc.get("f0") if enc else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  if isinstance(x, int):
316
  ctx = x
317
  else:
318
  batch, ctx, dims = x.shape
319
  t = torch.arange(ctx, device=self.device).float()
320
 
 
 
 
 
 
 
 
321
  if f0 is not None:
322
  f0_mean=f0.mean()+1e-8
323
  theta=f0_mean*self.pitch_scale
324
  freqs = 1. / (theta ** (torch.arange(0, self.dims, 2, device=self.device, dtype=self.dtype)[:(self.dims // 2)].float() /self.dims))
325
  else:
326
  freqs = self.freqs
327
-
328
  freqs = torch.einsum('i,j->ij', t, freqs)
329
  freqs = freqs.float()
330
- # print(f"{layer} : {f0_mean} : {theta:.2f} : {ctx} ")
331
  if self.radii:
332
- # radius = self.align_f0(f0, ctx)
333
- radius = enc.get("f0d") if enc else self.radius
 
334
  radius = radius.float()
335
-
336
  else:
337
  radius = self.radius
338
- # freqs = torch.polar(self.radius.unsqueeze(-1), freqs)
339
  freqs = torch.polar(radius.unsqueeze(-1), freqs)
340
 
341
- if "rotary" in self.debug:
342
  if f0 is not None:
343
- key = f"{self._counter}_{theta:.2f}"
344
- if key not in rotary._seen:
345
- if not hasattr(self, '_prev_f0_theta'):
346
- self._prev_f0_theta = theta
347
- # print(f"Step {self._counter}: Theta: {theta:.2f} Hz")
348
- elif abs(self._prev_f0_theta - theta) > 100.0:
349
- # print(f"Step {self._counter}: Theta: {theta:.2f} Hz, freqs: {freqs.shape}")
350
- print(f"{layer} : {f0_mean} : Theta: {theta:.2f} : {theta:.2f} : {ctx} ")
351
- if self.radii:
352
- print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")
353
- self._prev_f0_theta = theta
354
- rotary._seen.add(key)
355
  self._counter += 1
356
  return freqs
357
 
@@ -440,19 +413,19 @@ class MultiheadA(nn.Module):
440
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
441
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
442
 
443
- def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, feat=None, layer = None) -> tuple:
444
 
 
445
  scale = (self.dims // self.head) ** -0.25
446
 
447
  z = xa if xa is not None else x
448
  q = self.q(x).to(x.dtype)
449
  k = self.k(z).to(x.dtype)
450
  v = self.v(z).to(x.dtype)
451
- batch, ctx, dims = q.shape
452
 
453
  if self.rotary_emb:
454
- qf = self.rope(q.size(1), layer=layer, feat=feat)
455
- kf = self.rope(k.size(1), layer=layer, feat=feat)
456
 
457
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
458
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
@@ -466,13 +439,13 @@ class MultiheadA(nn.Module):
466
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
467
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
468
  batch, head, ctx, head_dim = q.shape
469
-
470
  if self.rbf:
471
  qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
472
 
473
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
474
- if self.rope.use_pbias:
475
- pbias = self.rope.pbias(feat.get("f0"))
476
  if pbias is not None:
477
  qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
478
  token_ids = k[:, :, :, 0]
@@ -484,6 +457,7 @@ class MultiheadA(nn.Module):
484
  mask = mask[:q.shape[2], :q.shape[2]]
485
  qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
486
  qk = qk * zscale.unsqueeze(-2)
 
487
  w = F.softmax(qk, dim=-1).to(q.dtype)
488
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
489
 
@@ -539,7 +513,6 @@ class c_gate(nn.Module):
539
  s = self.s_gate(x) * s_feat
540
  w = self.w_gate(x) * w_feat
541
  p = self.p_gate(x) * p_feat
542
-
543
  comb = torch.cat([s, w, p], dim=-1)
544
  return self.integ(comb)
545
 
@@ -590,12 +563,12 @@ class Residual(nn.Module):
590
  if not any([t_gate, m_gate, c_gate]):
591
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
592
 
593
- def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, feat=None, layer = None):
594
  bln = self.blend
595
- x = x + self.attna(self.lna(x), xa=None, mask=mask, layer=layer, feat=feat)[0]
596
 
597
  if self.attnb and xa is not None:
598
- c = self.attnb(self.lnb(x), xa, mask=None, layer=layer, feat=feat)[0]
599
  b = torch.sigmoid(bln)
600
  x = b * x + (1 - b) * c
601
 
@@ -610,7 +583,7 @@ class Residual(nn.Module):
610
  gate = self.m_gate(normx)
611
  x = x + gate * mlp_out
612
 
613
- elif self.c_gate is not None:
614
  gate_output = self.c_gate(normx, self.features)
615
  x = x + gate_output
616
 
@@ -650,7 +623,7 @@ class PEncoder(nn.Module):
650
  Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
651
  Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
652
 
653
- def forward(self, x, feat=None, layer=None):
654
  x = self.encoder(x).permute(0, 2, 1)
655
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
656
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
@@ -679,7 +652,7 @@ class WEncoder(nn.Module):
679
  self.positional = lambda length: sinusoids(length, dims)
680
  self.norm = RMSNorm(dims)
681
 
682
- def forward(self, x, feat=None, layer=None):
683
  x = self.downsample(x)
684
  x = self.encoder(x)
685
  x = x.permute(0, 2, 1)
@@ -706,49 +679,13 @@ class FEncoder(nn.Module):
706
  self.norm = RMSNorm(dims)
707
  self._norm = RMSNorm(dims)
708
 
709
- def forward(self, x, feat=None, layer=None):
710
  x = self.encoder(x).permute(0, 2, 1)
711
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
712
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
713
  x = self._norm(x)
714
  return x
715
-
716
- class F0Encoder(nn.Module):
717
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1):
718
- super().__init__()
719
-
720
- self.head_dim = dims // head
721
- self.dropout = 0.01
722
-
723
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
724
- "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
725
- "softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
726
- "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
727
- act_fn = act_map.get(act, nn.GELU())
728
-
729
- self.encoder = nn.Sequential(
730
- Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
731
- Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
732
- Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
733
-
734
- self.positional = lambda length: sinusoids(length, dims)
735
- self.norm = RMSNorm(dims)
736
- self._norm = RMSNorm(dims)
737
-
738
- def forward(self, x, feat=None, layer=None):
739
- if x.dim() == 3 and x.shape[0] == 1 and x.shape[1] == 1:
740
- pass
741
- elif x.dim() == 2:
742
- x = x.unsqueeze(1)
743
- elif x.dim() == 1:
744
- x = x.unsqueeze(0).unsqueeze(0)
745
- x = self.encoder(x)
746
- x = x.permute(0, 2, 1)
747
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
748
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
749
- x = self._norm(x)
750
- return x
751
-
752
  class AudioEncoder(nn.Module):
753
  _seen = set()
754
  def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str],
@@ -760,10 +697,8 @@ class AudioEncoder(nn.Module):
760
  self.ctx = ctx
761
  self.head_dim = dims // head
762
 
763
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
764
- dtype = torch.float32
765
- self.device = device
766
- self.dtype = dtype
767
  self.debug = debug
768
  self._counter = 0
769
 
@@ -772,7 +707,8 @@ class AudioEncoder(nn.Module):
772
  self.f0_rotary = f0_rotary
773
 
774
  self.rope = rotary(
775
- dims=self.head_dim)
 
776
 
777
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),
778
  "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
@@ -809,32 +745,27 @@ class AudioEncoder(nn.Module):
809
  FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)
810
  for _ in range(layer)])
811
 
812
- def forward(self, feat, layer="encoder"):
813
-
814
  if self._counter < 1:
815
- s = feat.get("spectrogram")
816
- w = feat.get("waveform")
817
- p = default(feat.get("f0"), feat.get("pitch"))
818
  plot_waveform(x=s, w=w, p=p, hop_length=128)
819
 
820
  enc = {}
821
- enc.update(feat)
822
-
823
- for f in self.features:
824
- if f in feat and f in self.blocks:
825
- x = feat[f]
826
- for block in self.blocks[f]:
827
- x = block(x, feat=feat, layer=layer)
828
- enc[f] = x
829
 
 
 
 
 
 
 
 
830
  if "encoder" in self.debug and self._counter % 100 == 0:
831
- names = list(feat.keys())
832
- shapes = {k: v.shape for k, v in feat.items()}
833
  print(f"Step {self._counter}: mode: {names}")
834
  print(f"shapes: {shapes}")
835
- for name, param in self.named_parameters():
836
- if param.requires_grad:
837
- print(f"ENCODER LAYER {name}: grad_norm={param.median():.4f}")
838
  self._counter += 1
839
  return enc
840
 
@@ -848,10 +779,8 @@ class TextDecoder(nn.Module):
848
  self.ctx = ctx
849
  self.head_dim = dims // head
850
 
851
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
852
- dtype = torch.float32
853
- self.device = device
854
- self.dtype = dtype
855
  self.debug = debug
856
  self._counter = 0
857
 
@@ -878,23 +807,8 @@ class TextDecoder(nn.Module):
878
 
879
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
880
  self.register_buffer("mask", mask, persistent=False)
881
-
882
- rotary_emb = False
883
- if rotary_emb:
884
- self.rope = rotary(
885
- dims=self.head_dim,
886
- debug = debug,
887
- radii=False,
888
- learned_pitch=False,
889
- learned_freq=False,
890
- learned_theta=False,
891
- learned_radius=False,
892
- )
893
- else:
894
- self.rope = None
895
 
896
- def forward(self, x, feat, order=None, layer='decoder') -> Tensor:
897
-
898
  bln = self.blend
899
  x = x.to(device)
900
  if order is None:
@@ -902,15 +816,13 @@ class TextDecoder(nn.Module):
902
  mask = self.mask[:x.shape[1], :x.shape[1]]
903
  x = self.token(x) + self.positional[:x.shape[1]]
904
  x = F.dropout(x, p=self.dropout, training=self.training)
905
-
906
  for block in self.block:
907
- x = block(x, xa=None, mask=mask, feat=feat, layer=layer)
908
-
909
  for f in order:
910
- if f in feat:
911
- xa = feat[f]
912
  for block in self.blocks[f]:
913
- out = block(x=x, xa=xa, mask=None, feat=feat, layer=layer)
914
  a = torch.sigmoid(bln[f])
915
  x = a * out + (1 - a) * x
916
  x = self.ln_dec(x)
@@ -994,13 +906,11 @@ class Echo(nn.Module):
994
  encoder_inputs["envelope"] = envelope
995
  if phase is not None:
996
  encoder_inputs["phase"] = phase
997
- if f0 is not None:
998
- encoder_inputs["f0"] = f0
999
- if f0d is not None:
1000
- encoder_inputs["f0d"] = f0d
1001
-
1002
- encoder_outputs = self.encoder(encoder_inputs)
1003
- logits = self.decoder(input_ids, encoder_outputs)
1004
 
1005
  loss = None
1006
  if labels is not None:
@@ -1071,7 +981,7 @@ class Echo(nn.Module):
1071
  print(f"{module_type}: {count}")
1072
 
1073
  def register_gradient_hooks(self):
1074
-
1075
  for name, param in self.named_parameters():
1076
  if param.requires_grad:
1077
  if "encoder" in name:
@@ -1096,6 +1006,6 @@ class Echo(nn.Module):
1096
  return None
1097
 
1098
  def reset_counter(self):
 
1099
  self._counter = 0
1100
  print("Counter reset to 0.")
1101
-
 
130
  axs[current_ax].set_ylabel("Mel Bin")
131
  axs[current_ax].set_xlim([0, max_time])
132
  axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
133
+ # fig.colorbar(im, ax=axs[current_ax])
134
  current_ax += 1
135
 
136
  if p is not None:
 
238
  scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
239
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  class rotary(nn.Module):
242
+
243
+ def __init__(self, dims, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias = False):
 
244
  super().__init__()
245
 
246
  self.use_pbias = use_pbias
 
267
  self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
268
  self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
269
 
270
+ def get_pitch_bias(self, f0):
271
+ if f0 is None:
272
+ return None
273
+ f0_flat = f0.squeeze().float()
274
+ f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
275
+ f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
276
+ f0_norm.unsqueeze(1)) * self.pitch_scale)
277
+ return f0_sim.unsqueeze(0).unsqueeze(0)
278
+
279
+ def align_f0(self, f0, ctx):
280
+ b, l = f0.shape
281
+ if l == ctx:
282
+ return f0.squeeze(0).float()
283
+ frames_per_token = l / ctx
284
+ idx = torch.arange(ctx, device=self.device, dtype=torch.float32)
285
+ src_idx = (idx * frames_per_token).long().clamp(0, l-1)
286
+ batch_idx = torch.arange(b, device=self.device, dtype=torch.float32).unsqueeze(1)
287
+ f0 = f0[batch_idx, src_idx]
288
+ return f0.squeeze(0).float()
289
+
290
+ def forward(self, x: Tensor, xa: Tensor = None, f0: Tensor = None, mask: Tensor = None, layer = None) -> Tensor:
291
+
292
  if isinstance(x, int):
293
  ctx = x
294
  else:
295
  batch, ctx, dims = x.shape
296
  t = torch.arange(ctx, device=self.device).float()
297
 
298
+ if self.learned_adaptation:
299
+ freqs = self.get_f0_adapted_freqs(ctx, f0)
300
+ x_complex = torch.view_as_complex(
301
+ x.float().reshape(*x.shape[:-1], -1, 2).contiguous())
302
+ x_rotated = x_complex * freqs.unsqueeze(0).unsqueeze(0)
303
+ freqs = torch.view_as_real(x_rotated).flatten(3).type_as(x)
304
+
305
  if f0 is not None:
306
  f0_mean=f0.mean()+1e-8
307
  theta=f0_mean*self.pitch_scale
308
  freqs = 1. / (theta ** (torch.arange(0, self.dims, 2, device=self.device, dtype=self.dtype)[:(self.dims // 2)].float() /self.dims))
309
  else:
310
  freqs = self.freqs
 
311
  freqs = torch.einsum('i,j->ij', t, freqs)
312
  freqs = freqs.float()
313
+
314
  if self.radii:
315
+ radius = self.align_f0(f0, ctx)
316
+ radius = torch.clamp(radius, min=40.0, max=600.0)
317
+ radius = radius / 600.0
318
  radius = radius.float()
 
319
  else:
320
  radius = self.radius
 
321
  freqs = torch.polar(radius.unsqueeze(-1), freqs)
322
 
323
+ if "rotary" in self.debug and self._counter % 100 == 50:
324
  if f0 is not None:
325
+ print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
326
+ if self.radii:
327
+ print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")
 
 
 
 
 
 
 
 
 
328
  self._counter += 1
329
  return freqs
330
 
 
413
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
414
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
415
 
416
+ def forward(self, x: Tensor, xa: Tensor = None, f0: Tensor = None, mask: Tensor = None, layer = None) -> tuple:
417
 
418
+ batch, ctx, dims = x.shape
419
  scale = (self.dims // self.head) ** -0.25
420
 
421
  z = xa if xa is not None else x
422
  q = self.q(x).to(x.dtype)
423
  k = self.k(z).to(x.dtype)
424
  v = self.v(z).to(x.dtype)
 
425
 
426
  if self.rotary_emb:
427
+ qf = self.rope(q.size(1), f0=f0, layer=layer)
428
+ kf = self.rope(k.size(1), f0=f0, layer=layer)
429
 
430
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
431
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
 
439
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
440
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
441
  batch, head, ctx, head_dim = q.shape
442
+
443
  if self.rbf:
444
  qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
445
 
446
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
447
+ if f0 is not None and self.rope.use_pbias:
448
+ pbias = self.rope.use_pbias(f0)
449
  if pbias is not None:
450
  qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
451
  token_ids = k[:, :, :, 0]
 
457
  mask = mask[:q.shape[2], :q.shape[2]]
458
  qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
459
  qk = qk * zscale.unsqueeze(-2)
460
+
461
  w = F.softmax(qk, dim=-1).to(q.dtype)
462
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
463
 
 
513
  s = self.s_gate(x) * s_feat
514
  w = self.w_gate(x) * w_feat
515
  p = self.p_gate(x) * p_feat
 
516
  comb = torch.cat([s, w, p], dim=-1)
517
  return self.integ(comb)
518
 
 
563
  if not any([t_gate, m_gate, c_gate]):
564
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
565
 
566
+ def forward(self, x, xa=None, mask=None, f0=None, mode=None, layer=None):
567
  bln = self.blend
568
+ x = x + self.attna(self.lna(x), xa=None, mask=mask, f0=f0, layer=layer)[0]
569
 
570
  if self.attnb and xa is not None:
571
+ c = self.attnb(self.lnb(x), xa=xa, f0=f0, mask=None, layer=layer)[0]
572
  b = torch.sigmoid(bln)
573
  x = b * x + (1 - b) * c
574
 
 
583
  gate = self.m_gate(normx)
584
  x = x + gate * mlp_out
585
 
586
+ elif self.c_gate and mode is not None:
587
  gate_output = self.c_gate(normx, self.features)
588
  x = x + gate_output
589
 
 
623
  Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
624
  Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
625
 
626
+ def forward(self, x, f0=None, layer=None):
627
  x = self.encoder(x).permute(0, 2, 1)
628
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
629
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
 
652
  self.positional = lambda length: sinusoids(length, dims)
653
  self.norm = RMSNorm(dims)
654
 
655
+ def forward(self, x, f0=None, layer=None):
656
  x = self.downsample(x)
657
  x = self.encoder(x)
658
  x = x.permute(0, 2, 1)
 
679
  self.norm = RMSNorm(dims)
680
  self._norm = RMSNorm(dims)
681
 
682
+ def forward(self, x, f0=None, layer=None):
683
  x = self.encoder(x).permute(0, 2, 1)
684
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
685
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
686
  x = self._norm(x)
687
  return x
688
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
  class AudioEncoder(nn.Module):
690
  _seen = set()
691
  def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str],
 
697
  self.ctx = ctx
698
  self.head_dim = dims // head
699
 
700
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
701
+ self.dtype = torch.float32
 
 
702
  self.debug = debug
703
  self._counter = 0
704
 
 
707
  self.f0_rotary = f0_rotary
708
 
709
  self.rope = rotary(
710
+ dims=self.head_dim,
711
+ )
712
 
713
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),
714
  "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
 
745
  FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)
746
  for _ in range(layer)])
747
 
748
+ def forward(self, x, f0=None, layer="ENC"):
 
749
  if self._counter < 1:
750
+ s = x.get("spectrogram")
751
+ w = x.get("waveform")
752
+ p = f0 if f0 is not None else x.get("pitch")
753
  plot_waveform(x=s, w=w, p=p, hop_length=128)
754
 
755
  enc = {}
 
 
 
 
 
 
 
 
756
 
757
+ for y in self.features:
758
+ if y in x and y in self.blocks:
759
+ f = x[y]
760
+ for block in self.blocks[y]:
761
+ f = block(f, f0=f0, layer=layer)
762
+ enc[y] = f
763
+
764
  if "encoder" in self.debug and self._counter % 100 == 0:
765
+ names = list(x.keys())
766
+ shapes = {k: v.shape for k, v in x.items()}
767
  print(f"Step {self._counter}: mode: {names}")
768
  print(f"shapes: {shapes}")
 
 
 
769
  self._counter += 1
770
  return enc
771
 
 
779
  self.ctx = ctx
780
  self.head_dim = dims // head
781
 
782
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
783
+ self.dtype = torch.float32
 
 
784
  self.debug = debug
785
  self._counter = 0
786
 
 
807
 
808
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
809
  self.register_buffer("mask", mask, persistent=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
 
811
+ def forward(self, x, enc, f0=None, order=None, layer='DEC') -> Tensor:
 
812
  bln = self.blend
813
  x = x.to(device)
814
  if order is None:
 
816
  mask = self.mask[:x.shape[1], :x.shape[1]]
817
  x = self.token(x) + self.positional[:x.shape[1]]
818
  x = F.dropout(x, p=self.dropout, training=self.training)
 
819
  for block in self.block:
820
+ x = block(x, xa=None, f0=f0, mask=mask, layer=layer)
 
821
  for f in order:
822
+ if f in enc:
823
+ xa = enc[f]
824
  for block in self.blocks[f]:
825
+ out = block(x=x, xa=xa, f0=f0, mask=None, layer=layer)
826
  a = torch.sigmoid(bln[f])
827
  x = a * out + (1 - a) * x
828
  x = self.ln_dec(x)
 
906
  encoder_inputs["envelope"] = envelope
907
  if phase is not None:
908
  encoder_inputs["phase"] = phase
909
+ # if f0 is not None:
910
+ # encoder_inputs["f0"] = f0
911
+
912
+ encoder_outputs = self.encoder(encoder_inputs, f0=f0)
913
+ logits = self.decoder(input_ids, encoder_outputs, f0=f0d)
 
 
914
 
915
  loss = None
916
  if labels is not None:
 
981
  print(f"{module_type}: {count}")
982
 
983
  def register_gradient_hooks(self):
984
+ """Add this method to your Echo model class"""
985
  for name, param in self.named_parameters():
986
  if param.requires_grad:
987
  if "encoder" in name:
 
1006
  return None
1007
 
1008
  def reset_counter(self):
1009
+ """Reset the internal counter for debugging purposes."""
1010
  self._counter = 0
1011
  print("Counter reset to 0.")