thecollabagepatch commited on
Commit
169ed8c
·
1 Parent(s): e0bae41
Files changed (1) hide show
  1. jam_worker.py +227 -181
jam_worker.py CHANGED
@@ -1,4 +1,4 @@
1
- # jam_worker.py - Updated with robust silence handling
2
  from __future__ import annotations
3
 
4
  import os
@@ -20,6 +20,7 @@ from utils import (
20
  )
21
 
22
  def _dbg_rms_dbfs(x: np.ndarray) -> float:
 
23
  if x.ndim == 2:
24
  x = x.mean(axis=1)
25
  r = float(np.sqrt(np.mean(x * x) + 1e-12))
@@ -27,6 +28,7 @@ def _dbg_rms_dbfs(x: np.ndarray) -> float:
27
 
28
  def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
29
  # x is model-rate, shape [S,C] or [S]
 
30
  if x.ndim == 2:
31
  x = x.mean(axis=1)
32
  r = float(np.sqrt(np.mean(x * x) + 1e-12))
@@ -35,19 +37,6 @@ def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
35
  def _dbg_shape(x):
36
  return tuple(x.shape) if hasattr(x, "shape") else ("-",)
37
 
38
- def _is_silent(audio: np.ndarray, threshold_db: float = -60.0) -> bool:
39
- """Check if audio is effectively silent."""
40
- if audio.size == 0:
41
- return True
42
- if audio.ndim == 2:
43
- audio = audio.mean(axis=1)
44
- rms = float(np.sqrt(np.mean(audio**2)))
45
- return 20.0 * np.log10(max(rms, 1e-12)) < threshold_db
46
-
47
- def _has_energy(audio: np.ndarray, threshold_db: float = -40.0) -> bool:
48
- """Check if audio has significant energy (stricter than just non-silent)."""
49
- return not _is_silent(audio, threshold_db)
50
-
51
  # -----------------------------
52
  # Data classes
53
  # -----------------------------
@@ -66,7 +55,7 @@ class JamParams:
66
  guidance_weight: float = 1.1
67
  temperature: float = 1.1
68
  topk: int = 40
69
- style_ramp_seconds: float = 8.0
70
 
71
 
72
  @dataclass
@@ -121,6 +110,8 @@ class JamWorker(threading.Thread):
121
  self.mrt.temperature = float(self.params.temperature)
122
  self.mrt.topk = int(self.params.topk)
123
 
 
 
124
  # codec/setup
125
  self._codec_fps = float(self.mrt.codec.frame_rate)
126
  JamWorker.FRAMES_PER_SECOND = self._codec_fps
@@ -146,9 +137,8 @@ class JamWorker(threading.Thread):
146
  self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
147
  self._spool_written = 0 # absolute frames written into spool
148
 
149
- # Health monitoring
150
- self._silence_streak = 0 # consecutive silent chunks
151
- self._last_good_context_tokens = None # backup of last known good context
152
 
153
  # bar clock: start with offset 0; if you have a downbeat estimator, set base later
154
  self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
@@ -173,47 +163,6 @@ class JamWorker(threading.Thread):
173
  # Prepare initial context from combined loop (best musical alignment)
174
  if self.params.combined_loop is not None:
175
  self._install_context_from_loop(self.params.combined_loop)
