|  |  | 
					
						
						|  | import threading, time, base64, io, uuid | 
					
						
						|  | from dataclasses import dataclass, field | 
					
						
						|  | import numpy as np | 
					
						
						|  | import soundfile as sf | 
					
						
						|  | from magenta_rt import audio as au | 
					
						
						|  | from threading import RLock | 
					
						
						|  | from utils import ( | 
					
						
						|  | match_loudness_to_reference, stitch_generated, hard_trim_seconds, | 
					
						
						|  | apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail, | 
					
						
						|  | resample_and_snap, wav_bytes_base64 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class JamParams: | 
					
						
						|  | bpm: float | 
					
						
						|  | beats_per_bar: int | 
					
						
						|  | bars_per_chunk: int | 
					
						
						|  | target_sr: int | 
					
						
						|  | loudness_mode: str = "auto" | 
					
						
						|  | headroom_db: float = 1.0 | 
					
						
						|  | style_vec: np.ndarray | None = None | 
					
						
						|  | ref_loop: any = None | 
					
						
						|  | combined_loop: any = None | 
					
						
						|  | guidance_weight: float = 1.1 | 
					
						
						|  | temperature: float = 1.1 | 
					
						
						|  | topk: int = 40 | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class JamChunk: | 
					
						
						|  | index: int | 
					
						
						|  | audio_base64: str | 
					
						
						|  | metadata: dict | 
					
						
						|  |  | 
					
						
						|  | class JamWorker(threading.Thread): | 
					
						
						|  | def __init__(self, mrt, params: JamParams): | 
					
						
						|  | super().__init__(daemon=True) | 
					
						
						|  | self.mrt = mrt | 
					
						
						|  | self.params = params | 
					
						
						|  | self.state = mrt.init_state() | 
					
						
						|  |  | 
					
						
						|  | if params.combined_loop is not None: | 
					
						
						|  | self._setup_context_from_combined_loop() | 
					
						
						|  |  | 
					
						
						|  | self.idx = 0 | 
					
						
						|  | self.outbox: list[JamChunk] = [] | 
					
						
						|  | self._stop_event = threading.Event() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._last_delivered_index = 0 | 
					
						
						|  | self._max_buffer_ahead = 5 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.last_chunk_started_at = None | 
					
						
						|  | self.last_chunk_completed_at = None | 
					
						
						|  | self._lock = threading.Lock() | 
					
						
						|  |  | 
					
						
						|  | def _setup_context_from_combined_loop(self): | 
					
						
						|  | """Set up MRT context tokens from the combined loop audio""" | 
					
						
						|  | try: | 
					
						
						|  | from utils import make_bar_aligned_context, take_bar_aligned_tail | 
					
						
						|  |  | 
					
						
						|  | codec_fps = float(self.mrt.codec.frame_rate) | 
					
						
						|  | ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps | 
					
						
						|  |  | 
					
						
						|  | loop_for_context = take_bar_aligned_tail( | 
					
						
						|  | self.params.combined_loop, | 
					
						
						|  | self.params.bpm, | 
					
						
						|  | self.params.beats_per_bar, | 
					
						
						|  | ctx_seconds | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32) | 
					
						
						|  | tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth] | 
					
						
						|  |  | 
					
						
						|  | context_tokens = make_bar_aligned_context( | 
					
						
						|  | tokens, | 
					
						
						|  | bpm=self.params.bpm, | 
					
						
						|  | fps=int(self.mrt.codec.frame_rate), | 
					
						
						|  | ctx_frames=self.mrt.config.context_length_frames, | 
					
						
						|  | beats_per_bar=self.params.beats_per_bar | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.state.context_tokens = context_tokens | 
					
						
						|  | print(f"β
 JamWorker: Set up fresh context from combined loop") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with self._lock: | 
					
						
						|  | if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None: | 
					
						
						|  | self._original_context_tokens = np.copy(context_tokens) | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"β Failed to setup context from combined loop: {e}") | 
					
						
						|  |  | 
					
						
						|  | def stop(self): | 
					
						
						|  | self._stop_event.set() | 
					
						
						|  |  | 
					
						
						|  | def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None): | 
					
						
						|  | with self._lock: | 
					
						
						|  | if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight) | 
					
						
						|  | if temperature is not None:     self.params.temperature     = float(temperature) | 
					
						
						|  | if topk is not None:            self.params.topk            = int(topk) | 
					
						
						|  |  | 
					
						
						|  | def get_next_chunk(self) -> JamChunk | None: | 
					
						
						|  | """Get the next sequential chunk (blocks/waits if not ready)""" | 
					
						
						|  | target_index = self._last_delivered_index + 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | max_wait = 30.0 | 
					
						
						|  | start_time = time.time() | 
					
						
						|  |  | 
					
						
						|  | while time.time() - start_time < max_wait and not self._stop_event.is_set(): | 
					
						
						|  | with self._lock: | 
					
						
						|  |  | 
					
						
						|  | for chunk in self.outbox: | 
					
						
						|  | if chunk.index == target_index: | 
					
						
						|  | self._last_delivered_index = target_index | 
					
						
						|  | print(f"π¦ Delivered chunk {target_index}") | 
					
						
						|  | return chunk | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | time.sleep(0.1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def mark_chunk_consumed(self, chunk_index: int): | 
					
						
						|  | """Mark a chunk as consumed by the frontend""" | 
					
						
						|  | with self._lock: | 
					
						
						|  | self._last_delivered_index = max(self._last_delivered_index, chunk_index) | 
					
						
						|  | print(f"β
 Chunk {chunk_index} consumed") | 
					
						
						|  |  | 
					
						
						|  | def _should_generate_next_chunk(self) -> bool: | 
					
						
						|  | """Check if we should generate the next chunk (don't get too far ahead)""" | 
					
						
						|  | with self._lock: | 
					
						
						|  |  | 
					
						
						|  | if self.idx > self._last_delivered_index + self._max_buffer_ahead: | 
					
						
						|  | return False | 
					
						
						|  | return True | 
					
						
						|  |  | 
					
						
						|  | def _seconds_per_bar(self) -> float: | 
					
						
						|  | return self.params.beats_per_bar * (60.0 / self.params.bpm) | 
					
						
						|  |  | 
					
						
						|  | def _snap_and_encode(self, y, seconds, target_sr, bars): | 
					
						
						|  | cur_sr = int(self.mrt.sample_rate) | 
					
						
						|  | x = y.samples if y.samples.ndim == 2 else y.samples[:, None] | 
					
						
						|  | x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds) | 
					
						
						|  | b64, total_samples, channels = wav_bytes_base64(x, target_sr) | 
					
						
						|  | meta = { | 
					
						
						|  | "bpm": int(round(self.params.bpm)), | 
					
						
						|  | "bars": int(bars), | 
					
						
						|  | "beats_per_bar": int(self.params.beats_per_bar), | 
					
						
						|  | "sample_rate": int(target_sr), | 
					
						
						|  | "channels": channels, | 
					
						
						|  | "total_samples": total_samples, | 
					
						
						|  | "seconds_per_bar": self._seconds_per_bar(), | 
					
						
						|  | "loop_duration_seconds": bars * self._seconds_per_bar(), | 
					
						
						|  | "guidance_weight": self.params.guidance_weight, | 
					
						
						|  | "temperature": self.params.temperature, | 
					
						
						|  | "topk": self.params.topk, | 
					
						
						|  | } | 
					
						
						|  | return b64, meta | 
					
						
						|  |  | 
					
						
						|  | def _append_model_chunk_to_stream(self, wav): | 
					
						
						|  | """Incrementally append a model chunk with equal-power crossfade.""" | 
					
						
						|  | xfade_s = float(self.mrt.config.crossfade_length) | 
					
						
						|  | sr = int(self.mrt.sample_rate) | 
					
						
						|  | xfade_n = int(round(xfade_s * sr)) | 
					
						
						|  |  | 
					
						
						|  | s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None] | 
					
						
						|  |  | 
					
						
						|  | if getattr(self, "_stream", None) is None: | 
					
						
						|  |  | 
					
						
						|  | if s.shape[0] > xfade_n: | 
					
						
						|  | self._stream = s[xfade_n:].astype(np.float32, copy=True) | 
					
						
						|  | else: | 
					
						
						|  | self._stream = np.zeros((0, s.shape[1]), dtype=np.float32) | 
					
						
						|  | self._next_emit_start = 0 | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n: | 
					
						
						|  |  | 
					
						
						|  | self._stream = np.concatenate([self._stream, s], axis=0) | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | tail = self._stream[-xfade_n:] | 
					
						
						|  | head = s[:xfade_n] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None] | 
					
						
						|  | eq_in, eq_out = np.sin(t), np.cos(t) | 
					
						
						|  | mixed = tail * eq_out + head * eq_in | 
					
						
						|  |  | 
					
						
						|  | self._stream = np.concatenate([self._stream[:-xfade_n], mixed, s[xfade_n:]], axis=0) | 
					
						
						|  |  | 
					
						
						|  | def reseed_from_waveform(self, wav): | 
					
						
						|  |  | 
					
						
						|  | new_state = self.mrt.init_state() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | codec_fps   = float(self.mrt.codec.frame_rate) | 
					
						
						|  | ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps | 
					
						
						|  | from utils import take_bar_aligned_tail, make_bar_aligned_context | 
					
						
						|  |  | 
					
						
						|  | tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, ctx_seconds) | 
					
						
						|  | tokens_full = self.mrt.codec.encode(tail).astype(np.int32) | 
					
						
						|  | tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth] | 
					
						
						|  | context_tokens = make_bar_aligned_context(tokens, | 
					
						
						|  | bpm=self.params.bpm, fps=int(self.mrt.codec.frame_rate), | 
					
						
						|  | ctx_frames=self.mrt.config.context_length_frames, | 
					
						
						|  | beats_per_bar=self.params.beats_per_bar | 
					
						
						|  | ) | 
					
						
						|  | new_state.context_tokens = context_tokens | 
					
						
						|  | self.state = new_state | 
					
						
						|  | self._prepare_stream_for_reseed_handoff() | 
					
						
						|  |  | 
					
						
						|  | def _frames_per_bar(self) -> int: | 
					
						
						|  |  | 
					
						
						|  | fps = float(self.mrt.codec.frame_rate) | 
					
						
						|  | sec_per_bar = (60.0 / float(self.params.bpm)) * float(self.params.beats_per_bar) | 
					
						
						|  | return int(round(fps * sec_per_bar)) | 
					
						
						|  |  | 
					
						
						|  | def _ctx_frames(self) -> int: | 
					
						
						|  |  | 
					
						
						|  | return int(self.mrt.config.context_length_frames) | 
					
						
						|  |  | 
					
						
						|  | def _make_recent_tokens_from_wave(self, wav) -> np.ndarray: | 
					
						
						|  | """ | 
					
						
						|  | Encode a waveform and produce a bar-aligned context token window (same shape/depth | 
					
						
						|  | as state.context_tokens). Uses your existing codec depth. | 
					
						
						|  | """ | 
					
						
						|  | tokens_full = self.mrt.codec.encode(wav).astype(np.int32) | 
					
						
						|  | tokens      = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | t = tokens.shape[0] | 
					
						
						|  | ctx = self._ctx_frames() | 
					
						
						|  | if t > ctx: | 
					
						
						|  | tokens = tokens[-ctx:] | 
					
						
						|  | return tokens | 
					
						
						|  |  | 
					
						
						|  | def _bar_aligned_tail(self, tokens: np.ndarray, bars: float) -> np.ndarray: | 
					
						
						|  | """ | 
					
						
						|  | Take a tail slice that is an integer number of codec frames corresponding to `bars`. | 
					
						
						|  | We round to nearest frame to stay phase-consistent with codec grid. | 
					
						
						|  | """ | 
					
						
						|  | frames_per_bar = self._frames_per_bar() | 
					
						
						|  | want = max(frames_per_bar * int(round(bars)), 0) | 
					
						
						|  | if want == 0: | 
					
						
						|  | return tokens[:0] | 
					
						
						|  | if tokens.shape[0] <= want: | 
					
						
						|  | return tokens | 
					
						
						|  | return tokens[-want:] | 
					
						
						|  |  | 
					
						
						|  | def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray, | 
					
						
						|  | anchor_bars: float) -> np.ndarray: | 
					
						
						|  | """ | 
					
						
						|  | Build new context by concatenating: | 
					
						
						|  | anchor = tail from originals (anchor_bars) | 
					
						
						|  | recent = tail from recent_tokens filling the remainder | 
					
						
						|  | Then clamp to ctx_frames from the tail (safety). | 
					
						
						|  | """ | 
					
						
						|  | ctx_frames = self._ctx_frames() | 
					
						
						|  | depth = original_tokens.shape[1] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | anchor = self._bar_aligned_tail(original_tokens, anchor_bars) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | a = anchor.shape[0] | 
					
						
						|  | remain = max(ctx_frames - a, 0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if remain > 0: | 
					
						
						|  |  | 
					
						
						|  | frames_per_bar = self._frames_per_bar() | 
					
						
						|  | recent_bars_fit = int(remain // frames_per_bar) | 
					
						
						|  |  | 
					
						
						|  | if recent_bars_fit >= 1: | 
					
						
						|  | want_recent_frames = recent_bars_fit * frames_per_bar | 
					
						
						|  | recent = recent_tokens[-want_recent_frames:] if recent_tokens.shape[0] > want_recent_frames else recent_tokens | 
					
						
						|  | else: | 
					
						
						|  | recent = recent_tokens[-remain:] if recent_tokens.shape[0] > remain else recent_tokens | 
					
						
						|  | else: | 
					
						
						|  | recent = recent_tokens[:0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | out = np.concatenate([anchor, recent], axis=0) if anchor.size or recent.size else recent_tokens[-ctx_frames:] | 
					
						
						|  | if out.shape[0] > ctx_frames: | 
					
						
						|  | out = out[-ctx_frames:] | 
					
						
						|  |  | 
					
						
						|  | if out.shape[1] != depth: | 
					
						
						|  | out = out[:, :depth] | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | def _prepare_stream_for_reseed_handoff(self): | 
					
						
						|  | """ | 
					
						
						|  | Keep only a tiny tail to crossfade against the FIRST post-reseed chunk. | 
					
						
						|  | Reset the emit pointer so the next emitted window starts fresh. | 
					
						
						|  | """ | 
					
						
						|  | sr = int(self.mrt.sample_rate) | 
					
						
						|  | xfade_s = float(self.mrt.config.crossfade_length) | 
					
						
						|  | xfade_n = int(round(xfade_s * sr)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if getattr(self, "_stream", None) is not None and self._stream.shape[0] > 0: | 
					
						
						|  | tail = self._stream[-xfade_n:] if self._stream.shape[0] > xfade_n else self._stream | 
					
						
						|  | self._stream = tail.copy() | 
					
						
						|  | else: | 
					
						
						|  | self._stream = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._next_emit_start = 0 | 
					
						
						|  |  | 
					
						
						|  | def reseed_splice(self, recent_wav, anchor_bars: float): | 
					
						
						|  | """ | 
					
						
						|  | Token-splice reseed: | 
					
						
						|  | - original = the context we captured when the jam started | 
					
						
						|  | - recent   = tokens from the provided recent waveform (usually Swift-combined mix) | 
					
						
						|  | - anchor_bars controls how much of the original vibe we re-inject | 
					
						
						|  | """ | 
					
						
						|  | with self._lock: | 
					
						
						|  | if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None: | 
					
						
						|  |  | 
					
						
						|  | self._original_context_tokens = np.copy(self.state.context_tokens) | 
					
						
						|  |  | 
					
						
						|  | recent_tokens = self._make_recent_tokens_from_wave(recent_wav) | 
					
						
						|  | new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | new_state = self.mrt.init_state() | 
					
						
						|  | new_state.context_tokens = new_ctx | 
					
						
						|  | self.state = new_state | 
					
						
						|  |  | 
					
						
						|  | self._prepare_stream_for_reseed_handoff() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1 | 
					
						
						|  |  | 
					
						
						|  | def run(self): | 
					
						
						|  | """Continuous stream + sliding 8-bar window emitter.""" | 
					
						
						|  | sr_model = int(self.mrt.sample_rate) | 
					
						
						|  | spb = self._seconds_per_bar() | 
					
						
						|  | chunk_secs = float(self.params.bars_per_chunk) * spb | 
					
						
						|  | chunk_n_model = int(round(chunk_secs * sr_model)) | 
					
						
						|  | xfade = self.mrt.config.crossfade_length | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._stream = None | 
					
						
						|  | self._next_emit_start = 0 | 
					
						
						|  |  | 
					
						
						|  | print("π JamWorker (streaming) started...") | 
					
						
						|  |  | 
					
						
						|  | while not self._stop_event.is_set(): | 
					
						
						|  |  | 
					
						
						|  | with self._lock: | 
					
						
						|  | if self.idx > self._last_delivered_index + self._max_buffer_ahead: | 
					
						
						|  | time.sleep(0.25) | 
					
						
						|  | continue | 
					
						
						|  | style_vec = self.params.style_vec | 
					
						
						|  | self.mrt.guidance_weight = self.params.guidance_weight | 
					
						
						|  | self.mrt.temperature     = self.params.temperature | 
					
						
						|  | self.mrt.topk            = self.params.topk | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.last_chunk_started_at = time.time() | 
					
						
						|  | wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec) | 
					
						
						|  | self._append_model_chunk_to_stream(wav) | 
					
						
						|  | self.last_chunk_completed_at = time.time() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | while (getattr(self, "_stream", None) is not None and | 
					
						
						|  | self._stream.shape[0] - self._next_emit_start >= chunk_n_model and | 
					
						
						|  | not self._stop_event.is_set()): | 
					
						
						|  |  | 
					
						
						|  | seg = self._stream[self._next_emit_start:self._next_emit_start + chunk_n_model] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | y = au.Waveform(seg.astype(np.float32, copy=False), sr_model).as_stereo() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | next_idx = self.idx + 1 | 
					
						
						|  | if next_idx == 1 and self.params.ref_loop is not None: | 
					
						
						|  | y, _ = match_loudness_to_reference( | 
					
						
						|  | self.params.ref_loop, y, | 
					
						
						|  | method=self.params.loudness_mode, | 
					
						
						|  | headroom_db=self.params.headroom_db | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | b64, meta = self._snap_and_encode( | 
					
						
						|  | y, seconds=chunk_secs, | 
					
						
						|  | target_sr=self.params.target_sr, | 
					
						
						|  | bars=self.params.bars_per_chunk | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | with self._lock: | 
					
						
						|  | self.idx = next_idx | 
					
						
						|  | self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta)) | 
					
						
						|  |  | 
					
						
						|  | if len(self.outbox) > 10: | 
					
						
						|  | self.outbox = [ch for ch in self.outbox if ch.index > self._last_delivered_index - 5] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._next_emit_start += chunk_n_model | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | keep_from = max(0, self._next_emit_start - chunk_n_model) | 
					
						
						|  | if keep_from > 0: | 
					
						
						|  | self._stream = self._stream[keep_from:] | 
					
						
						|  | self._next_emit_start -= keep_from | 
					
						
						|  |  | 
					
						
						|  | print("π JamWorker (streaming) stopped") |