|  |  | 
					
						
						|  | from __future__ import annotations | 
					
						
						|  |  | 
					
						
						|  | import threading, time | 
					
						
						|  | from dataclasses import dataclass | 
					
						
						|  | from fractions import Fraction | 
					
						
						|  | from typing import Optional, Dict, Tuple, List | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | from magenta_rt import audio as au | 
					
						
						|  |  | 
					
						
						|  | from utils import ( | 
					
						
						|  | StreamingResampler, | 
					
						
						|  | match_loudness_to_reference, | 
					
						
						|  | make_bar_aligned_context, | 
					
						
						|  | take_bar_aligned_tail, | 
					
						
						|  | wav_bytes_base64, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class JamParams: | 
					
						
						|  | bpm: float | 
					
						
						|  | beats_per_bar: int | 
					
						
						|  | bars_per_chunk: int | 
					
						
						|  | target_sr: int | 
					
						
						|  | loudness_mode: str = "auto" | 
					
						
						|  | headroom_db: float = 1.0 | 
					
						
						|  | style_vec: Optional[np.ndarray] = None | 
					
						
						|  | ref_loop: Optional[au.Waveform] = None | 
					
						
						|  | combined_loop: Optional[au.Waveform] = None | 
					
						
						|  | guidance_weight: float = 1.1 | 
					
						
						|  | temperature: float = 1.1 | 
					
						
						|  | topk: int = 40 | 
					
						
						|  | style_ramp_seconds: float = 0.0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class JamChunk: | 
					
						
						|  | index: int | 
					
						
						|  | audio_base64: str | 
					
						
						|  | metadata: dict | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BarClock: | 
					
						
						|  | """Sample-domain bar clock with drift-free absolute boundaries.""" | 
					
						
						|  | def __init__(self, target_sr: int, bpm: float, beats_per_bar: int, base_offset_samples: int = 0): | 
					
						
						|  | self.sr = int(target_sr) | 
					
						
						|  | self.bpm = Fraction(str(bpm)) | 
					
						
						|  | self.beats_per_bar = int(beats_per_bar) | 
					
						
						|  | self.bar_samps = Fraction(self.sr * 60 * self.beats_per_bar, 1) / self.bpm | 
					
						
						|  | self.base = int(base_offset_samples) | 
					
						
						|  |  | 
					
						
						|  | def bounds_for_chunk(self, chunk_index: int, bars_per_chunk: int) -> Tuple[int, int]: | 
					
						
						|  | start_f = self.base + self.bar_samps * (chunk_index * bars_per_chunk) | 
					
						
						|  | end_f   = self.base + self.bar_samps * ((chunk_index + 1) * bars_per_chunk) | 
					
						
						|  | return int(round(start_f)), int(round(end_f)) | 
					
						
						|  |  | 
					
						
						|  | def seconds_per_bar(self) -> float: | 
					
						
						|  | return float(self.beats_per_bar) * (60.0 / float(self.bpm)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class JamWorker(threading.Thread): | 
					
						
						|  | FRAMES_PER_SECOND: float | None = None | 
					
						
						|  | """Generates continuous audio with MagentaRT, spools it at target SR, | 
					
						
						|  | and emits *sample-accurate*, bar-aligned chunks (no FPS drift).""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, mrt, params: JamParams): | 
					
						
						|  | super().__init__(daemon=True) | 
					
						
						|  | self.mrt = mrt | 
					
						
						|  | self.params = params | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._lock = threading.RLock() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.state = self.mrt.init_state() | 
					
						
						|  | self.mrt.guidance_weight = float(self.params.guidance_weight) | 
					
						
						|  | self.mrt.temperature     = float(self.params.temperature) | 
					
						
						|  | self.mrt.topk            = int(self.params.topk) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._codec_fps = float(self.mrt.codec.frame_rate) | 
					
						
						|  | JamWorker.FRAMES_PER_SECOND = self._codec_fps | 
					
						
						|  | self._ctx_frames = int(self.mrt.config.context_length_frames) | 
					
						
						|  | self._ctx_seconds = self._ctx_frames / self._codec_fps | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._model_stream: Optional[np.ndarray] = None | 
					
						
						|  | self._model_sr = int(self.mrt.sample_rate) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._style_vec = (None if self.params.style_vec is None | 
					
						
						|  | else np.array(self.params.style_vec, dtype=np.float32, copy=True)) | 
					
						
						|  | self._chunk_secs = ( | 
					
						
						|  | self.mrt.config.chunk_length_frames * self.mrt.config.frame_length_samples | 
					
						
						|  | ) / float(self._model_sr) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if int(self.params.target_sr) != int(self._model_sr): | 
					
						
						|  | self._rs = StreamingResampler(self._model_sr, int(self.params.target_sr), channels=2) | 
					
						
						|  | else: | 
					
						
						|  | self._rs = None | 
					
						
						|  | self._spool = np.zeros((0, 2), dtype=np.float32) | 
					
						
						|  | self._spool_written = 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.idx = 0 | 
					
						
						|  | self._next_to_deliver = 0 | 
					
						
						|  | self._last_consumed_index = -1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._outbox: Dict[int, JamChunk] = {} | 
					
						
						|  | self._cv = threading.Condition() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._stop_event = threading.Event() | 
					
						
						|  | self._max_buffer_ahead = 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._pending_reseed: Optional[dict] = None | 
					
						
						|  | self._pending_token_splice: Optional[dict] = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.params.combined_loop is not None: | 
					
						
						|  | self._install_context_from_loop(self.params.combined_loop) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def set_buffer_seconds(self, seconds: float): | 
					
						
						|  | """Clamp how far ahead we allow, in *seconds* of audio.""" | 
					
						
						|  | chunk_secs = float(self.params.bars_per_chunk) * self._bar_clock.seconds_per_bar() | 
					
						
						|  | max_chunks = max(0, int(round(seconds / max(chunk_secs, 1e-6)))) | 
					
						
						|  | with self._cv: | 
					
						
						|  | self._max_buffer_ahead = max_chunks | 
					
						
						|  |  | 
					
						
						|  | def set_buffer_chunks(self, k: int): | 
					
						
						|  | with self._cv: | 
					
						
						|  | self._max_buffer_ahead = max(0, int(k)) | 
					
						
						|  |  | 
					
						
						|  | def stop(self): | 
					
						
						|  | self._stop_event.set() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_next_chunk(self, timeout: float = 30.0) -> Optional[JamChunk]: | 
					
						
						|  | deadline = time.time() + timeout | 
					
						
						|  | with self._cv: | 
					
						
						|  | while True: | 
					
						
						|  | c = self._outbox.get(self._next_to_deliver) | 
					
						
						|  | if c is not None: | 
					
						
						|  | self._next_to_deliver += 1 | 
					
						
						|  | return c | 
					
						
						|  | remaining = deadline - time.time() | 
					
						
						|  | if remaining <= 0: | 
					
						
						|  | return None | 
					
						
						|  | self._cv.wait(timeout=min(0.25, remaining)) | 
					
						
						|  |  | 
					
						
						|  | def mark_chunk_consumed(self, chunk_index: int): | 
					
						
						|  |  | 
					
						
						|  | with self._cv: | 
					
						
						|  | self._last_consumed_index = max(self._last_consumed_index, int(chunk_index)) | 
					
						
						|  |  | 
					
						
						|  | for k in list(self._outbox.keys()): | 
					
						
						|  | if k < self._last_consumed_index - 1: | 
					
						
						|  | self._outbox.pop(k, None) | 
					
						
						|  |  | 
					
						
						|  | def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None): | 
					
						
						|  | with self._lock: | 
					
						
						|  | if guidance_weight is not None: | 
					
						
						|  | self.params.guidance_weight = float(guidance_weight) | 
					
						
						|  | if temperature is not None: | 
					
						
						|  | self.params.temperature = float(temperature) | 
					
						
						|  | if topk is not None: | 
					
						
						|  | self.params.topk = int(topk) | 
					
						
						|  |  | 
					
						
						|  | self.mrt.guidance_weight = float(self.params.guidance_weight) | 
					
						
						|  | self.mrt.temperature     = float(self.params.temperature) | 
					
						
						|  | self.mrt.topk            = int(self.params.topk) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _expected_token_shape(self) -> Tuple[int, int]: | 
					
						
						|  | F = int(self._ctx_frames) | 
					
						
						|  | D = int(self.mrt.config.decoder_codec_rvq_depth) | 
					
						
						|  | return F, D | 
					
						
						|  |  | 
					
						
						|  | def _coerce_tokens(self, toks: np.ndarray) -> np.ndarray: | 
					
						
						|  | """Force tokens to (context_length_frames, rvq_depth), padding/trimming as needed. | 
					
						
						|  | Pads missing frames by repeating the last frame (safer than zeros for RVQ stacks).""" | 
					
						
						|  | F, D = self._expected_token_shape() | 
					
						
						|  | if toks.ndim != 2: | 
					
						
						|  | toks = np.atleast_2d(toks) | 
					
						
						|  |  | 
					
						
						|  | if toks.shape[1] > D: | 
					
						
						|  | toks = toks[:, :D] | 
					
						
						|  | elif toks.shape[1] < D: | 
					
						
						|  | pad_cols = np.tile(toks[:, -1:], (1, D - toks.shape[1])) | 
					
						
						|  | toks = np.concatenate([toks, pad_cols], axis=1) | 
					
						
						|  |  | 
					
						
						|  | if toks.shape[0] < F: | 
					
						
						|  | if toks.shape[0] == 0: | 
					
						
						|  | toks = np.zeros((1, D), dtype=np.int32) | 
					
						
						|  | pad = np.repeat(toks[-1:, :], F - toks.shape[0], axis=0) | 
					
						
						|  | toks = np.concatenate([pad, toks], axis=0) | 
					
						
						|  | elif toks.shape[0] > F: | 
					
						
						|  | toks = toks[-F:, :] | 
					
						
						|  | if toks.dtype != np.int32: | 
					
						
						|  | toks = toks.astype(np.int32, copy=False) | 
					
						
						|  | return toks | 
					
						
						|  |  | 
					
						
						|  | def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray: | 
					
						
						|  | """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps), | 
					
						
						|  | while ensuring the *end* of the audio lands on a bar boundary. | 
					
						
						|  | Strategy: take the largest integer number of bars <= ctx_seconds as the tail, | 
					
						
						|  | then left-fill from just before that tail (wrapping if needed) to reach exactly | 
					
						
						|  | ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim | 
					
						
						|  | tokens to the expected frame count. | 
					
						
						|  | """ | 
					
						
						|  | wav = loop.as_stereo().resample(self._model_sr) | 
					
						
						|  | data = wav.samples.astype(np.float32, copy=False) | 
					
						
						|  | if data.ndim == 1: | 
					
						
						|  | data = data[:, None] | 
					
						
						|  |  | 
					
						
						|  | spb = self._bar_clock.seconds_per_bar() | 
					
						
						|  | ctx_sec = float(self._ctx_seconds) | 
					
						
						|  | sr = int(self._model_sr) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bars_fit = max(1, int(ctx_sec // spb)) | 
					
						
						|  | tail_len_samps = int(round(bars_fit * spb * sr)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | need = int(round(ctx_sec * sr)) + tail_len_samps | 
					
						
						|  | if data.shape[0] == 0: | 
					
						
						|  | data = np.zeros((1, 2), dtype=np.float32) | 
					
						
						|  | reps = int(np.ceil(need / float(data.shape[0]))) | 
					
						
						|  | tiled = np.tile(data, (reps, 1)) | 
					
						
						|  |  | 
					
						
						|  | end = tiled.shape[0] | 
					
						
						|  | tail = tiled[end - tail_len_samps:end] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ctx_samps = int(round(ctx_sec * sr)) | 
					
						
						|  | pad_len = ctx_samps - tail.shape[0] | 
					
						
						|  | if pad_len > 0: | 
					
						
						|  | pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps] | 
					
						
						|  | ctx = np.concatenate([pre, tail], axis=0) | 
					
						
						|  | else: | 
					
						
						|  | ctx = tail[-ctx_samps:] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if ctx.shape[0] < ctx_samps: | 
					
						
						|  | pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32) | 
					
						
						|  | ctx = np.concatenate([pad, ctx], axis=0) | 
					
						
						|  | elif ctx.shape[0] > ctx_samps: | 
					
						
						|  | ctx = ctx[-ctx_samps:] | 
					
						
						|  |  | 
					
						
						|  | exact = au.Waveform(ctx, sr) | 
					
						
						|  | tokens_full = self.mrt.codec.encode(exact).astype(np.int32) | 
					
						
						|  | depth = int(self.mrt.config.decoder_codec_rvq_depth) | 
					
						
						|  | tokens = tokens_full[:, :depth] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tokens = self._coerce_tokens(tokens) | 
					
						
						|  | return tokens | 
					
						
						|  |  | 
					
						
						|  | def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray: | 
					
						
						|  | """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps), | 
					
						
						|  | while ensuring the *end* of the audio lands on a bar boundary. | 
					
						
						|  | Strategy: take the largest integer number of bars <= ctx_seconds as the tail, | 
					
						
						|  | then left-fill from just before that tail (wrapping if needed) to reach exactly | 
					
						
						|  | ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim | 
					
						
						|  | tokens to the expected frame count. | 
					
						
						|  | """ | 
					
						
						|  | wav = loop.as_stereo().resample(self._model_sr) | 
					
						
						|  | data = wav.samples.astype(np.float32, copy=False) | 
					
						
						|  | if data.ndim == 1: | 
					
						
						|  | data = data[:, None] | 
					
						
						|  |  | 
					
						
						|  | spb = self._bar_clock.seconds_per_bar() | 
					
						
						|  | ctx_sec = float(self._ctx_seconds) | 
					
						
						|  | sr = int(self._model_sr) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bars_fit = max(1, int(ctx_sec // spb)) | 
					
						
						|  | tail_len_samps = int(round(bars_fit * spb * sr)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | need = int(round(ctx_sec * sr)) + tail_len_samps | 
					
						
						|  | if data.shape[0] == 0: | 
					
						
						|  | data = np.zeros((1, 2), dtype=np.float32) | 
					
						
						|  | reps = int(np.ceil(need / float(data.shape[0]))) | 
					
						
						|  | tiled = np.tile(data, (reps, 1)) | 
					
						
						|  |  | 
					
						
						|  | end = tiled.shape[0] | 
					
						
						|  | tail = tiled[end - tail_len_samps:end] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ctx_samps = int(round(ctx_sec * sr)) | 
					
						
						|  | pad_len = ctx_samps - tail.shape[0] | 
					
						
						|  | if pad_len > 0: | 
					
						
						|  | pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps] | 
					
						
						|  | ctx = np.concatenate([pre, tail], axis=0) | 
					
						
						|  | else: | 
					
						
						|  | ctx = tail[-ctx_samps:] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if ctx.shape[0] < ctx_samps: | 
					
						
						|  | pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32) | 
					
						
						|  | ctx = np.concatenate([pad, ctx], axis=0) | 
					
						
						|  | elif ctx.shape[0] > ctx_samps: | 
					
						
						|  | ctx = ctx[-ctx_samps:] | 
					
						
						|  |  | 
					
						
						|  | exact = au.Waveform(ctx, sr) | 
					
						
						|  | tokens_full = self.mrt.codec.encode(exact).astype(np.int32) | 
					
						
						|  | depth = int(self.mrt.config.decoder_codec_rvq_depth) | 
					
						
						|  | tokens = tokens_full[:, :depth] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | frames = tokens.shape[0] | 
					
						
						|  | exp = int(self._ctx_frames) | 
					
						
						|  | if frames < exp: | 
					
						
						|  |  | 
					
						
						|  | pad = np.repeat(tokens[-1:, :], exp - frames, axis=0) | 
					
						
						|  | tokens = np.concatenate([pad, tokens], axis=0) | 
					
						
						|  | elif frames > exp: | 
					
						
						|  | tokens = tokens[-exp:, :] | 
					
						
						|  | return tokens | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _install_context_from_loop(self, loop: au.Waveform): | 
					
						
						|  |  | 
					
						
						|  | context_tokens = self._encode_exact_context_tokens(loop) | 
					
						
						|  | s = self.mrt.init_state() | 
					
						
						|  | s.context_tokens = context_tokens | 
					
						
						|  | self.state = s | 
					
						
						|  | self._original_context_tokens = np.copy(context_tokens) | 
					
						
						|  |  | 
					
						
						|  | def reseed_from_waveform(self, wav: au.Waveform): | 
					
						
						|  | """Immediate reseed: replace context from provided wave (bar-locked, exact length).""" | 
					
						
						|  | context_tokens = self._encode_exact_context_tokens(wav) | 
					
						
						|  | with self._lock: | 
					
						
						|  | s = self.mrt.init_state() | 
					
						
						|  | s.context_tokens = context_tokens | 
					
						
						|  | self.state = s | 
					
						
						|  | self._model_stream = None | 
					
						
						|  | self._original_context_tokens = np.copy(context_tokens) | 
					
						
						|  |  | 
					
						
						|  | def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float): | 
					
						
						|  | """Queue a *seamless* reseed by token splicing instead of full restart. | 
					
						
						|  | We compute a fresh, bar-locked context token tensor of exact length | 
					
						
						|  | (e.g., 250 frames), then splice only the *tail* corresponding to | 
					
						
						|  | `anchor_bars` so generation continues smoothly without resetting state. | 
					
						
						|  | """ | 
					
						
						|  | new_ctx = self._encode_exact_context_tokens(recent_wav) | 
					
						
						|  | F, D = self._expected_token_shape() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | spb = self._bar_clock.seconds_per_bar() | 
					
						
						|  | frames_per_bar = max(1, int(round(self._codec_fps * spb))) | 
					
						
						|  | splice_frames = max(1, min(int(round(max(1.0, float(anchor_bars)) * frames_per_bar)), F)) | 
					
						
						|  |  | 
					
						
						|  | with self._lock: | 
					
						
						|  |  | 
					
						
						|  | cur = getattr(self.state, "context_tokens", None) | 
					
						
						|  | if cur is None: | 
					
						
						|  |  | 
					
						
						|  | self._pending_reseed = {"ctx": new_ctx} | 
					
						
						|  | return | 
					
						
						|  | cur = self._coerce_tokens(cur) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | left = cur[:F - splice_frames, :] | 
					
						
						|  | right = new_ctx[F - splice_frames:, :] | 
					
						
						|  | spliced = np.concatenate([left, right], axis=0) | 
					
						
						|  | spliced = self._coerce_tokens(spliced) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._pending_token_splice = { | 
					
						
						|  | "tokens": spliced, | 
					
						
						|  | "debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar} | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def reseed_from_waveform(self, wav: au.Waveform): | 
					
						
						|  | """Immediate reseed: replace context from provided wave (bar-aligned tail).""" | 
					
						
						|  | wav = wav.as_stereo().resample(self._model_sr) | 
					
						
						|  | tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds) | 
					
						
						|  | tokens_full = self.mrt.codec.encode(tail).astype(np.int32) | 
					
						
						|  | depth = int(self.mrt.config.decoder_codec_rvq_depth) | 
					
						
						|  | context_tokens = tokens_full[:, :depth] | 
					
						
						|  |  | 
					
						
						|  | s = self.mrt.init_state() | 
					
						
						|  | s.context_tokens = context_tokens | 
					
						
						|  | self.state = s | 
					
						
						|  |  | 
					
						
						|  | self._model_stream = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._original_context_tokens = np.copy(context_tokens) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _append_model_chunk_and_spool(self, wav: au.Waveform): | 
					
						
						|  | """Crossfade into the model-rate stream and write the *non-overlapped* | 
					
						
						|  | tail to the target-SR spool.""" | 
					
						
						|  | s = wav.samples.astype(np.float32, copy=False) | 
					
						
						|  | if s.ndim == 1: | 
					
						
						|  | s = s[:, None] | 
					
						
						|  | sr = self._model_sr | 
					
						
						|  | xfade_s = float(self.mrt.config.crossfade_length) | 
					
						
						|  | xfade_n = int(round(max(0.0, xfade_s) * sr)) | 
					
						
						|  |  | 
					
						
						|  | if self._model_stream is None: | 
					
						
						|  |  | 
					
						
						|  | new_part = s[xfade_n:] if xfade_n < s.shape[0] else s[:0] | 
					
						
						|  | self._model_stream = new_part.copy() | 
					
						
						|  | if new_part.size: | 
					
						
						|  | y = (new_part.astype(np.float32, copy=False) | 
					
						
						|  | if self._rs is None else | 
					
						
						|  | self._rs.process(new_part.astype(np.float32, copy=False), final=False)) | 
					
						
						|  | self._spool = np.concatenate([self._spool, y], axis=0) | 
					
						
						|  | self._spool_written += y.shape[0] | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if xfade_n > 0 and self._model_stream.shape[0] >= xfade_n and s.shape[0] >= xfade_n: | 
					
						
						|  | tail = self._model_stream[-xfade_n:] | 
					
						
						|  | head = s[:xfade_n] | 
					
						
						|  | t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None] | 
					
						
						|  | mixed = tail * np.cos(t) + head * np.sin(t) | 
					
						
						|  | self._model_stream = np.concatenate([self._model_stream[:-xfade_n], mixed, s[xfade_n:]], axis=0) | 
					
						
						|  | new_part = s[xfade_n:] | 
					
						
						|  | else: | 
					
						
						|  | self._model_stream = np.concatenate([self._model_stream, s], axis=0) | 
					
						
						|  | new_part = s | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if new_part.size: | 
					
						
						|  | y = (new_part.astype(np.float32, copy=False) | 
					
						
						|  | if self._rs is None else | 
					
						
						|  | self._rs.process(new_part.astype(np.float32, copy=False), final=False)) | 
					
						
						|  | if y.size: | 
					
						
						|  | self._spool = np.concatenate([self._spool, y], axis=0) | 
					
						
						|  | self._spool_written += y.shape[0] | 
					
						
						|  |  | 
					
						
						|  | def _should_generate_next_chunk(self) -> bool: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | implicit_consumed = self._next_to_deliver - 1 | 
					
						
						|  | horizon_anchor = max(self._last_consumed_index, implicit_consumed) | 
					
						
						|  | return self.idx <= (horizon_anchor + self._max_buffer_ahead) | 
					
						
						|  |  | 
					
						
						|  | def _emit_ready(self): | 
					
						
						|  | """Emit next chunk(s) if the spool has enough samples.""" | 
					
						
						|  | while True: | 
					
						
						|  | start, end = self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk) | 
					
						
						|  | if end > self._spool_written: | 
					
						
						|  | break | 
					
						
						|  | loop = self._spool[start:end] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.params.ref_loop is not None and self.params.loudness_mode != "none": | 
					
						
						|  | ref = self.params.ref_loop.as_stereo().resample(self.params.target_sr) | 
					
						
						|  | wav = au.Waveform(loop.copy(), int(self.params.target_sr)) | 
					
						
						|  | matched, _ = match_loudness_to_reference(ref, wav, method=self.params.loudness_mode, headroom_db=self.params.headroom_db) | 
					
						
						|  | loop = matched.samples | 
					
						
						|  |  | 
					
						
						|  | audio_b64, total_samples, channels = wav_bytes_base64(loop, int(self.params.target_sr)) | 
					
						
						|  | meta = { | 
					
						
						|  | "bpm": float(self.params.bpm), | 
					
						
						|  | "bars": int(self.params.bars_per_chunk), | 
					
						
						|  | "beats_per_bar": int(self.params.beats_per_bar), | 
					
						
						|  | "sample_rate": int(self.params.target_sr), | 
					
						
						|  | "channels": int(channels), | 
					
						
						|  | "total_samples": int(total_samples), | 
					
						
						|  | "seconds_per_bar": self._bar_clock.seconds_per_bar(), | 
					
						
						|  | "loop_duration_seconds": self.params.bars_per_chunk * self._bar_clock.seconds_per_bar(), | 
					
						
						|  | "guidance_weight": float(self.params.guidance_weight), | 
					
						
						|  | "temperature": float(self.params.temperature), | 
					
						
						|  | "topk": int(self.params.topk), | 
					
						
						|  | } | 
					
						
						|  | chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta) | 
					
						
						|  |  | 
					
						
						|  | with self._cv: | 
					
						
						|  | self._outbox[self.idx] = chunk | 
					
						
						|  | self._cv.notify_all() | 
					
						
						|  | self.idx += 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with self._lock: | 
					
						
						|  |  | 
					
						
						|  | if self._pending_token_splice is not None: | 
					
						
						|  | spliced = self._coerce_tokens(self._pending_token_splice["tokens"]) | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | self.state.context_tokens = spliced | 
					
						
						|  | self._pending_token_splice = None | 
					
						
						|  | except Exception: | 
					
						
						|  |  | 
					
						
						|  | new_state = self.mrt.init_state() | 
					
						
						|  | new_state.context_tokens = spliced | 
					
						
						|  | self.state = new_state | 
					
						
						|  | self._model_stream = None | 
					
						
						|  | self._pending_token_splice = None | 
					
						
						|  | elif self._pending_reseed is not None: | 
					
						
						|  | ctx = self._coerce_tokens(self._pending_reseed["ctx"]) | 
					
						
						|  | new_state = self.mrt.init_state() | 
					
						
						|  | new_state.context_tokens = ctx | 
					
						
						|  | self.state = new_state | 
					
						
						|  | self._model_stream = None | 
					
						
						|  | self._pending_reseed = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def run(self): | 
					
						
						|  |  | 
					
						
						|  | while not self._stop_event.is_set(): | 
					
						
						|  |  | 
					
						
						|  | if not self._should_generate_next_chunk(): | 
					
						
						|  |  | 
					
						
						|  | self._emit_ready() | 
					
						
						|  | time.sleep(0.01) | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with self._lock: | 
					
						
						|  | target = self.params.style_vec | 
					
						
						|  | if target is None: | 
					
						
						|  | style_to_use = None | 
					
						
						|  | else: | 
					
						
						|  | if self._style_vec is None: | 
					
						
						|  | self._style_vec = np.array(target, dtype=np.float32, copy=True) | 
					
						
						|  | else: | 
					
						
						|  | ramp = float(self.params.style_ramp_seconds or 0.0) | 
					
						
						|  | step = 1.0 if ramp <= 0.0 else min(1.0, self._chunk_secs / ramp) | 
					
						
						|  |  | 
					
						
						|  | self._style_vec += step * (target.astype(np.float32, copy=False) - self._style_vec) | 
					
						
						|  | style_to_use = self._style_vec | 
					
						
						|  |  | 
					
						
						|  | wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_to_use) | 
					
						
						|  |  | 
					
						
						|  | self._append_model_chunk_and_spool(wav) | 
					
						
						|  |  | 
					
						
						|  | self._emit_ready() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tail = self._rs.process(np.zeros((0,2), np.float32), final=True) | 
					
						
						|  | if tail.size: | 
					
						
						|  | self._spool = np.concatenate([self._spool, tail], axis=0) | 
					
						
						|  | self._spool_written += tail.shape[0] | 
					
						
						|  |  | 
					
						
						|  | self._emit_ready() | 
					
						
						|  |  |