176
- # Save this as our "good" context backup
177
- if hasattr(self.state, 'context_tokens') and self.state.context_tokens is not None:
178
- self._last_good_context_tokens = np.copy(self.state.context_tokens)
179
-
180
- # ---------- NEW: Health monitoring methods ----------
181
-
182
- def _check_model_health(self, new_chunk: np.ndarray) -> bool:
183
- """Check if the model output looks healthy."""
184
- if _is_silent(new_chunk, threshold_db=-80.0):
185
- self._silence_streak += 1
186
- print(f"⚠️ Silent chunk detected (streak: {self._silence_streak})")
187
- return False
188
- else:
189
- if self._silence_streak > 0:
190
- print(f"✅ Audio resumed after {self._silence_streak} silent chunks")
191
- self._silence_streak = 0
192
- return True
193
-
194
- def _recover_from_silence(self):
195
- """Attempt to recover from silence by restoring last good context."""
196
- print("🔧 Attempting recovery from silence...")
197
-
198
- if self._last_good_context_tokens is not None:
199
- # Restore last known good context
200
- try:
201
- new_state = self.mrt.init_state()
202
- new_state.context_tokens = np.copy(self._last_good_context_tokens)
203
- self.state = new_state
204
- self._model_stream = None # Reset stream to start fresh
205
- print(" Restored last good context")
206
- except Exception as e:
207
- print(f" Context restoration failed: {e}")
208
-
209
- # If we have the original loop, rebuild context from it
210
- elif self.params.combined_loop is not None:
211
- try:
212
- self._install_context_from_loop(self.params.combined_loop)
213
- self._model_stream = None
214
- print(" Rebuilt context from original loop")
215
- except Exception as e:
216
- print(f" Context rebuild failed: {e}")
217
 
218
  # ---------- lifecycle ----------
219
 
@@ -299,7 +248,13 @@ class JamWorker(threading.Thread):
299
  return toks
300
 
301
  def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
302
- """Build *exactly* context_length_frames worth of tokens, ensuring bar alignment."""
 
 
 
 
 
 
303
  wav = loop.as_stereo().resample(self._model_sr)
304
  data = wav.samples.astype(np.float32, copy=False)
305
  if data.ndim == 1:
@@ -334,14 +289,8 @@ class JamWorker(threading.Thread):
334
 
335
  # final snap to *exact* ctx samples
336
  if ctx.shape[0] < ctx_samps:
337
- # Instead of zero padding, repeat the audio to fill
338
- shortfall = ctx_samps - ctx.shape[0]
339
- if ctx.shape[0] > 0:
340
- fill = np.tile(ctx, (int(np.ceil(shortfall / ctx.shape[0])) + 1, 1))[:shortfall]
341
- ctx = np.concatenate([fill, ctx], axis=0)
342
- else:
343
- print("⚠️ Zero-length context, using fallback")
344
- ctx = np.zeros((ctx_samps, 2), dtype=np.float32)
345
  elif ctx.shape[0] > ctx_samps:
346
  ctx = ctx[-ctx_samps:]
347
 
@@ -352,20 +301,79 @@ class JamWorker(threading.Thread):
352
 
353
  # Force expected (F,D) at *return time*
354
  tokens = self._coerce_tokens(tokens)
355
-
356
- # Validate that we don't have a silent context
357
- if _is_silent(ctx, threshold_db=-80.0):
358
- print("⚠️ Generated silent context - this may cause issues")
359
-
360
  return tokens
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  def _install_context_from_loop(self, loop: au.Waveform):
363
  # Build exact-length, bar-locked context tokens
364
  context_tokens = self._encode_exact_context_tokens(loop)
365
  s = self.mrt.init_state()
366
  s.context_tokens = context_tokens
367
  self.state = s
368
- self._last_good_context_tokens = np.copy(context_tokens)
369
 
370
  def reseed_from_waveform(self, wav: au.Waveform):
371
  """Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
@@ -375,11 +383,14 @@ class JamWorker(threading.Thread):
375
  s.context_tokens = context_tokens
376
  self.state = s
377
  self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
378
- self._last_good_context_tokens = np.copy(context_tokens)
379
- self._silence_streak = 0 # Reset health monitoring
380
 
381
  def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
382
- """Queue a *seamless* reseed by token splicing instead of full restart."""
 
 
 
 
383
  new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
384
  F, D = self._expected_token_shape()
385
 
@@ -408,20 +419,44 @@ class JamWorker(threading.Thread):
408
  "tokens": spliced,
409
  "debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
410
  }
 
411
 
412
- # ---------- REWRITTEN: core streaming helpers ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
  def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
415
  """
