Commit
·
169ed8c
1
Parent(s):
e0bae41
reverted
Browse files- jam_worker.py +227 -181
jam_worker.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# jam_worker.py -
|
2 |
from __future__ import annotations
|
3 |
|
4 |
import os
|
@@ -20,6 +20,7 @@ from utils import (
|
|
20 |
)
|
21 |
|
22 |
def _dbg_rms_dbfs(x: np.ndarray) -> float:
|
|
|
23 |
if x.ndim == 2:
|
24 |
x = x.mean(axis=1)
|
25 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
@@ -27,6 +28,7 @@ def _dbg_rms_dbfs(x: np.ndarray) -> float:
|
|
27 |
|
28 |
def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
|
29 |
# x is model-rate, shape [S,C] or [S]
|
|
|
30 |
if x.ndim == 2:
|
31 |
x = x.mean(axis=1)
|
32 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
@@ -35,19 +37,6 @@ def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
|
|
35 |
def _dbg_shape(x):
|
36 |
return tuple(x.shape) if hasattr(x, "shape") else ("-",)
|
37 |
|
38 |
-
def _is_silent(audio: np.ndarray, threshold_db: float = -60.0) -> bool:
|
39 |
-
"""Check if audio is effectively silent."""
|
40 |
-
if audio.size == 0:
|
41 |
-
return True
|
42 |
-
if audio.ndim == 2:
|
43 |
-
audio = audio.mean(axis=1)
|
44 |
-
rms = float(np.sqrt(np.mean(audio**2)))
|
45 |
-
return 20.0 * np.log10(max(rms, 1e-12)) < threshold_db
|
46 |
-
|
47 |
-
def _has_energy(audio: np.ndarray, threshold_db: float = -40.0) -> bool:
|
48 |
-
"""Check if audio has significant energy (stricter than just non-silent)."""
|
49 |
-
return not _is_silent(audio, threshold_db)
|
50 |
-
|
51 |
# -----------------------------
|
52 |
# Data classes
|
53 |
# -----------------------------
|
@@ -66,7 +55,7 @@ class JamParams:
|
|
66 |
guidance_weight: float = 1.1
|
67 |
temperature: float = 1.1
|
68 |
topk: int = 40
|
69 |
-
style_ramp_seconds: float = 8.0
|
70 |
|
71 |
|
72 |
@dataclass
|
@@ -121,6 +110,8 @@ class JamWorker(threading.Thread):
|
|
121 |
self.mrt.temperature = float(self.params.temperature)
|
122 |
self.mrt.topk = int(self.params.topk)
|
123 |
|
|
|
|
|
124 |
# codec/setup
|
125 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
126 |
JamWorker.FRAMES_PER_SECOND = self._codec_fps
|
@@ -146,9 +137,8 @@ class JamWorker(threading.Thread):
|
|
146 |
self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
|
147 |
self._spool_written = 0 # absolute frames written into spool
|
148 |
|
149 |
-
#
|
150 |
-
self.
|
151 |
-
self._last_good_context_tokens = None # backup of last known good context
|
152 |
|
153 |
# bar clock: start with offset 0; if you have a downbeat estimator, set base later
|
154 |
self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
|
@@ -173,47 +163,6 @@ class JamWorker(threading.Thread):
|
|
173 |
# Prepare initial context from combined loop (best musical alignment)
|
174 |
if self.params.combined_loop is not None:
|
175 |
self._install_context_from_loop(self.params.combined_loop)
|
176 |
-
# Save this as our "good" context backup
|
177 |
-
if hasattr(self.state, 'context_tokens') and self.state.context_tokens is not None:
|
178 |
-
self._last_good_context_tokens = np.copy(self.state.context_tokens)
|
179 |
-
|
180 |
-
# ---------- NEW: Health monitoring methods ----------
|
181 |
-
|
182 |
-
def _check_model_health(self, new_chunk: np.ndarray) -> bool:
|
183 |
-
"""Check if the model output looks healthy."""
|
184 |
-
if _is_silent(new_chunk, threshold_db=-80.0):
|
185 |
-
self._silence_streak += 1
|
186 |
-
print(f"⚠️ Silent chunk detected (streak: {self._silence_streak})")
|
187 |
-
return False
|
188 |
-
else:
|
189 |
-
if self._silence_streak > 0:
|
190 |
-
print(f"✅ Audio resumed after {self._silence_streak} silent chunks")
|
191 |
-
self._silence_streak = 0
|
192 |
-
return True
|
193 |
-
|
194 |
-
def _recover_from_silence(self):
|
195 |
-
"""Attempt to recover from silence by restoring last good context."""
|
196 |
-
print("🔧 Attempting recovery from silence...")
|
197 |
-
|
198 |
-
if self._last_good_context_tokens is not None:
|
199 |
-
# Restore last known good context
|
200 |
-
try:
|
201 |
-
new_state = self.mrt.init_state()
|
202 |
-
new_state.context_tokens = np.copy(self._last_good_context_tokens)
|
203 |
-
self.state = new_state
|
204 |
-
self._model_stream = None # Reset stream to start fresh
|
205 |
-
print(" Restored last good context")
|
206 |
-
except Exception as e:
|
207 |
-
print(f" Context restoration failed: {e}")
|
208 |
-
|
209 |
-
# If we have the original loop, rebuild context from it
|
210 |
-
elif self.params.combined_loop is not None:
|
211 |
-
try:
|
212 |
-
self._install_context_from_loop(self.params.combined_loop)
|
213 |
-
self._model_stream = None
|
214 |
-
print(" Rebuilt context from original loop")
|
215 |
-
except Exception as e:
|
216 |
-
print(f" Context rebuild failed: {e}")
|
217 |
|
218 |
# ---------- lifecycle ----------
|
219 |
|
@@ -299,7 +248,13 @@ class JamWorker(threading.Thread):
|
|
299 |
return toks
|
300 |
|
301 |
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
|
302 |
-
"""Build *exactly* context_length_frames worth of tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
wav = loop.as_stereo().resample(self._model_sr)
|
304 |
data = wav.samples.astype(np.float32, copy=False)
|
305 |
if data.ndim == 1:
|
@@ -334,14 +289,8 @@ class JamWorker(threading.Thread):
|
|
334 |
|
335 |
# final snap to *exact* ctx samples
|
336 |
if ctx.shape[0] < ctx_samps:
|
337 |
-
|
338 |
-
|
339 |
-
if ctx.shape[0] > 0:
|
340 |
-
fill = np.tile(ctx, (int(np.ceil(shortfall / ctx.shape[0])) + 1, 1))[:shortfall]
|
341 |
-
ctx = np.concatenate([fill, ctx], axis=0)
|
342 |
-
else:
|
343 |
-
print("⚠️ Zero-length context, using fallback")
|
344 |
-
ctx = np.zeros((ctx_samps, 2), dtype=np.float32)
|
345 |
elif ctx.shape[0] > ctx_samps:
|
346 |
ctx = ctx[-ctx_samps:]
|
347 |
|
@@ -352,20 +301,79 @@ class JamWorker(threading.Thread):
|
|
352 |
|
353 |
# Force expected (F,D) at *return time*
|
354 |
tokens = self._coerce_tokens(tokens)
|
355 |
-
|
356 |
-
# Validate that we don't have a silent context
|
357 |
-
if _is_silent(ctx, threshold_db=-80.0):
|
358 |
-
print("⚠️ Generated silent context - this may cause issues")
|
359 |
-
|
360 |
return tokens
|
361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
def _install_context_from_loop(self, loop: au.Waveform):
|
363 |
# Build exact-length, bar-locked context tokens
|
364 |
context_tokens = self._encode_exact_context_tokens(loop)
|
365 |
s = self.mrt.init_state()
|
366 |
s.context_tokens = context_tokens
|
367 |
self.state = s
|
368 |
-
self.
|
369 |
|
370 |
def reseed_from_waveform(self, wav: au.Waveform):
|
371 |
"""Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
|
@@ -375,11 +383,14 @@ class JamWorker(threading.Thread):
|
|
375 |
s.context_tokens = context_tokens
|
376 |
self.state = s
|
377 |
self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
|
378 |
-
self.
|
379 |
-
self._silence_streak = 0 # Reset health monitoring
|
380 |
|
381 |
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
382 |
-
"""Queue a *seamless* reseed by token splicing instead of full restart.
|
|
|
|
|
|
|
|
|
383 |
new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
|
384 |
F, D = self._expected_token_shape()
|
385 |
|
@@ -408,20 +419,44 @@ class JamWorker(threading.Thread):
|
|
408 |
"tokens": spliced,
|
409 |
"debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
|
410 |
}
|
|
|
411 |
|
412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
|
414 |
def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
|
415 |
"""
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
|
|
|
|
423 |
"""
|
424 |
-
|
|
|
425 |
s = wav.samples.astype(np.float32, copy=False)
|
426 |
if s.ndim == 1:
|
427 |
s = s[:, None]
|
@@ -429,103 +464,119 @@ class JamWorker(threading.Thread):
|
|
429 |
if n_samps == 0:
|
430 |
return
|
431 |
|
432 |
-
#
|
433 |
-
is_healthy = self._check_model_health(s)
|
434 |
-
is_very_quiet = _is_silent(s, threshold_db=-50.0) # stricter than default -60
|
435 |
-
|
436 |
-
# Get crossfade params
|
437 |
try:
|
438 |
xfade_s = float(self.mrt.config.crossfade_length)
|
439 |
except Exception:
|
440 |
xfade_s = 0.0
|
441 |
xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr)))
|
442 |
|
443 |
-
|
444 |
-
|
445 |
-
# --- REJECT PROBLEMATIC CHUNKS ---
|
446 |
-
if not is_healthy or is_very_quiet:
|
447 |
-
print(f"[REJECT] Discarding unhealthy/quiet chunk - not adding to spool or model stream")
|
448 |
-
|
449 |
-
# Trigger recovery immediately on first bad chunk
|
450 |
-
if self._silence_streak >= 1:
|
451 |
-
self._recover_from_silence()
|
452 |
-
|
453 |
-
# Don't process this chunk at all - return early
|
454 |
-
return
|
455 |
-
|
456 |
-
# Reset silence streak on good chunk
|
457 |
-
if self._silence_streak > 0:
|
458 |
-
print(f"✅ Audio resumed after {self._silence_streak} rejected chunks")
|
459 |
-
self._silence_streak = 0
|
460 |
-
|
461 |
-
# Helper: resample to target SR
|
462 |
def to_target(y: np.ndarray) -> np.ndarray:
|
463 |
return y if self._rs is None else self._rs.process(y, final=False)
|
464 |
|
465 |
-
#
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
else:
|
482 |
-
#
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
506 |
if xfade_n > 0 and n_samps >= xfade_n:
|
507 |
-
|
508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
509 |
else:
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
target_audio = to_target(new_audio)
|
515 |
-
if target_audio.shape[0] > 0:
|
516 |
-
print(f"[append] body len={target_audio.shape[0]} rms={_dbg_rms_dbfs(target_audio):+.1f} dBFS")
|
517 |
-
self._spool = np.concatenate([self._spool, target_audio], axis=0) if self._spool.size else target_audio
|
518 |
-
self._spool_written += target_audio.shape[0]
|
519 |
-
|
520 |
-
# --- SAVE GOOD CONTEXT ---
|
521 |
-
# Only save context from healthy chunks
|
522 |
-
if hasattr(self.state, 'context_tokens') and self.state.context_tokens is not None:
|
523 |
-
self._last_good_context_tokens = np.copy(self.state.context_tokens)
|
524 |
-
|
525 |
-
# Trim model stream to reasonable length (keep ~30 seconds)
|
526 |
-
max_model_samples = int(30.0 * self._model_sr)
|
527 |
-
if self._model_stream.shape[0] > max_model_samples:
|
528 |
-
self._model_stream = self._model_stream[-max_model_samples:]
|
529 |
|
530 |
def _should_generate_next_chunk(self) -> bool:
|
531 |
# Allow running ahead relative to whichever is larger: last *consumed*
|
@@ -562,7 +613,6 @@ class JamWorker(threading.Thread):
|
|
562 |
"guidance_weight": float(self.params.guidance_weight),
|
563 |
"temperature": float(self.params.temperature),
|
564 |
"topk": int(self.params.topk),
|
565 |
-
"silence_streak": self._silence_streak, # Add health info
|
566 |
}
|
567 |
chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
|
568 |
|
@@ -587,7 +637,6 @@ class JamWorker(threading.Thread):
|
|
587 |
# inplace update (no reset)
|
588 |
self.state.context_tokens = spliced
|
589 |
self._pending_token_splice = None
|
590 |
-
print("[reseed] Token splice applied")
|
591 |
except Exception:
|
592 |
# fallback: full reseed using spliced tokens
|
593 |
new_state = self.mrt.init_state()
|
@@ -595,7 +644,6 @@ class JamWorker(threading.Thread):
|
|
595 |
self.state = new_state
|
596 |
self._model_stream = None
|
597 |
self._pending_token_splice = None
|
598 |
-
print("[reseed] Token splice fallback to full reset")
|
599 |
elif self._pending_reseed is not None:
|
600 |
ctx = self._coerce_tokens(self._pending_reseed["ctx"])
|
601 |
new_state = self.mrt.init_state()
|
@@ -603,7 +651,6 @@ class JamWorker(threading.Thread):
|
|
603 |
self.state = new_state
|
604 |
self._model_stream = None
|
605 |
self._pending_reseed = None
|
606 |
-
print("[reseed] Full reseed applied")
|
607 |
|
608 |
# ---------- main loop ----------
|
609 |
|
@@ -640,10 +687,9 @@ class JamWorker(threading.Thread):
|
|
640 |
self._emit_ready()
|
641 |
|
642 |
# finalize resampler (flush) — not strictly necessary here
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
self._spool_written += tail.shape[0]
|
648 |
# one last emit attempt
|
649 |
-
self._emit_ready()
|
|
|
1 |
+
# jam_worker.py - Bar-locked spool rewrite
|
2 |
from __future__ import annotations
|
3 |
|
4 |
import os
|
|
|
20 |
)
|
21 |
|
22 |
def _dbg_rms_dbfs(x: np.ndarray) -> float:
|
23 |
+
|
24 |
if x.ndim == 2:
|
25 |
x = x.mean(axis=1)
|
26 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
|
|
28 |
|
29 |
def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
|
30 |
# x is model-rate, shape [S,C] or [S]
|
31 |
+
|
32 |
if x.ndim == 2:
|
33 |
x = x.mean(axis=1)
|
34 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
|
|
37 |
def _dbg_shape(x):
|
38 |
return tuple(x.shape) if hasattr(x, "shape") else ("-",)
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# -----------------------------
|
41 |
# Data classes
|
42 |
# -----------------------------
|
|
|
55 |
guidance_weight: float = 1.1
|
56 |
temperature: float = 1.1
|
57 |
topk: int = 40
|
58 |
+
style_ramp_seconds: float = 8.0 # 0 => instant (current behavior), try 6.0–10.0 for gentle glides
|
59 |
|
60 |
|
61 |
@dataclass
|
|
|
110 |
self.mrt.temperature = float(self.params.temperature)
|
111 |
self.mrt.topk = int(self.params.topk)
|
112 |
|
113 |
+
|
114 |
+
|
115 |
# codec/setup
|
116 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
117 |
JamWorker.FRAMES_PER_SECOND = self._codec_fps
|
|
|
137 |
self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
|
138 |
self._spool_written = 0 # absolute frames written into spool
|
139 |
|
140 |
+
self._pending_tail_model = None # type: Optional[np.ndarray] # last tail at model SR
|
141 |
+
self._pending_tail_target_len = 0 # number of target-SR samples last tail contributed
|
|
|
142 |
|
143 |
# bar clock: start with offset 0; if you have a downbeat estimator, set base later
|
144 |
self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
|
|
|
163 |
# Prepare initial context from combined loop (best musical alignment)
|
164 |
if self.params.combined_loop is not None:
|
165 |
self._install_context_from_loop(self.params.combined_loop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
# ---------- lifecycle ----------
|
168 |
|
|
|
248 |
return toks
|
249 |
|
250 |
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
|
251 |
+
"""Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
|
252 |
+
while ensuring the *end* of the audio lands on a bar boundary.
|
253 |
+
Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
|
254 |
+
then left-fill from just before that tail (wrapping if needed) to reach exactly
|
255 |
+
ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
|
256 |
+
tokens to the expected frame count.
|
257 |
+
"""
|
258 |
wav = loop.as_stereo().resample(self._model_sr)
|
259 |
data = wav.samples.astype(np.float32, copy=False)
|
260 |
if data.ndim == 1:
|
|
|
289 |
|
290 |
# final snap to *exact* ctx samples
|
291 |
if ctx.shape[0] < ctx_samps:
|
292 |
+
pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
|
293 |
+
ctx = np.concatenate([pad, ctx], axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
elif ctx.shape[0] > ctx_samps:
|
295 |
ctx = ctx[-ctx_samps:]
|
296 |
|
|
|
301 |
|
302 |
# Force expected (F,D) at *return time*
|
303 |
tokens = self._coerce_tokens(tokens)
|
|
|
|
|
|
|
|
|
|
|
304 |
return tokens
|
305 |
|
306 |
+
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
|
307 |
+
"""Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
|
308 |
+
while ensuring the *end* of the audio lands on a bar boundary.
|
309 |
+
Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
|
310 |
+
then left-fill from just before that tail (wrapping if needed) to reach exactly
|
311 |
+
ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
|
312 |
+
tokens to the expected frame count.
|
313 |
+
"""
|
314 |
+
wav = loop.as_stereo().resample(self._model_sr)
|
315 |
+
data = wav.samples.astype(np.float32, copy=False)
|
316 |
+
if data.ndim == 1:
|
317 |
+
data = data[:, None]
|
318 |
+
|
319 |
+
spb = self._bar_clock.seconds_per_bar()
|
320 |
+
ctx_sec = float(self._ctx_seconds)
|
321 |
+
sr = int(self._model_sr)
|
322 |
+
|
323 |
+
# bars that fit fully inside ctx_sec (at least 1)
|
324 |
+
bars_fit = max(1, int(ctx_sec // spb))
|
325 |
+
tail_len_samps = int(round(bars_fit * spb * sr))
|
326 |
+
|
327 |
+
# ensure we have enough source by tiling
|
328 |
+
need = int(round(ctx_sec * sr)) + tail_len_samps
|
329 |
+
if data.shape[0] == 0:
|
330 |
+
data = np.zeros((1, 2), dtype=np.float32)
|
331 |
+
reps = int(np.ceil(need / float(data.shape[0])))
|
332 |
+
tiled = np.tile(data, (reps, 1))
|
333 |
+
|
334 |
+
end = tiled.shape[0]
|
335 |
+
tail = tiled[end - tail_len_samps:end]
|
336 |
+
|
337 |
+
# left-fill to reach exact ctx samples (keeps end-of-bar alignment)
|
338 |
+
ctx_samps = int(round(ctx_sec * sr))
|
339 |
+
pad_len = ctx_samps - tail.shape[0]
|
340 |
+
if pad_len > 0:
|
341 |
+
pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps]
|
342 |
+
ctx = np.concatenate([pre, tail], axis=0)
|
343 |
+
else:
|
344 |
+
ctx = tail[-ctx_samps:]
|
345 |
+
|
346 |
+
# final snap to *exact* ctx samples
|
347 |
+
if ctx.shape[0] < ctx_samps:
|
348 |
+
pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
|
349 |
+
ctx = np.concatenate([pad, ctx], axis=0)
|
350 |
+
elif ctx.shape[0] > ctx_samps:
|
351 |
+
ctx = ctx[-ctx_samps:]
|
352 |
+
|
353 |
+
exact = au.Waveform(ctx, sr)
|
354 |
+
tokens_full = self.mrt.codec.encode(exact).astype(np.int32)
|
355 |
+
depth = int(self.mrt.config.decoder_codec_rvq_depth)
|
356 |
+
tokens = tokens_full[:, :depth]
|
357 |
+
|
358 |
+
# Last defense: force expected frame count
|
359 |
+
frames = tokens.shape[0]
|
360 |
+
exp = int(self._ctx_frames)
|
361 |
+
if frames < exp:
|
362 |
+
# repeat last frame
|
363 |
+
pad = np.repeat(tokens[-1:, :], exp - frames, axis=0)
|
364 |
+
tokens = np.concatenate([pad, tokens], axis=0)
|
365 |
+
elif frames > exp:
|
366 |
+
tokens = tokens[-exp:, :]
|
367 |
+
return tokens
|
368 |
+
|
369 |
+
|
370 |
def _install_context_from_loop(self, loop: au.Waveform):
|
371 |
# Build exact-length, bar-locked context tokens
|
372 |
context_tokens = self._encode_exact_context_tokens(loop)
|
373 |
s = self.mrt.init_state()
|
374 |
s.context_tokens = context_tokens
|
375 |
self.state = s
|
376 |
+
self._original_context_tokens = np.copy(context_tokens)
|
377 |
|
378 |
def reseed_from_waveform(self, wav: au.Waveform):
|
379 |
"""Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
|
|
|
383 |
s.context_tokens = context_tokens
|
384 |
self.state = s
|
385 |
self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
|
386 |
+
self._original_context_tokens = np.copy(context_tokens)
|
|
|
387 |
|
388 |
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
389 |
+
"""Queue a *seamless* reseed by token splicing instead of full restart.
|
390 |
+
We compute a fresh, bar-locked context token tensor of exact length
|
391 |
+
(e.g., 250 frames), then splice only the *tail* corresponding to
|
392 |
+
`anchor_bars` so generation continues smoothly without resetting state.
|
393 |
+
"""
|
394 |
new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
|
395 |
F, D = self._expected_token_shape()
|
396 |
|
|
|
419 |
"tokens": spliced,
|
420 |
"debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
|
421 |
}
|
422 |
+
|
423 |
|
424 |
+
|
425 |
+
def reseed_from_waveform(self, wav: au.Waveform):
|
426 |
+
"""Immediate reseed: replace context from provided wave (bar-aligned tail)."""
|
427 |
+
wav = wav.as_stereo().resample(self._model_sr)
|
428 |
+
tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
|
429 |
+
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
430 |
+
depth = int(self.mrt.config.decoder_codec_rvq_depth)
|
431 |
+
context_tokens = tokens_full[:, :depth]
|
432 |
+
|
433 |
+
s = self.mrt.init_state()
|
434 |
+
s.context_tokens = context_tokens
|
435 |
+
self.state = s
|
436 |
+
# reset model stream so next generate starts cleanly
|
437 |
+
self._model_stream = None
|
438 |
+
|
439 |
+
# optional loudness match will be applied per-chunk on emission
|
440 |
+
|
441 |
+
# also remember this as new "original"
|
442 |
+
self._original_context_tokens = np.copy(context_tokens)
|
443 |
+
|
444 |
+
# ---------- core streaming helpers ----------
|
445 |
|
446 |
def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
|
447 |
"""
|
448 |
+
Conservative boundary fix:
|
449 |
+
- Emit body+tail immediately (target SR), unchanged from your original behavior.
|
450 |
+
- On *next* call, compute the mixed overlap (prev tail ⨉ cos + new head ⨉ sin),
|
451 |
+
resample it, and overwrite the last `_pending_tail_target_len` samples in the
|
452 |
+
target-SR spool with that mixed overlap. Then emit THIS chunk's body+tail and
|
453 |
+
remember THIS chunk's tail length at target SR for the next correction.
|
454 |
+
|
455 |
+
This keeps external timing and bar alignment identical, but removes the audible
|
456 |
+
fade-to-zero at chunk ends.
|
457 |
"""
|
458 |
+
|
459 |
+
# ---- unpack model-rate samples ----
|
460 |
s = wav.samples.astype(np.float32, copy=False)
|
461 |
if s.ndim == 1:
|
462 |
s = s[:, None]
|
|
|
464 |
if n_samps == 0:
|
465 |
return
|
466 |
|
467 |
+
# crossfade length in model samples
|
|
|
|
|
|
|
|
|
468 |
try:
|
469 |
xfade_s = float(self.mrt.config.crossfade_length)
|
470 |
except Exception:
|
471 |
xfade_s = 0.0
|
472 |
xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr)))
|
473 |
|
474 |
+
# helper: resample to target SR via your streaming resampler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
475 |
def to_target(y: np.ndarray) -> np.ndarray:
|
476 |
return y if self._rs is None else self._rs.process(y, final=False)
|
477 |
|
478 |
+
# ------------------------------------------
|
479 |
+
# (A) If we have a pending model tail, fix the last emitted tail at target SR
|
480 |
+
# ------------------------------------------
|
481 |
+
if self._pending_tail_model is not None and self._pending_tail_model.shape[0] == xfade_n and xfade_n > 0 and n_samps >= xfade_n:
|
482 |
+
head = s[:xfade_n, :]
|
483 |
+
|
484 |
+
print(f"[model] head len={head.shape[0]} rms={_dbg_rms_dbfs_model(head):+.1f} dBFS")
|
485 |
+
|
486 |
+
t = np.linspace(0.0, np.pi/2.0, xfade_n, endpoint=False, dtype=np.float32)[:, None]
|
487 |
+
cosw = np.cos(t, dtype=np.float32)
|
488 |
+
sinw = np.sin(t, dtype=np.float32)
|
489 |
+
mixed_model = (self._pending_tail_model * cosw) + (head * sinw) # [xfade_n, C] at model SR
|
490 |
+
|
491 |
+
y_mixed = to_target(mixed_model.astype(np.float32))
|
492 |
+
Lcorr = int(y_mixed.shape[0]) # exact target-SR samples to write
|
493 |
+
|
494 |
+
# DEBUG: corrected overlap RMS (what we intend to hear at the boundary)
|
495 |
+
if y_mixed.size:
|
496 |
+
print(f"[append] mixedOverlap len={y_mixed.shape[0]} rms={_dbg_rms_dbfs(y_mixed):+.1f} dBFS")
|
497 |
+
|
498 |
+
# Overwrite the last `_pending_tail_target_len` samples of the spool with `y_mixed`.
|
499 |
+
# Use the *smaller* of the two lengths to be safe.
|
500 |
+
Lpop = min(self._pending_tail_target_len, self._spool.shape[0], Lcorr)
|
501 |
+
if Lpop > 0 and self._spool.size:
|
502 |
+
# Trim last Lpop samples
|
503 |
+
self._spool = self._spool[:-Lpop, :]
|
504 |
+
self._spool_written -= Lpop
|
505 |
+
# Append corrected overlap (trim/pad to Lpop to avoid drift)
|
506 |
+
if Lcorr != Lpop:
|
507 |
+
if Lcorr > Lpop:
|
508 |
+
y_m = y_mixed[-Lpop:, :]
|
509 |
+
else:
|
510 |
+
pad = np.zeros((Lpop - Lcorr, y_mixed.shape[1]), dtype=np.float32)
|
511 |
+
y_m = np.concatenate([y_mixed, pad], axis=0)
|
512 |
+
else:
|
513 |
+
y_m = y_mixed
|
514 |
+
self._spool = np.concatenate([self._spool, y_m], axis=0) if self._spool.size else y_m
|
515 |
+
self._spool_written += y_m.shape[0]
|
516 |
+
|
517 |
+
# For internal continuity, update _model_stream like before
|
518 |
+
if self._model_stream is None or self._model_stream.shape[0] < xfade_n:
|
519 |
+
self._model_stream = s[xfade_n:].copy()
|
520 |
+
else:
|
521 |
+
self._model_stream = np.concatenate([self._model_stream[:-xfade_n], mixed_model, s[xfade_n:]], axis=0)
|
522 |
else:
|
523 |
+
# First-ever call or too-short to mix: maintain _model_stream minimally
|
524 |
+
if xfade_n > 0 and n_samps > xfade_n:
|
525 |
+
self._model_stream = s[xfade_n:].copy() if self._model_stream is None else np.concatenate([self._model_stream, s[xfade_n:]], axis=0)
|
526 |
+
else:
|
527 |
+
self._model_stream = s.copy() if self._model_stream is None else np.concatenate([self._model_stream, s], axis=0)
|
528 |
+
|
529 |
+
# ------------------------------------------
|
530 |
+
# (B) Emit THIS chunk's body and tail (same external behavior)
|
531 |
+
# ------------------------------------------
|
532 |
+
if xfade_n > 0 and n_samps >= (2 * xfade_n):
|
533 |
+
body = s[xfade_n:-xfade_n, :]
|
534 |
+
print(f"[model] body len={body.shape[0]} rms={_dbg_rms_dbfs_model(body):+.1f} dBFS")
|
535 |
+
if body.size:
|
536 |
+
y_body = to_target(body.astype(np.float32))
|
537 |
+
if y_body.size:
|
538 |
+
# DEBUG: body RMS we are actually appending
|
539 |
+
print(f"[append] body len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
|
540 |
+
self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
|
541 |
+
self._spool_written += y_body.shape[0]
|
542 |
+
else:
|
543 |
+
# If chunk too short for head+tail split, treat all (minus preroll) as body
|
544 |
+
if xfade_n > 0 and n_samps > xfade_n:
|
545 |
+
body = s[xfade_n:, :]
|
546 |
+
print(f"[model] body(S) len={body.shape[0]} rms={_dbg_rms_dbfs_model(body):+.1f} dBFS")
|
547 |
+
y_body = to_target(body.astype(np.float32))
|
548 |
+
if y_body.size:
|
549 |
+
# DEBUG: body RMS in short-chunk path
|
550 |
+
print(f"[append] body(len=short) len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
|
551 |
+
self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
|
552 |
+
self._spool_written += y_body.shape[0]
|
553 |
+
# No tail to remember this round
|
554 |
+
self._pending_tail_model = None
|
555 |
+
self._pending_tail_target_len = 0
|
556 |
+
return
|
557 |
+
|
558 |
+
# Tail (always remember how many TARGET samples we append)
|
559 |
if xfade_n > 0 and n_samps >= xfade_n:
|
560 |
+
tail = s[-xfade_n:, :]
|
561 |
+
print(f"[model] tail len={tail.shape[0]} rms={_dbg_rms_dbfs_model(tail):+.1f} dBFS")
|
562 |
+
y_tail = to_target(tail.astype(np.float32))
|
563 |
+
Ltail = int(y_tail.shape[0])
|
564 |
+
if Ltail:
|
565 |
+
# DEBUG: tail RMS we are appending now (to be corrected next call)
|
566 |
+
print(f"[append] tail len={y_tail.shape[0]} rms={_dbg_rms_dbfs(y_tail):+.1f} dBFS")
|
567 |
+
self._spool = np.concatenate([self._spool, y_tail], axis=0) if self._spool.size else y_tail
|
568 |
+
self._spool_written += Ltail
|
569 |
+
self._pending_tail_model = tail.copy()
|
570 |
+
self._pending_tail_target_len = Ltail
|
571 |
+
else:
|
572 |
+
# Nothing appended (resampler returned nothing yet) — keep model tail but mark zero target len
|
573 |
+
self._pending_tail_model = tail.copy()
|
574 |
+
self._pending_tail_target_len = 0
|
575 |
else:
|
576 |
+
self._pending_tail_model = None
|
577 |
+
self._pending_tail_target_len = 0
|
578 |
+
|
579 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
|
581 |
def _should_generate_next_chunk(self) -> bool:
|
582 |
# Allow running ahead relative to whichever is larger: last *consumed*
|
|
|
613 |
"guidance_weight": float(self.params.guidance_weight),
|
614 |
"temperature": float(self.params.temperature),
|
615 |
"topk": int(self.params.topk),
|
|
|
616 |
}
|
617 |
chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
|
618 |
|
|
|
637 |
# inplace update (no reset)
|
638 |
self.state.context_tokens = spliced
|
639 |
self._pending_token_splice = None
|
|
|
640 |
except Exception:
|
641 |
# fallback: full reseed using spliced tokens
|
642 |
new_state = self.mrt.init_state()
|
|
|
644 |
self.state = new_state
|
645 |
self._model_stream = None
|
646 |
self._pending_token_splice = None
|
|
|
647 |
elif self._pending_reseed is not None:
|
648 |
ctx = self._coerce_tokens(self._pending_reseed["ctx"])
|
649 |
new_state = self.mrt.init_state()
|
|
|
651 |
self.state = new_state
|
652 |
self._model_stream = None
|
653 |
self._pending_reseed = None
|
|
|
654 |
|
655 |
# ---------- main loop ----------
|
656 |
|
|
|
687 |
self._emit_ready()
|
688 |
|
689 |
# finalize resampler (flush) — not strictly necessary here
|
690 |
+
tail = self._rs.process(np.zeros((0,2), np.float32), final=True)
|
691 |
+
if tail.size:
|
692 |
+
self._spool = np.concatenate([self._spool, tail], axis=0)
|
693 |
+
self._spool_written += tail.shape[0]
|
|
|
694 |
# one last emit attempt
|
695 |
+
self._emit_ready()
|