# utils.py from __future__ import annotations import io, base64, math from math import gcd import numpy as np import soundfile as sf from scipy.signal import resample_poly # Magenta RT audio types from magenta_rt import audio as au # Optional loudness try: import pyloudnorm as pyln _HAS_LOUDNORM = True except Exception: _HAS_LOUDNORM = False # ---------- Loudness ---------- def _measure_lufs(wav: au.Waveform) -> float: meter = pyln.Meter(wav.sample_rate) # BS.1770-4 return float(meter.integrated_loudness(wav.samples)) def _rms(x: np.ndarray) -> float: if x.size == 0: return 0.0 return float(np.sqrt(np.mean(x**2))) def match_loudness_to_reference( ref: au.Waveform, target: au.Waveform, method: str = "auto", # "auto"|"lufs"|"rms"|"none" headroom_db: float = 1.0 ) -> tuple[au.Waveform, dict]: stats = {"method": method, "applied_gain_db": 0.0} if method == "none": return target, stats if method == "auto": method = "lufs" if _HAS_LOUDNORM else "rms" if method == "lufs" and _HAS_LOUDNORM: L_ref = _measure_lufs(ref) L_tgt = _measure_lufs(target) delta_db = L_ref - L_tgt gain = 10.0 ** (delta_db / 20.0) y = target.samples.astype(np.float32) * gain stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db}) else: ra = _rms(ref.samples) rb = _rms(target.samples) if rb <= 1e-12: return target, stats gain = ra / rb y = target.samples.astype(np.float32) * gain stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))}) # simple peak “limiter” to keep headroom limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS peak = float(np.max(np.abs(y))) if y.size else 0.0 if peak > limit: y *= (limit / peak) stats["post_peak_limited"] = True else: stats["post_peak_limited"] = False target.samples = y.astype(np.float32) return target, stats # ---------- Stitch / fades / trims ---------- def stitch_generated(chunks, sr: int, xfade_s: float, drop_first_pre_roll: bool = True): if not chunks: raise ValueError("no chunks") xfade_n = int(round(xfade_s * sr)) if xfade_n <= 0: return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr) t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32) eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None] first = chunks[0].samples if first.shape[0] < xfade_n: raise ValueError("chunk shorter than crossfade prefix") # 🔧 key change: out = first[xfade_n:].copy() if drop_first_pre_roll else first.copy() for i in range(1, len(chunks)): cur = chunks[i].samples if cur.shape[0] < xfade_n: continue head, tail = cur[:xfade_n], cur[xfade_n:] mixed = out[-xfade_n:] * eq_out + head * eq_in out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0) return au.Waveform(out, sr) def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform: n = int(round(seconds * wav.sample_rate)) return au.Waveform(wav.samples[:n], wav.sample_rate) def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None: n = int(wav.sample_rate * ms / 1000.0) if n > 0 and wav.samples.shape[0] > 2*n: env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None] wav.samples[:n] *= env wav.samples[-n:] *= env[::-1] # ---------- Token context helpers ---------- def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4): """ Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest whole-bar boundary in codec-frame space, even when frames_per_bar is fractional. tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames bpm: float fps: float (codec frames per second; keep this as float) ctx_frames: int (length of context window in codec frames) beats_per_bar: int """ if tokens is None: raise ValueError("tokens is None") tokens = np.asarray(tokens) if tokens.ndim == 1: tokens = tokens[:, None] # promote to (T, 1) for uniform tiling T = tokens.shape[0] if T == 0: return tokens fps = float(fps) frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps # float frames per bar # Tile a little more than we need so we can always snap the END to a bar boundary reps = int(np.ceil((ctx_frames + T) / float(T))) + 1 tiled = np.tile(tokens, (reps, 1)) total = tiled.shape[0] # How many whole bars fit? k_bars = int(np.floor(total / frames_per_bar_f)) if k_bars <= 0: # Fallback: just take the last ctx_frames window = tiled[-ctx_frames:] return window # Snap END index to the nearest integer frame at a whole-bar boundary end_idx = int(round(k_bars * frames_per_bar_f)) end_idx = min(max(end_idx, ctx_frames), total) start_idx = end_idx - ctx_frames if start_idx < 0: start_idx = 0 end_idx = ctx_frames window = tiled[start_idx:end_idx] # Guard against rare off-by-one due to rounding if window.shape[0] < ctx_frames: pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1)) window = np.vstack([window, pad])[:ctx_frames] elif window.shape[0] > ctx_frames: window = window[-ctx_frames:] return window def take_bar_aligned_tail( wav: au.Waveform, bpm: float, beats_per_bar: int, ctx_seconds: float, max_bars=None ) -> au.Waveform: """ Take a tail whose length is an integer number of bars, with the END aligned to a bar boundary. Uses ceil for bars_needed so we never under-fill the context. """ import math # seconds per bar spb = (60.0 / float(bpm)) * float(beats_per_bar) # Pick enough whole bars to cover ctx_seconds (avoid underfilling on round-down). # The small epsilon avoids an extra bar due to FP jitter when ctx_seconds ~= k * spb. eps = 1e-9 bars_needed = max(1, int(math.ceil((float(ctx_seconds) - eps) / spb))) if max_bars is not None: bars_needed = min(bars_needed, int(max_bars)) # Convert bars -> samples (do rounding once at the end for stability) samples_per_bar_f = spb * float(wav.sample_rate) n = int(round(bars_needed * samples_per_bar_f)) total = int(wav.samples.shape[0]) if n >= total: # Not enough audio to take that many bars—return as-is (current behavior). return wav start = total - n return au.Waveform(wav.samples[start:], wav.sample_rate) # ---------- SR normalize + snap ---------- def resample_and_snap(x: np.ndarray, cur_sr: int, target_sr: int, seconds: float) -> np.ndarray: """ x: np.ndarray shape (S, C), float32 Returns: exact-length array (round(seconds*target_sr), C) """ if x.ndim == 1: x = x[:, None] if cur_sr != target_sr: g = gcd(cur_sr, target_sr) up, down = target_sr // g, cur_sr // g x = resample_poly(x, up, down, axis=0) expected_len = int(round(seconds * target_sr)) if x.shape[0] < expected_len: pad = np.zeros((expected_len - x.shape[0], x.shape[1]), dtype=x.dtype) x = np.vstack([x, pad]) elif x.shape[0] > expected_len: x = x[:expected_len, :] return x.astype(np.float32, copy=False) # ---------- WAV encode ---------- def wav_bytes_base64(x: np.ndarray, sr: int) -> tuple[str, int, int]: """ x: np.ndarray shape (S, C) returns: (base64_wav, total_samples, channels) """ buf = io.BytesIO() sf.write(buf, x, sr, subtype="FLOAT", format="WAV") buf.seek(0) b64 = base64.b64encode(buf.read()).decode("utf-8") return b64, int(x.shape[0]), int(x.shape[1]) def _ratio(out_sr: int, in_sr: int) -> tuple[int, int]: g = gcd(int(out_sr), int(in_sr)) return int(out_sr) // g, int(in_sr) // g class StreamingResampler: """ Stateful streaming resampler. Prefers soxr (best), then libsamplerate; final fallback is block resample_poly. Always pass float32 arrays shaped (S, C). """ def __init__(self, in_sr: int, out_sr: int, channels: int = 2, quality: str = "VHQ"): self.in_sr = int(in_sr) self.out_sr = int(out_sr) self.channels = int(channels) self.quality = quality self._backend = None # Try soxr first try: import soxr # pip install soxr self._backend = "soxr" # dtype float32 keeps things consistent with the rest of your code self._rs = soxr.Resampler( self.in_sr, self.out_sr, channels=self.channels, dtype="float32", quality=self.quality, # "Q", "HQ", "VHQ" ) except Exception: # Try libsamplerate try: import samplerate # pip install samplerate self._backend = "samplerate" # sinc_best == highest quality; you can choose 'sinc_medium' for speed self._rs = samplerate.Resampler(converter_type="sinc_best", channels=self.channels) except Exception: # Last resort: block resample (not truly streaming) from scipy.signal import resample_poly self._backend = "scipy" self._resample_poly = resample_poly self._L, self._M = _ratio(self.out_sr, self.in_sr) # Keep a tiny tail to help transitions (still not perfect vs true streaming) self._hist = np.zeros((0, self.channels), dtype=np.float32) def process(self, x: np.ndarray, final: bool = False) -> np.ndarray: """Feed a chunk (S, C) and get resampled chunk (S', C). Keep calling in order.""" if x.size == 0 and not final: # nothing to do return np.zeros((0, self.channels), dtype=np.float32) if self._backend == "soxr": return self._rs.process(x, final=final) elif self._backend == "samplerate": import samplerate ratio = float(self.out_sr) / float(self.in_sr) # end_of_input=True flushes tail on the last call y = self._rs.process(x, ratio, end_of_input=final) # libsamplerate returns (S', C) return y.astype(np.float32, copy=False) # --- scipy fallback (block, not truly streaming) --- # We concatenate a short history to reduce block edge artifacts x_ext = x if self._hist.size == 0 else np.vstack([self._hist, x]) y = self._resample_poly(x_ext, up=self._L, down=self._M, axis=0).astype(np.float32, copy=False) # Heuristic: drop the portion corresponding roughly to the history to avoid duplicate content # (Not perfect, but helps a lot when chunks are reasonably sized.) drop = int(round(self._hist.shape[0] * self.out_sr / self.in_sr)) y = y[drop:] if drop < y.shape[0] else np.zeros((0, self.channels), dtype=np.float32) # Keep a small input tail for the next call (say ~ 4 ms at in_sr) tail_samples = max(int(0.004 * self.in_sr), 1) self._hist = x[-tail_samples:] if x.shape[0] >= tail_samples else x.copy() if final: self._hist = np.zeros((0, self.channels), dtype=np.float32) return y def flush(self) -> np.ndarray: """Drain converter tail (call at stop).""" if self._backend == "soxr": return self._rs.process(np.zeros((0, self.channels), dtype=np.float32), final=True) elif self._backend == "samplerate": ratio = float(self.out_sr) / float(self.in_sr) return self._rs.process(np.zeros((0, self.channels), dtype=np.float32), ratio, end_of_input=True) else: # nothing meaningful to flush in scipy fallback return np.zeros((0, self.channels), dtype=np.float32)