416
- REWRITTEN: Robust audio processing that rejects silent chunks entirely.
417
-
418
- Strategy:
419
- 1. Validate input chunk for silence/issues
420
- 2. REJECT silent chunks - don't add them to spool or model stream
421
- 3. Use healthy crossfading only between good audio
422
- 4. Aggressive recovery when silence detected
 
 
423
  """
424
- # Unpack model-rate samples
 
425
  s = wav.samples.astype(np.float32, copy=False)
426
  if s.ndim == 1:
427
  s = s[:, None]
@@ -429,103 +464,119 @@ class JamWorker(threading.Thread):
429
  if n_samps == 0:
430
  return
431
 
432
- # Health check on new chunk - use stricter threshold
433
- is_healthy = self._check_model_health(s)
434
- is_very_quiet = _is_silent(s, threshold_db=-50.0) # stricter than default -60
435
-
436
- # Get crossfade params
437
  try:
438
  xfade_s = float(self.mrt.config.crossfade_length)
439
  except Exception:
440
  xfade_s = 0.0
441
  xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr)))
442
 
443
- print(f"[model] chunk len={n_samps} rms={_dbg_rms_dbfs_model(s):+.1f} dBFS healthy={is_healthy} quiet={is_very_quiet}")
444
-
445
- # --- REJECT PROBLEMATIC CHUNKS ---
446
- if not is_healthy or is_very_quiet:
447
- print(f"[REJECT] Discarding unhealthy/quiet chunk - not adding to spool or model stream")
448
-
449
- # Trigger recovery immediately on first bad chunk
450
- if self._silence_streak >= 1:
451
- self._recover_from_silence()
452
-
453
- # Don't process this chunk at all - return early
454
- return
455
-
456
- # Reset silence streak on good chunk
457
- if self._silence_streak > 0:
458
- print(f"✅ Audio resumed after {self._silence_streak} rejected chunks")
459
- self._silence_streak = 0
460
-
461
- # Helper: resample to target SR
462
  def to_target(y: np.ndarray) -> np.ndarray:
463
  return y if self._rs is None else self._rs.process(y, final=False)
464
 
465
- # --- SIMPLIFIED CROSSFADE LOGIC (only for healthy audio) ---
466
-
467
- if self._model_stream is None:
468
- # First chunk - no crossfading needed
469
- self._model_stream = s.copy()
470
-
471
- elif xfade_n <= 0 or n_samps < xfade_n:
472
- # No crossfade configured or chunk too short - simple append
473
- self._model_stream = np.concatenate([self._model_stream, s], axis=0)
474
-
475
- elif _is_silent(self._model_stream[-xfade_n:], threshold_db=-50.0):
476
- # Previous tail is quiet - don't crossfade, just replace
477
- print(f"[crossfade] Replacing quiet tail with new audio")
478
- # Remove quiet tail and append new chunk
479
- self._model_stream = np.concatenate([self._model_stream[:-xfade_n], s], axis=0)
480
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  else:
482
- # Normal crossfade between healthy audio
483
- tail = self._model_stream[-xfade_n:]
484
- head = s[:xfade_n]
485
- body = s[xfade_n:] if n_samps > xfade_n else np.zeros((0, s.shape[1]), dtype=np.float32)
486
-
487
- # Equal power crossfade
488
- t = np.linspace(0.0, 1.0, xfade_n, dtype=np.float32)[:, None]
489
- fade_out = np.cos(t * np.pi / 2.0)
490
- fade_in = np.sin(t * np.pi / 2.0)
491
-
492
- mixed = tail * fade_out + head * fade_in
493
-
494
- print(f"[crossfade] tail rms={_dbg_rms_dbfs_model(tail):+.1f} head rms={_dbg_rms_dbfs_model(head):+.1f} mixed rms={_dbg_rms_dbfs_model(mixed):+.1f}")
495
-
496
- # Update model stream: remove old tail, add mixed section, add body
497
- self._model_stream = np.concatenate([
498
- self._model_stream[:-xfade_n],
499
- mixed,
500
- body
501
- ], axis=0)
502
-
503
- # --- CONVERT AND APPEND TO SPOOL (only healthy audio reaches here) ---
504
-
505
- # Take the new audio from this iteration
 
 
 
 
 
 
 
 
 
 
 
 
506
  if xfade_n > 0 and n_samps >= xfade_n:
507
- # Normal case: body after crossfade region
508
- new_audio = s[xfade_n:] if n_samps > xfade_n else s
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  else:
510
- # Short chunk or no crossfade: use entire chunk
511
- new_audio = s
512
-
513
- if new_audio.shape[0] > 0:
514
- target_audio = to_target(new_audio)
515
- if target_audio.shape[0] > 0:
516
- print(f"[append] body len={target_audio.shape[0]} rms={_dbg_rms_dbfs(target_audio):+.1f} dBFS")
517
- self._spool = np.concatenate([self._spool, target_audio], axis=0) if self._spool.size else target_audio
518
- self._spool_written += target_audio.shape[0]
519
-
520
- # --- SAVE GOOD CONTEXT ---
521
- # Only save context from healthy chunks
522
- if hasattr(self.state, 'context_tokens') and self.state.context_tokens is not None:
523
- self._last_good_context_tokens = np.copy(self.state.context_tokens)
524
-
525
- # Trim model stream to reasonable length (keep ~30 seconds)
526
- max_model_samples = int(30.0 * self._model_sr)
527
- if self._model_stream.shape[0] > max_model_samples:
528
- self._model_stream = self._model_stream[-max_model_samples:]
529
 
530
  def _should_generate_next_chunk(self) -> bool:
531
  # Allow running ahead relative to whichever is larger: last *consumed*
@@ -562,7 +613,6 @@ class JamWorker(threading.Thread):
562
  "guidance_weight": float(self.params.guidance_weight),
563
  "temperature": float(self.params.temperature),
564
  "topk": int(self.params.topk),
565
- "silence_streak": self._silence_streak, # Add health info
566
  }
567
  chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
568
 
@@ -587,7 +637,6 @@ class JamWorker(threading.Thread):
587
  # inplace update (no reset)
588
  self.state.context_tokens = spliced
589
  self._pending_token_splice = None
590
- print("[reseed] Token splice applied")
591
  except Exception:
592
  # fallback: full reseed using spliced tokens
593
  new_state = self.mrt.init_state()
@@ -595,7 +644,6 @@ class JamWorker(threading.Thread):
595
  self.state = new_state
596
  self._model_stream = None
597
  self._pending_token_splice = None
598
- print("[reseed] Token splice fallback to full reset")
599
  elif self._pending_reseed is not None:
600
  ctx = self._coerce_tokens(self._pending_reseed["ctx"])
601
  new_state = self.mrt.init_state()
@@ -603,7 +651,6 @@ class JamWorker(threading.Thread):
603
  self.state = new_state
604
  self._model_stream = None
605
  self._pending_reseed = None
606
- print("[reseed] Full reseed applied")
607
 
608
  # ---------- main loop ----------
609
 
@@ -640,10 +687,9 @@ class JamWorker(threading.Thread):
640
  self._emit_ready()
641
 
642
  # finalize resampler (flush) — not strictly necessary here
643
- if self._rs is not None:
644
- tail = self._rs.process(np.zeros((0,2), np.float32), final=True)
645
- if tail.size:
646
- self._spool = np.concatenate([self._spool, tail], axis=0)
647
- self._spool_written += tail.shape[0]
648
  # one last emit attempt
649
- self._emit_ready()
 
1
+ # jam_worker.py - Bar-locked spool rewrite
2
  from __future__ import annotations
3
 
4
  import os
 
20
  )
21
 
22
  def _dbg_rms_dbfs(x: np.ndarray) -> float:
23
+
24
  if x.ndim == 2:
25
  x = x.mean(axis=1)
26
  r = float(np.sqrt(np.mean(x * x) + 1e-12))
 
28
 
29
  def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
30
  # x is model-rate, shape [S,C] or [S]
31
+
32
  if x.ndim == 2:
33
  x = x.mean(axis=1)
34
  r = float(np.sqrt(np.mean(x * x) + 1e-12))
 
37
  def _dbg_shape(x):
38
  return tuple(x.shape) if hasattr(x, "shape") else ("-",)
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # -----------------------------
41
  # Data classes
42
  # -----------------------------
 
55
  guidance_weight: float = 1.1
56
  temperature: float = 1.1
57
  topk: int = 40
58
+ style_ramp_seconds: float = 8.0 # 0 => instant (current behavior), try 6.0–10.0 for gentle glides
59
 
60
 
61
  @dataclass
 
110
  self.mrt.temperature = float(self.params.temperature)
111
  self.mrt.topk = int(self.params.topk)
112
 
113
+
114
+
115
  # codec/setup
116
  self._codec_fps = float(self.mrt.codec.frame_rate)
117
  JamWorker.FRAMES_PER_SECOND = self._codec_fps
 
137
  self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
138
  self._spool_written = 0 # absolute frames written into spool
139
 
140
+ self._pending_tail_model = None # type: Optional[np.ndarray] # last tail at model SR
141
+ self._pending_tail_target_len = 0 # number of target-SR samples last tail contributed
 
142
 
143
  # bar clock: start with offset 0; if you have a downbeat estimator, set base later
144
  self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
 
163
  # Prepare initial context from combined loop (best musical alignment)
164
  if self.params.combined_loop is not None:
165
  self._install_context_from_loop(self.params.combined_loop)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  # ---------- lifecycle ----------
168
 
 
248
  return toks
249
 
250
  def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
251
+ """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
252
+ while ensuring the *end* of the audio lands on a bar boundary.
253
+ Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
254
+ then left-fill from just before that tail (wrapping if needed) to reach exactly
255
+ ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
256
+ tokens to the expected frame count.
257
+ """
258
  wav = loop.as_stereo().resample(self._model_sr)
259
  data = wav.samples.astype(np.float32, copy=False)
260
  if data.ndim == 1:
 
289
 
290
  # final snap to *exact* ctx samples
291
  if ctx.shape[0] < ctx_samps:
292
+ pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
293
+ ctx = np.concatenate([pad, ctx], axis=0)
 
 
 
 
 
 
294
  elif ctx.shape[0] > ctx_samps:
295
  ctx = ctx[-ctx_samps:]
296
 
 
301
 
302
  # Force expected (F,D) at *return time*
303
  tokens = self._coerce_tokens(tokens)
 
 
 
 
 
304
  return tokens
305
 
306
+ def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
307
+ """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
308
+ while ensuring the *end* of the audio lands on a bar boundary.
309
+ Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
310
+ then left-fill from just before that tail (wrapping if needed) to reach exactly
311
+ ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
312
+ tokens to the expected frame count.
313
+ """
314
+ wav = loop.as_stereo().resample(self._model_sr)
315
+ data = wav.samples.astype(np.float32, copy=False)
316
+ if data.ndim == 1:
317
+ data = data[:, None]
318
+
319
+ spb = self._bar_clock.seconds_per_bar()
320
+ ctx_sec = float(self._ctx_seconds)
321
+ sr = int(self._model_sr)
322
+
323
+ # bars that fit fully inside ctx_sec (at least 1)
324
+ bars_fit = max(1, int(ctx_sec // spb))
325
+ tail_len_samps = int(round(bars_fit * spb * sr))
326
+
327
+ # ensure we have enough source by tiling
328
+ need = int(round(ctx_sec * sr)) + tail_len_samps
329
+ if data.shape[0] == 0:
330
+ data = np.zeros((1, 2), dtype=np.float32)
331
+ reps = int(np.ceil(need / float(data.shape[0])))
332
+ tiled = np.tile(data, (reps, 1))
333
+
334
+ end = tiled.shape[0]
335
+ tail = tiled[end - tail_len_samps:end]
336
+
337
+ # left-fill to reach exact ctx samples (keeps end-of-bar alignment)
338
+ ctx_samps = int(round(ctx_sec * sr))
339
+ pad_len = ctx_samps - tail.shape[0]
340
+ if pad_len > 0:
341
+ pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps]
342
+ ctx = np.concatenate([pre, tail], axis=0)
343
+ else:
344
+ ctx = tail[-ctx_samps:]
345
+
346
+ # final snap to *exact* ctx samples
347
+ if ctx.shape[0] < ctx_samps:
348
+ pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
349
+ ctx = np.concatenate([pad, ctx], axis=0)
350
+ elif ctx.shape[0] > ctx_samps:
351
+ ctx = ctx[-ctx_samps:]
352
+
353
+ exact = au.Waveform(ctx, sr)
354
+ tokens_full = self.mrt.codec.encode(exact).astype(np.int32)
355
+ depth = int(self.mrt.config.decoder_codec_rvq_depth)
356
+ tokens = tokens_full[:, :depth]
357
+
358
+ # Last defense: force expected frame count
359
+ frames = tokens.shape[0]
360
+ exp = int(self._ctx_frames)
361
+ if frames < exp:
362
+ # repeat last frame
363
+ pad = np.repeat(tokens[-1:, :], exp - frames, axis=0)
364
+ tokens = np.concatenate([pad, tokens], axis=0)
365
+ elif frames > exp:
366
+ tokens = tokens[-exp:, :]
367
+ return tokens
368
+
369
+
370
  def _install_context_from_loop(self, loop: au.Waveform):
371
  # Build exact-length, bar-locked context tokens
372
  context_tokens = self._encode_exact_context_tokens(loop)
373
  s = self.mrt.init_state()
374
  s.context_tokens = context_tokens
375
  self.state = s
376
+ self._original_context_tokens = np.copy(context_tokens)
377
 
378
  def reseed_from_waveform(self, wav: au.Waveform):
379
  """Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
 
383
  s.context_tokens = context_tokens
384
  self.state = s
385
  self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
386
+ self._original_context_tokens = np.copy(context_tokens)
 
387
 
388
  def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
389
+ """Queue a *seamless* reseed by token splicing instead of full restart.
390
+ We compute a fresh, bar-locked context token tensor of exact length
391
+ (e.g., 250 frames), then splice only the *tail* corresponding to
392
+ `anchor_bars` so generation continues smoothly without resetting state.
393
+ """
394
  new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
395
  F, D = self._expected_token_shape()
396
 
 
419
  "tokens": spliced,
420
  "debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
421
  }
422
+
423
 
424
+
425
+ def reseed_from_waveform(self, wav: au.Waveform):
426
+ """Immediate reseed: replace context from provided wave (bar-aligned tail)."""
427
+ wav = wav.as_stereo().resample(self._model_sr)
428
+ tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
429
+ tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
430
+ depth = int(self.mrt.config.decoder_codec_rvq_depth)
431
+ context_tokens = tokens_full[:, :depth]
432
+
433
+ s = self.mrt.init_state()
434
+ s.context_tokens = context_tokens
435
+ self.state = s
436
+ # reset model stream so next generate starts cleanly
437
+ self._model_stream = None
438
+
439
+ # optional loudness match will be applied per-chunk on emission
440
+
441
+ # also remember this as new "original"
442
+ self._original_context_tokens = np.copy(context_tokens)
443
+
444
+ # ---------- core streaming helpers ----------
445
 
446
  def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
447
  """
448
+ Conservative boundary fix:
449
+ - Emit body+tail immediately (target SR), unchanged from your original behavior.
450
+ - On *next* call, compute the mixed overlap (prev tail ⨉ cos + new head ⨉ sin),
451
+ resample it, and overwrite the last `_pending_tail_target_len` samples in the
452
+ target-SR spool with that mixed overlap. Then emit THIS chunk's body+tail and
453
+ remember THIS chunk's tail length at target SR for the next correction.
454
+
455
+ This keeps external timing and bar alignment identical, but removes the audible
456
+ fade-to-zero at chunk ends.
457
  """
458
+
459
+ # ---- unpack model-rate samples ----
460
  s = wav.samples.astype(np.float32, copy=False)
461
  if s.ndim == 1:
462
  s = s[:, None]
 
464
  if n_samps == 0:
465
  return
466
 
467
+ # crossfade length in model samples
 
 
 
 
468
  try:
469
  xfade_s = float(self.mrt.config.crossfade_length)
470
  except Exception:
471
  xfade_s = 0.0
472
  xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr)))
473
 
474
+ # helper: resample to target SR via your streaming resampler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  def to_target(y: np.ndarray) -> np.ndarray:
476
  return y if self._rs is None else self._rs.process(y, final=False)
477
 
478
+ # ------------------------------------------
479
+ # (A) If we have a pending model tail, fix the last emitted tail at target SR
480
+ # ------------------------------------------
481
+ if self._pending_tail_model is not None and self._pending_tail_model.shape[0] == xfade_n and xfade_n > 0 and n_samps >= xfade_n:
482
+ head = s[:xfade_n, :]
483
+
484
+ print(f"[model] head len={head.shape[0]} rms={_dbg_rms_dbfs_model(head):+.1f} dBFS")
485
+
486
+ t = np.linspace(0.0, np.pi/2.0, xfade_n, endpoint=False, dtype=np.float32)[:, None]
487
+ cosw = np.cos(t, dtype=np.float32)
488
+ sinw = np.sin(t, dtype=np.float32)
489
+ mixed_model = (self._pending_tail_model * cosw) + (head * sinw) # [xfade_n, C] at model SR
490
+
491
+ y_mixed = to_target(mixed_model.astype(np.float32))
492
+ Lcorr = int(y_mixed.shape[0]) # exact target-SR samples to write
493
+
494
+ # DEBUG: corrected overlap RMS (what we intend to hear at the boundary)
495
+ if y_mixed.size:
496
+ print(f"[append] mixedOverlap len={y_mixed.shape[0]} rms={_dbg_rms_dbfs(y_mixed):+.1f} dBFS")
497
+
498
+ # Overwrite the last `_pending_tail_target_len` samples of the spool with `y_mixed`.
499
+ # Use the *smaller* of the two lengths to be safe.
500
+ Lpop = min(self._pending_tail_target_len, self._spool.shape[0], Lcorr)
501
+ if Lpop > 0 and self._spool.size:
502
+ # Trim last Lpop samples
503
+ self._spool = self._spool[:-Lpop, :]
504
+ self._spool_written -= Lpop
505
+ # Append corrected overlap (trim/pad to Lpop to avoid drift)
506
+ if Lcorr != Lpop:
507
+ if Lcorr > Lpop:
508
+ y_m = y_mixed[-Lpop:, :]
509
+ else:
510
+ pad = np.zeros((Lpop - Lcorr, y_mixed.shape[1]), dtype=np.float32)
511
+ y_m = np.concatenate([y_mixed, pad], axis=0)
512
+ else:
513
+ y_m = y_mixed
514
+ self._spool = np.concatenate([self._spool, y_m], axis=0) if self._spool.size else y_m
515
+ self._spool_written += y_m.shape[0]
516
+
517
+ # For internal continuity, update _model_stream like before
518
+ if self._model_stream is None or self._model_stream.shape[0] < xfade_n:
519
+ self._model_stream = s[xfade_n:].copy()
520
+ else:
521
+ self._model_stream = np.concatenate([self._model_stream[:-xfade_n], mixed_model, s[xfade_n:]], axis=0)
522
  else:
523
+ # First-ever call or too-short to mix: maintain _model_stream minimally
524
+ if xfade_n > 0 and n_samps > xfade_n:
525
+ self._model_stream = s[xfade_n:].copy() if self._model_stream is None else np.concatenate([self._model_stream, s[xfade_n:]], axis=0)
526
+ else:
527
+ self._model_stream = s.copy() if self._model_stream is None else np.concatenate([self._model_stream, s], axis=0)
528
+
529
+ # ------------------------------------------
530
+ # (B) Emit THIS chunk's body and tail (same external behavior)
531
+ # ------------------------------------------
532
+ if xfade_n > 0 and n_samps >= (2 * xfade_n):
533
+ body = s[xfade_n:-xfade_n, :]
534
+ print(f"[model] body len={body.shape[0]} rms={_dbg_rms_dbfs_model(body):+.1f} dBFS")
535
+ if body.size:
536
+ y_body = to_target(body.astype(np.float32))
537
+ if y_body.size:
538
+ # DEBUG: body RMS we are actually appending
539
+ print(f"[append] body len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
540
+ self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
541
+ self._spool_written += y_body.shape[0]
542
+ else:
543
+ # If chunk too short for head+tail split, treat all (minus preroll) as body
544
+ if xfade_n > 0 and n_samps > xfade_n:
545
+ body = s[xfade_n:, :]
546
+ print(f"[model] body(S) len={body.shape[0]} rms={_dbg_rms_dbfs_model(body):+.1f} dBFS")
547
+ y_body = to_target(body.astype(np.float32))
548
+ if y_body.size:
549
+ # DEBUG: body RMS in short-chunk path
550
+ print(f"[append] body(len=short) len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
551
+ self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
552
+ self._spool_written += y_body.shape[0]
553
+ # No tail to remember this round
554
+ self._pending_tail_model = None
555
+ self._pending_tail_target_len = 0
556
+ return
557
+
558
+ # Tail (always remember how many TARGET samples we append)
559
  if xfade_n > 0 and n_samps >= xfade_n:
560
+ tail = s[-xfade_n:, :]
561
+ print(f"[model] tail len={tail.shape[0]} rms={_dbg_rms_dbfs_model(tail):+.1f} dBFS")
562
+ y_tail = to_target(tail.astype(np.float32))
563
+ Ltail = int(y_tail.shape[0])
564
+ if Ltail:
565
+ # DEBUG: tail RMS we are appending now (to be corrected next call)
566
+ print(f"[append] tail len={y_tail.shape[0]} rms={_dbg_rms_dbfs(y_tail):+.1f} dBFS")
567
+ self._spool = np.concatenate([self._spool, y_tail], axis=0) if self._spool.size else y_tail
568
+ self._spool_written += Ltail
569
+ self._pending_tail_model = tail.copy()
570
+ self._pending_tail_target_len = Ltail
571
+ else:
572
+ # Nothing appended (resampler returned nothing yet) — keep model tail but mark zero target len
573
+ self._pending_tail_model = tail.copy()
574
+ self._pending_tail_target_len = 0
575
  else:
576
+ self._pending_tail_model = None
577
+ self._pending_tail_target_len = 0
578
+
579
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
 
581
  def _should_generate_next_chunk(self) -> bool:
582
  # Allow running ahead relative to whichever is larger: last *consumed*
 
613
  "guidance_weight": float(self.params.guidance_weight),
614
  "temperature": float(self.params.temperature),
615
  "topk": int(self.params.topk),
 
616
  }
617
  chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
618
 
 
637
  # inplace update (no reset)
638
  self.state.context_tokens = spliced
639
  self._pending_token_splice = None
 
640
  except Exception:
641
  # fallback: full reseed using spliced tokens
642
  new_state = self.mrt.init_state()
 
644
  self.state = new_state
645
  self._model_stream = None
646
  self._pending_token_splice = None
 
647
  elif self._pending_reseed is not None:
648
  ctx = self._coerce_tokens(self._pending_reseed["ctx"])
649
  new_state = self.mrt.init_state()
 
651
  self.state = new_state
652
  self._model_stream = None
653
  self._pending_reseed = None
 
654
 
655
  # ---------- main loop ----------
656
 
 
687
  self._emit_ready()
688
 
689
  # finalize resampler (flush) — not strictly necessary here
690
+ tail = self._rs.process(np.zeros((0,2), np.float32), final=True)
691
+ if tail.size:
692
+ self._spool = np.concatenate([self._spool, tail], axis=0)
693
+ self._spool_written += tail.shape[0]
 
694
  # one last emit attempt
695
+ self._emit_ready()