thecollabagepatch commited on
Commit
d1afbc8
·
1 Parent(s): 2f6eca9

initial commit

Browse files
Files changed (4) hide show
  1. Dockerfile +138 -0
  2. app.py +436 -0
  3. jam_worker.py +231 -0
  4. utils.py +168 -0
Dockerfile ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # thecollabagepatch/magenta:latest
2
+ FROM nvidia/cuda:12.6.2-cudnn-runtime-ubuntu22.04
3
+
4
+ # CUDA libs present + on loader path
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ cuda-libraries-12-4 && rm -rf /var/lib/apt/lists/*
7
+ ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda-12.4/lib64:/usr/local/cuda-12.4/compat:/usr/local/cuda/targets/x86_64-linux/lib:${LD_LIBRARY_PATH}
8
+ RUN ln -sf /usr/local/cuda/targets/x86_64-linux/lib /usr/local/cuda/lib64 || true
9
+
10
+ # Ensure the NVIDIA repo key is present (non-interactive) and install cuDNN 9.8
11
+ RUN set -eux; \
12
+ apt-get update && apt-get install -y --no-install-recommends gnupg ca-certificates curl; \
13
+ install -d -m 0755 /usr/share/keyrings; \
14
+ # Refresh the *same* keyring the base source uses (no second source file)
15
+ curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub \
16
+ | gpg --batch --yes --dearmor -o /usr/share/keyrings/cuda-archive-keyring.gpg; \
17
+ apt-get update; \
18
+ # If libcudnn is "held", unhold it so we can move to 9.8
19
+ apt-mark unhold libcudnn9-cuda-12 || true; \
20
+ # Install cuDNN 9.8 for CUDA 12 (correct dev package name!)
21
+ apt-get install -y --no-install-recommends \
22
+ 'libcudnn9-cuda-12=9.8.*' \
23
+ 'libcudnn9-dev-cuda-12=9.8.*' \
24
+ --allow-downgrades --allow-change-held-packages; \
25
+ apt-mark hold libcudnn9-cuda-12 || true; \
26
+ ldconfig; \
27
+ rm -rf /var/lib/apt/lists/*
28
+
29
+ # (optional) preload workaround if still needed
30
+ ENV LD_PRELOAD=/usr/local/cuda/lib64/libcusparse.so.12:/usr/local/cuda/lib64/libcublas.so.12:/usr/local/cuda/lib64/libcublasLt.so.12:/usr/local/cuda/lib64/libcufft.so.11:/usr/local/cuda/lib64/libcusolver.so.11
31
+
32
+ ENV DEBIAN_FRONTEND=noninteractive \
33
+ PYTHONUNBUFFERED=1 \
34
+ PIP_NO_CACHE_DIR=1 \
35
+ TF_FORCE_GPU_ALLOW_GROWTH=true \
36
+ XLA_PYTHON_CLIENT_PREALLOCATE=false
37
+
38
+ ENV JAX_PLATFORMS=""
39
+
40
+ # --- OS deps ---
41
+ RUN apt-get update && apt-get install -y --no-install-recommends \
42
+ software-properties-common curl ca-certificates git \
43
+ libsndfile1 ffmpeg \
44
+ build-essential pkg-config \
45
+ && add-apt-repository ppa:deadsnakes/ppa -y \
46
+ && apt-get update && apt-get install -y --no-install-recommends \
47
+ python3.11 python3.11-venv python3.11-distutils python3-pip \
48
+ && rm -rf /var/lib/apt/lists/*
49
+
50
+ # Make python3 => 3.11 for convenience
51
+ RUN ln -sf /usr/bin/python3.11 /usr/bin/python && python -m pip install --upgrade pip
52
+
53
+ # --- Python deps (pin order matters!) ---
54
+ # 1) JAX CUDA pins
55
+ RUN python -m pip install "jax[cuda12]==0.6.2" "jaxlib==0.6.2"
56
+
57
+ # 2) Lock seqio early to avoid backtracking madness
58
+ RUN python -m pip install "seqio==0.0.11"
59
+
60
+ # 3) Install Magenta RT *without* deps so we control pins
61
+ RUN python -m pip install --no-deps 'git+https://github.com/magenta/magenta-realtime#egg=magenta_rt[gpu]'
62
+
63
+ # 4) TF nightlies (MATCH DATES!)
64
+ RUN python -m pip install \
65
+ "tf_nightly==2.20.0.dev20250619" \
66
+ "tensorflow-text-nightly==2.20.0.dev20250316" \
67
+ "tf-hub-nightly"
68
+
69
+ # 5) tf2jax pinned alongside tf_nightly so pip doesn’t drag stable TF
70
+ RUN python -m pip install tf2jax "tf_nightly==2.20.0.dev20250619"
71
+
72
+ # 6) The rest of MRT deps + API runtime deps
73
+ RUN python -m pip install \
74
+ gin-config librosa resampy soundfile \
75
+ google-auth google-auth-oauthlib google-auth-httplib2 \
76
+ google-api-core googleapis-common-protos google-resumable-media \
77
+ google-cloud-storage requests tqdm typing-extensions numpy==2.1.3 \
78
+ fastapi uvicorn[standard] python-multipart pyloudnorm
79
+
80
+ # 7) Exact commits for T5X/Flaxformer as in pyproject
81
+ RUN python -m pip install \
82
+ "t5x @ git+https://github.com/google-research/t5x.git@92c5b46" \
83
+ "flaxformer @ git+https://github.com/google/flaxformer@399ea3a"
84
+
85
+ # ---- FINAL: enforce TF nightlies and clean any stable TF ----
86
+ RUN python - <<'PY'
87
+ import sys, sysconfig, glob, os, shutil
88
+ # Find a writable site dir (site-packages OR dist-packages)
89
+ cands = [sysconfig.get_paths().get('purelib'), sysconfig.get_paths().get('platlib')]
90
+ cands += [p for p in sys.path if p and p.endswith(('site-packages','dist-packages'))]
91
+ site = next(p for p in cands if p and os.path.isdir(p))
92
+
93
+ patterns = [
94
+ "tensorflow", "tensorflow-*.dist-info", "tensorflow-*.egg-info",
95
+ "tf-nightly-*.dist-info", "tf_nightly-*.dist-info",
96
+ "tensorflow_text", "tensorflow_text-*.dist-info",
97
+ "tf-hub-nightly-*.dist-info", "tf_hub_nightly-*.dist-info",
98
+ "tf_keras-nightly-*.dist-info", "tf_keras_nightly-*.dist-info",
99
+ "tensorboard*", "tb-nightly-*.dist-info",
100
+ "keras*", # remove stray keras
101
+ "tensorflow_hub*", "tensorflow_io*",
102
+ ]
103
+ for pat in patterns:
104
+ for path in glob.glob(os.path.join(site, pat)):
105
+ if os.path.isdir(path): shutil.rmtree(path, ignore_errors=True)
106
+ else:
107
+ try: os.remove(path)
108
+ except FileNotFoundError: pass
109
+
110
+ print("TF/Hub/Text cleared in:", site)
111
+ PY
112
+
113
+ # Reinstall pinned nightlies in ONE transaction
114
+ RUN python -m pip install --no-cache-dir --force-reinstall \
115
+ "tf-nightly==2.20.0.dev20250619" \
116
+ "tensorflow-text-nightly==2.20.0.dev20250316" \
117
+ "tf-hub-nightly"
118
+
119
+ RUN python -m pip install huggingface_hub
120
+
121
+ RUN python -m pip install --no-cache-dir --force-reinstall "protobuf==4.25.3"
122
+
123
+ # Switch to Spaces’ preferred user
124
+ # Switch to Spaces’ preferred user
125
+ RUN useradd -m -u 1000 appuser
126
+ WORKDIR /home/appuser/app
127
+
128
+ # Copy from *build context* into image, owned by appuser
129
+ COPY --chown=appuser:appuser app.py /home/appuser/app/app.py
130
+
131
+ # NEW: shared utils + worker
132
+ COPY --chown=appuser:appuser utils.py /home/appuser/app/utils.py
133
+ COPY --chown=appuser:appuser jam_worker.py /home/appuser/app/jam_worker.py
134
+
135
+ USER appuser
136
+
137
+ EXPOSE 7860
138
+ CMD ["bash", "-lc", "python -m uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860}"]
app.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from magenta_rt import system, audio as au
2
+ import numpy as np
3
+ from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response
4
+ import tempfile, io, base64, math, threading
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from contextlib import contextmanager
7
+ import soundfile as sf
8
+ import numpy as np
9
+ from math import gcd
10
+ from scipy.signal import resample_poly
11
+ from utils import (
12
+ match_loudness_to_reference, stitch_generated, hard_trim_seconds,
13
+ apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
14
+ resample_and_snap, wav_bytes_base64
15
+ )
16
+
17
+ from jam_worker import JamWorker, JamParams, JamChunk
18
+ import uuid, threading
19
+
20
+ jam_registry: dict[str, JamWorker] = {}
21
+ jam_lock = threading.Lock()
22
+
23
+ @contextmanager
24
+ def mrt_overrides(mrt, **kwargs):
25
+ """Temporarily set attributes on MRT if they exist; restore after."""
26
+ old = {}
27
+ try:
28
+ for k, v in kwargs.items():
29
+ if hasattr(mrt, k):
30
+ old[k] = getattr(mrt, k)
31
+ setattr(mrt, k, v)
32
+ yield
33
+ finally:
34
+ for k, v in old.items():
35
+ setattr(mrt, k, v)
36
+
37
+ # loudness utils
38
+ try:
39
+ import pyloudnorm as pyln
40
+ _HAS_LOUDNORM = True
41
+ except Exception:
42
+ _HAS_LOUDNORM = False
43
+
44
+ # ----------------------------
45
+ # Main generation (single combined style vector)
46
+ # ----------------------------
47
+ def generate_loop_continuation_with_mrt(
48
+ mrt,
49
+ input_wav_path: str,
50
+ bpm: float,
51
+ extra_styles=None,
52
+ style_weights=None,
53
+ bars: int = 8,
54
+ beats_per_bar: int = 4,
55
+ loop_weight: float = 1.0,
56
+ loudness_mode: str = "auto",
57
+ loudness_headroom_db: float = 1.0,
58
+ intro_bars_to_drop: int = 0, # <— NEW
59
+ ):
60
+ # Load & prep (unchanged)
61
+ loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
62
+
63
+ # Use tail for context (your recent change)
64
+ codec_fps = float(mrt.codec.frame_rate)
65
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
66
+ loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
67
+
68
+ tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
69
+ tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
70
+
71
+ # Bar-aligned token window (unchanged)
72
+ context_tokens = make_bar_aligned_context(
73
+ tokens, bpm=bpm, fps=int(mrt.codec.frame_rate),
74
+ ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
75
+ )
76
+ state = mrt.init_state()
77
+ state.context_tokens = context_tokens
78
+
79
+ # STYLE embed (optional: switch to loop_for_context if you want stronger “recent” bias)
80
+ loop_embed = mrt.embed_style(loop_for_context)
81
+ embeds, weights = [loop_embed], [float(loop_weight)]
82
+ if extra_styles:
83
+ for i, s in enumerate(extra_styles):
84
+ if s.strip():
85
+ embeds.append(mrt.embed_style(s.strip()))
86
+ w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
87
+ weights.append(float(w))
88
+ wsum = float(sum(weights)) or 1.0
89
+ weights = [w / wsum for w in weights]
90
+ combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype)
91
+
92
+ # --- Length math ---
93
+ seconds_per_bar = beats_per_bar * (60.0 / bpm)
94
+ total_secs = bars * seconds_per_bar
95
+ drop_bars = max(0, int(intro_bars_to_drop))
96
+ drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars
97
+ gen_total_secs = total_secs + drop_secs # generate extra
98
+
99
+ # Chunk scheduling to cover gen_total_secs
100
+ chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
101
+ steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim
102
+
103
+ # Generate
104
+ chunks = []
105
+ for _ in range(steps):
106
+ wav, state = mrt.generate_chunk(state=state, style=combined_style)
107
+ chunks.append(wav)
108
+
109
+ # Stitch continuous audio
110
+ stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
111
+
112
+ # Trim to generated length (bars + dropped bars)
113
+ stitched = hard_trim_seconds(stitched, gen_total_secs)
114
+
115
+ # 👉 Drop the intro bars
116
+ if drop_secs > 0:
117
+ n_drop = int(round(drop_secs * stitched.sample_rate))
118
+ stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
119
+
120
+ # Final exact-length trim to requested bars
121
+ out = hard_trim_seconds(stitched, total_secs)
122
+
123
+ # Final polish AFTER drop
124
+ out = out.peak_normalize(0.95)
125
+ apply_micro_fades(out, 5)
126
+
127
+ # Loudness match to input (after drop) so bar 1 sits right
128
+ out, loud_stats = match_loudness_to_reference(
129
+ ref=loop, target=out,
130
+ method=loudness_mode, headroom_db=loudness_headroom_db
131
+ )
132
+
133
+ return out, loud_stats
134
+
135
+
136
+
137
+ # ----------------------------
138
+ # FastAPI app with lazy, thread-safe model init
139
+ # ----------------------------
140
+ app = FastAPI()
141
+
142
+ app.add_middleware(
143
+ CORSMiddleware,
144
+ allow_origins=["*"], # or lock to your domain(s)
145
+ allow_credentials=True,
146
+ allow_methods=["*"],
147
+ allow_headers=["*"],
148
+ )
149
+
150
+ _MRT = None
151
+ _MRT_LOCK = threading.Lock()
152
+
153
+ def get_mrt():
154
+ global _MRT
155
+ if _MRT is None:
156
+ with _MRT_LOCK:
157
+ if _MRT is None:
158
+ _MRT = system.MagentaRT(tag="base", guidance_weight=1.0, device="gpu", lazy=False)
159
+ return _MRT
160
+
161
+ @app.post("/generate")
162
+ def generate(
163
+ loop_audio: UploadFile = File(...),
164
+ bpm: float = Form(...),
165
+ bars: int = Form(8),
166
+ beats_per_bar: int = Form(4),
167
+ styles: str = Form("acid house"),
168
+ style_weights: str = Form(""),
169
+ loop_weight: float = Form(1.0),
170
+ loudness_mode: str = Form("auto"),
171
+ loudness_headroom_db: float = Form(1.0),
172
+ guidance_weight: float = Form(5.0),
173
+ temperature: float = Form(1.1),
174
+ topk: int = Form(40),
175
+ target_sample_rate: int | None = Form(None),
176
+ intro_bars_to_drop: int = Form(0), # <— NEW
177
+ ):
178
+ # Read file
179
+ data = loop_audio.file.read()
180
+ if not data:
181
+ return {"error": "Empty file"}
182
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
183
+ tmp.write(data)
184
+ tmp_path = tmp.name
185
+
186
+ # Parse styles + weights
187
+ extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()]
188
+ weights = [float(x) for x in style_weights.split(",")] if style_weights else None
189
+
190
+ mrt = get_mrt() # warm once, in this worker thread
191
+ # Temporarily override MRT inference knobs for this request
192
+ with mrt_overrides(mrt,
193
+ guidance_weight=guidance_weight,
194
+ temperature=temperature,
195
+ topk=topk):
196
+ wav, loud_stats = generate_loop_continuation_with_mrt(
197
+ mrt,
198
+ input_wav_path=tmp_path,
199
+ bpm=bpm,
200
+ extra_styles=extra_styles,
201
+ style_weights=weights,
202
+ bars=bars,
203
+ beats_per_bar=beats_per_bar,
204
+ loop_weight=loop_weight,
205
+ loudness_mode=loudness_mode,
206
+ loudness_headroom_db=loudness_headroom_db,
207
+ intro_bars_to_drop=intro_bars_to_drop, # <— pass through
208
+ )
209
+
210
+ # 1) Figure out the desired SR
211
+ inp_info = sf.info(tmp_path)
212
+ input_sr = int(inp_info.samplerate)
213
+ target_sr = int(target_sample_rate or input_sr)
214
+
215
+ # 2) Convert to target SR + snap to exact bars
216
+ cur_sr = int(mrt.sample_rate)
217
+ x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
218
+ seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
219
+ expected_secs = float(bars) * seconds_per_bar
220
+ x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=expected_secs)
221
+
222
+ # 3) Encode WAV once (no extra write)
223
+ audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr)
224
+ loop_duration_seconds = total_samples / float(target_sr)
225
+
226
+ # 4) Metadata
227
+ metadata = {
228
+ "bpm": int(round(bpm)),
229
+ "bars": int(bars),
230
+ "beats_per_bar": int(beats_per_bar),
231
+ "styles": extra_styles,
232
+ "style_weights": weights,
233
+ "loop_weight": loop_weight,
234
+ "loudness": loud_stats,
235
+ "sample_rate": int(target_sr),
236
+ "channels": int(channels),
237
+ "crossfade_seconds": mrt.config.crossfade_length,
238
+ "total_samples": int(total_samples),
239
+ "seconds_per_bar": seconds_per_bar,
240
+ "loop_duration_seconds": loop_duration_seconds,
241
+ "guidance_weight": guidance_weight,
242
+ "temperature": temperature,
243
+ "topk": topk,
244
+ }
245
+ return {"audio_base64": audio_b64, "metadata": metadata}
246
+
247
+ # ----------------------------
248
+ # the 'keep jamming' button
249
+ # ----------------------------
250
+
251
+ @app.post("/jam/start")
252
+ def jam_start(
253
+ loop_audio: UploadFile = File(...),
254
+ bpm: float = Form(...),
255
+ bars_per_chunk: int = Form(4),
256
+ beats_per_bar: int = Form(4),
257
+ styles: str = Form(""),
258
+ style_weights: str = Form(""),
259
+ loop_weight: float = Form(1.0),
260
+ loudness_mode: str = Form("auto"),
261
+ loudness_headroom_db: float = Form(1.0),
262
+ guidance_weight: float = Form(1.1),
263
+ temperature: float = Form(1.1),
264
+ topk: int = Form(40),
265
+ target_sample_rate: int | None = Form(None),
266
+ ):
267
+ # enforce single active jam per GPU
268
+ with jam_lock:
269
+ for sid, w in list(jam_registry.items()):
270
+ if w.is_alive():
271
+ raise HTTPException(status_code=429, detail="A jam is already running. Try again later.")
272
+
273
+ # read input + prep context/style (reuse your existing code)
274
+ data = loop_audio.file.read()
275
+ if not data: raise HTTPException(status_code=400, detail="Empty file")
276
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
277
+ tmp.write(data); tmp_path = tmp.name
278
+
279
+ mrt = get_mrt()
280
+ loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo()
281
+
282
+ # build tail context + style vec (tail-biased)
283
+ codec_fps = float(mrt.codec.frame_rate)
284
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
285
+ loop_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
286
+
287
+ # style vec = normalized mix of loop_tail + extra styles
288
+ embeds, weights = [mrt.embed_style(loop_tail)], [float(loop_weight)]
289
+ extra = [s for s in (styles.split(",") if styles else []) if s.strip()]
290
+ sw = [float(x) for x in style_weights.split(",")] if style_weights else []
291
+ for i, s in enumerate(extra):
292
+ embeds.append(mrt.embed_style(s.strip()))
293
+ weights.append(sw[i] if i < len(sw) else 1.0)
294
+ wsum = sum(weights) or 1.0
295
+ weights = [w / wsum for w in weights]
296
+ style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(embeds[0].dtype)
297
+
298
+ # target SR (default input SR)
299
+ inp_info = sf.info(tmp_path)
300
+ input_sr = int(inp_info.samplerate)
301
+ target_sr = int(target_sample_rate or input_sr)
302
+
303
+ params = JamParams(
304
+ bpm=bpm,
305
+ beats_per_bar=beats_per_bar,
306
+ bars_per_chunk=bars_per_chunk,
307
+ target_sr=target_sr,
308
+ loudness_mode=loudness_mode,
309
+ headroom_db=loudness_headroom_db,
310
+ style_vec=style_vec,
311
+ ref_loop=loop_tail, # For loudness matching
312
+ combined_loop=loop, # NEW: Full loop for context setup
313
+ guidance_weight=guidance_weight,
314
+ temperature=temperature,
315
+ topk=topk
316
+ )
317
+
318
+ worker = JamWorker(mrt, params)
319
+ sid = str(uuid.uuid4())
320
+ with jam_lock:
321
+ jam_registry[sid] = worker
322
+ worker.start()
323
+
324
+ return {"session_id": sid}
325
+
326
+ @app.get("/jam/next")
327
+ def jam_next(session_id: str):
328
+ """
329
+ Get the next sequential chunk in the jam session.
330
+ This ensures chunks are delivered in order without gaps.
331
+ """
332
+ with jam_lock:
333
+ worker = jam_registry.get(session_id)
334
+ if worker is None or not worker.is_alive():
335
+ raise HTTPException(status_code=404, detail="Session not found")
336
+
337
+ # Get the next sequential chunk (this blocks until ready)
338
+ chunk = worker.get_next_chunk()
339
+
340
+ if chunk is None:
341
+ raise HTTPException(status_code=408, detail="Chunk not ready within timeout")
342
+
343
+ return {
344
+ "chunk": {
345
+ "index": chunk.index,
346
+ "audio_base64": chunk.audio_base64,
347
+ "metadata": chunk.metadata
348
+ }
349
+ }
350
+
351
+ @app.post("/jam/consume")
352
+ def jam_consume(session_id: str = Form(...), chunk_index: int = Form(...)):
353
+ """
354
+ Mark a chunk as consumed by the frontend.
355
+ This helps the worker manage its buffer and generation flow.
356
+ """
357
+ with jam_lock:
358
+ worker = jam_registry.get(session_id)
359
+ if worker is None or not worker.is_alive():
360
+ raise HTTPException(status_code=404, detail="Session not found")
361
+
362
+ worker.mark_chunk_consumed(chunk_index)
363
+
364
+ return {"consumed": chunk_index}
365
+
366
+
367
+
368
+ @app.post("/jam/stop")
369
+ def jam_stop(session_id: str = Body(..., embed=True)):
370
+ with jam_lock:
371
+ worker = jam_registry.get(session_id)
372
+ if worker is None:
373
+ raise HTTPException(status_code=404, detail="Session not found")
374
+
375
+ worker.stop()
376
+ worker.join(timeout=5.0)
377
+ if worker.is_alive():
378
+ # It’s daemon=True, so it won’t block process exit, but report it
379
+ print(f"⚠️ JamWorker {session_id} did not stop within timeout")
380
+
381
+ with jam_lock:
382
+ jam_registry.pop(session_id, None)
383
+ return {"stopped": True}
384
+
385
+ @app.post("/jam/update")
386
+ def jam_update(session_id: str = Form(...),
387
+ guidance_weight: float | None = Form(None),
388
+ temperature: float | None = Form(None),
389
+ topk: int | None = Form(None)):
390
+ with jam_lock:
391
+ worker = jam_registry.get(session_id)
392
+ if worker is None or not worker.is_alive():
393
+ raise HTTPException(status_code=404, detail="Session not found")
394
+ worker.update_knobs(guidance_weight=guidance_weight, temperature=temperature, topk=topk)
395
+ return {"ok": True}
396
+
397
+ @app.get("/jam/status")
398
+ def jam_status(session_id: str):
399
+ with jam_lock:
400
+ worker = jam_registry.get(session_id)
401
+
402
+ if worker is None:
403
+ raise HTTPException(status_code=404, detail="Session not found")
404
+
405
+ running = worker.is_alive()
406
+
407
+ # Snapshot safely
408
+ with worker._lock:
409
+ last_generated = int(worker.idx)
410
+ last_delivered = int(worker._last_delivered_index)
411
+ queued = len(worker.outbox)
412
+ buffer_ahead = last_generated - last_delivered
413
+ p = worker.params
414
+ spb = p.beats_per_bar * (60.0 / p.bpm)
415
+ chunk_secs = p.bars_per_chunk * spb
416
+
417
+ return {
418
+ "running": running,
419
+ "last_generated_index": last_generated, # Last chunk that finished generating
420
+ "last_delivered_index": last_delivered, # Last chunk sent to frontend
421
+ "buffer_ahead": buffer_ahead, # How many chunks ahead we are
422
+ "queued_chunks": queued, # Total chunks in outbox
423
+ "bpm": p.bpm,
424
+ "beats_per_bar": p.beats_per_bar,
425
+ "bars_per_chunk": p.bars_per_chunk,
426
+ "seconds_per_bar": spb,
427
+ "chunk_duration_seconds": chunk_secs,
428
+ "target_sample_rate": p.target_sr,
429
+ "last_chunk_started_at": worker.last_chunk_started_at,
430
+ "last_chunk_completed_at": worker.last_chunk_completed_at,
431
+ }
432
+
433
+
434
+ @app.get("/health")
435
+ def health():
436
+ return {"ok": True}
jam_worker.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # jam_worker.py - SIMPLE FIX VERSION
2
+ import threading, time, base64, io, uuid
3
+ from dataclasses import dataclass, field
4
+ import numpy as np
5
+ import soundfile as sf
6
+
7
+ from utils import (
8
+ match_loudness_to_reference, stitch_generated, hard_trim_seconds,
9
+ apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
10
+ resample_and_snap, wav_bytes_base64
11
+ )
12
+
13
+ @dataclass
14
+ class JamParams:
15
+ bpm: float
16
+ beats_per_bar: int
17
+ bars_per_chunk: int
18
+ target_sr: int
19
+ loudness_mode: str = "auto"
20
+ headroom_db: float = 1.0
21
+ style_vec: np.ndarray | None = None
22
+ ref_loop: any = None
23
+ combined_loop: any = None
24
+ guidance_weight: float = 1.1
25
+ temperature: float = 1.1
26
+ topk: int = 40
27
+
28
+ @dataclass
29
+ class JamChunk:
30
+ index: int
31
+ audio_base64: str
32
+ metadata: dict
33
+
34
+ class JamWorker(threading.Thread):
35
+ def __init__(self, mrt, params: JamParams):
36
+ super().__init__(daemon=True)
37
+ self.mrt = mrt
38
+ self.params = params
39
+ self.state = mrt.init_state()
40
+
41
+ if params.combined_loop is not None:
42
+ self._setup_context_from_combined_loop()
43
+
44
+ self.idx = 0
45
+ self.outbox: list[JamChunk] = []
46
+ self._stop_event = threading.Event()
47
+
48
+ # NEW: Track delivery state
49
+ self._last_delivered_index = 0
50
+ self._max_buffer_ahead = 5 # Don't generate more than 3 chunks ahead
51
+
52
+ # Timing info
53
+ self.last_chunk_started_at = None
54
+ self.last_chunk_completed_at = None
55
+ self._lock = threading.Lock()
56
+
57
+ def _setup_context_from_combined_loop(self):
58
+ """Set up MRT context tokens from the combined loop audio"""
59
+ try:
60
+ from utils import make_bar_aligned_context, take_bar_aligned_tail
61
+
62
+ codec_fps = float(self.mrt.codec.frame_rate)
63
+ ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
64
+
65
+ loop_for_context = take_bar_aligned_tail(
66
+ self.params.combined_loop,
67
+ self.params.bpm,
68
+ self.params.beats_per_bar,
69
+ ctx_seconds
70
+ )
71
+
72
+ tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
73
+ tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
74
+
75
+ context_tokens = make_bar_aligned_context(
76
+ tokens,
77
+ bpm=self.params.bpm,
78
+ fps=int(self.mrt.codec.frame_rate),
79
+ ctx_frames=self.mrt.config.context_length_frames,
80
+ beats_per_bar=self.params.beats_per_bar
81
+ )
82
+
83
+ self.state.context_tokens = context_tokens
84
+ print(f"✅ JamWorker: Set up fresh context from combined loop")
85
+
86
+ except Exception as e:
87
+ print(f"❌ Failed to setup context from combined loop: {e}")
88
+
89
+ def stop(self):
90
+ self._stop_event.set()
91
+
92
+ def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
93
+ with self._lock:
94
+ if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
95
+ if temperature is not None: self.params.temperature = float(temperature)
96
+ if topk is not None: self.params.topk = int(topk)
97
+
98
+ def get_next_chunk(self) -> JamChunk | None:
99
+ """Get the next sequential chunk (blocks/waits if not ready)"""
100
+ target_index = self._last_delivered_index + 1
101
+
102
+ # Wait for the target chunk to be ready (with timeout)
103
+ max_wait = 30.0 # seconds
104
+ start_time = time.time()
105
+
106
+ while time.time() - start_time < max_wait and not self._stop_event.is_set():
107
+ with self._lock:
108
+ # Look for the exact chunk we need
109
+ for chunk in self.outbox:
110
+ if chunk.index == target_index:
111
+ self._last_delivered_index = target_index
112
+ print(f"📦 Delivered chunk {target_index}")
113
+ return chunk
114
+
115
+ # Not ready yet, wait a bit
116
+ time.sleep(0.1)
117
+
118
+ # Timeout or stopped
119
+ return None
120
+
121
+ def mark_chunk_consumed(self, chunk_index: int):
122
+ """Mark a chunk as consumed by the frontend"""
123
+ with self._lock:
124
+ self._last_delivered_index = max(self._last_delivered_index, chunk_index)
125
+ print(f"✅ Chunk {chunk_index} consumed")
126
+
127
+ def _should_generate_next_chunk(self) -> bool:
128
+ """Check if we should generate the next chunk (don't get too far ahead)"""
129
+ with self._lock:
130
+ # Don't generate if we're already too far ahead
131
+ if self.idx > self._last_delivered_index + self._max_buffer_ahead:
132
+ return False
133
+ return True
134
+
135
+ def _seconds_per_bar(self) -> float:
136
+ return self.params.beats_per_bar * (60.0 / self.params.bpm)
137
+
138
+ def _snap_and_encode(self, y, seconds, target_sr, bars):
139
+ cur_sr = int(self.mrt.sample_rate)
140
+ x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
141
+ x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
142
+ b64, total_samples, channels = wav_bytes_base64(x, target_sr)
143
+ meta = {
144
+ "bpm": int(round(self.params.bpm)),
145
+ "bars": int(bars),
146
+ "beats_per_bar": int(self.params.beats_per_bar),
147
+ "sample_rate": int(target_sr),
148
+ "channels": channels,
149
+ "total_samples": total_samples,
150
+ "seconds_per_bar": self._seconds_per_bar(),
151
+ "loop_duration_seconds": bars * self._seconds_per_bar(),
152
+ "guidance_weight": self.params.guidance_weight,
153
+ "temperature": self.params.temperature,
154
+ "topk": self.params.topk,
155
+ }
156
+ return b64, meta
157
+
158
+ def run(self):
159
+ """Main worker loop - generate chunks continuously but don't get too far ahead"""
160
+ spb = self._seconds_per_bar()
161
+ chunk_secs = self.params.bars_per_chunk * spb
162
+ xfade = self.mrt.config.crossfade_length
163
+
164
+ print("🚀 JamWorker started with flow control...")
165
+
166
+ while not self._stop_event.is_set():
167
+ # Check if we should generate the next chunk
168
+ if not self._should_generate_next_chunk():
169
+ # We're ahead enough, wait a bit for frontend to catch up
170
+ print(f"⏸️ Buffer full, waiting for consumption...")
171
+ time.sleep(0.5)
172
+ continue
173
+
174
+ # Generate the next chunk
175
+ with self._lock:
176
+ style_vec = self.params.style_vec
177
+ self.mrt.guidance_weight = self.params.guidance_weight
178
+ self.mrt.temperature = self.params.temperature
179
+ self.mrt.topk = self.params.topk
180
+ next_idx = self.idx + 1
181
+
182
+ print(f"🎹 Generating chunk {next_idx}...")
183
+
184
+ # Generate enough model chunks to cover chunk_secs
185
+ need = chunk_secs
186
+ chunks = []
187
+ self.last_chunk_started_at = time.time()
188
+
189
+ while need > 0 and not self._stop_event.is_set():
190
+ wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
191
+ chunks.append(wav)
192
+ need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))
193
+
194
+ if self._stop_event.is_set():
195
+ break
196
+
197
+ # Stitch and trim to exact seconds at model SR
198
+ y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
199
+ y = hard_trim_seconds(y, chunk_secs)
200
+
201
+ # Post-process
202
+ if next_idx == 1 and self.params.ref_loop is not None:
203
+ y, _ = match_loudness_to_reference(
204
+ self.params.ref_loop, y,
205
+ method=self.params.loudness_mode,
206
+ headroom_db=self.params.headroom_db
207
+ )
208
+ else:
209
+ apply_micro_fades(y, 3)
210
+
211
+ # Resample + snap + b64
212
+ b64, meta = self._snap_and_encode(
213
+ y, seconds=chunk_secs,
214
+ target_sr=self.params.target_sr,
215
+ bars=self.params.bars_per_chunk
216
+ )
217
+
218
+ # Store the completed chunk
219
+ with self._lock:
220
+ self.idx = next_idx
221
+ self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
222
+
223
+ # Keep outbox bounded (remove old chunks)
224
+ if len(self.outbox) > 10:
225
+ # Remove chunks that are way behind the delivery point
226
+ self.outbox = [ch for ch in self.outbox if ch.index > self._last_delivered_index - 5]
227
+
228
+ self.last_chunk_completed_at = time.time()
229
+ print(f"✅ Completed chunk {next_idx}")
230
+
231
+ print("🛑 JamWorker stopped")
utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ from __future__ import annotations
3
+ import io, base64, math
4
+ from math import gcd
5
+ import numpy as np
6
+ import soundfile as sf
7
+ from scipy.signal import resample_poly
8
+
9
+ # Magenta RT audio types
10
+ from magenta_rt import audio as au
11
+
12
+ # Optional loudness
13
+ try:
14
+ import pyloudnorm as pyln
15
+ _HAS_LOUDNORM = True
16
+ except Exception:
17
+ _HAS_LOUDNORM = False
18
+
19
+
20
+ # ---------- Loudness ----------
21
+ def _measure_lufs(wav: au.Waveform) -> float:
22
+ meter = pyln.Meter(wav.sample_rate) # BS.1770-4
23
+ return float(meter.integrated_loudness(wav.samples))
24
+
25
+ def _rms(x: np.ndarray) -> float:
26
+ if x.size == 0: return 0.0
27
+ return float(np.sqrt(np.mean(x**2)))
28
+
29
+ def match_loudness_to_reference(
30
+ ref: au.Waveform,
31
+ target: au.Waveform,
32
+ method: str = "auto", # "auto"|"lufs"|"rms"|"none"
33
+ headroom_db: float = 1.0
34
+ ) -> tuple[au.Waveform, dict]:
35
+ stats = {"method": method, "applied_gain_db": 0.0}
36
+ if method == "none":
37
+ return target, stats
38
+
39
+ if method == "auto":
40
+ method = "lufs" if _HAS_LOUDNORM else "rms"
41
+
42
+ if method == "lufs" and _HAS_LOUDNORM:
43
+ L_ref = _measure_lufs(ref)
44
+ L_tgt = _measure_lufs(target)
45
+ delta_db = L_ref - L_tgt
46
+ gain = 10.0 ** (delta_db / 20.0)
47
+ y = target.samples.astype(np.float32) * gain
48
+ stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db})
49
+ else:
50
+ ra = _rms(ref.samples)
51
+ rb = _rms(target.samples)
52
+ if rb <= 1e-12:
53
+ return target, stats
54
+ gain = ra / rb
55
+ y = target.samples.astype(np.float32) * gain
56
+ stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))})
57
+
58
+ # simple peak “limiter” to keep headroom
59
+ limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS
60
+ peak = float(np.max(np.abs(y))) if y.size else 0.0
61
+ if peak > limit:
62
+ y *= (limit / peak)
63
+ stats["post_peak_limited"] = True
64
+ else:
65
+ stats["post_peak_limited"] = False
66
+
67
+ target.samples = y.astype(np.float32)
68
+ return target, stats
69
+
70
+
71
+ # ---------- Stitch / fades / trims ----------
72
+ def stitch_generated(chunks, sr: int, xfade_s: float) -> au.Waveform:
73
+ if not chunks:
74
+ raise ValueError("no chunks")
75
+ xfade_n = int(round(xfade_s * sr))
76
+ if xfade_n <= 0:
77
+ return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr)
78
+
79
+ t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)
80
+ eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None]
81
+
82
+ first = chunks[0].samples
83
+ if first.shape[0] < xfade_n:
84
+ raise ValueError("chunk shorter than crossfade prefix")
85
+ out = first[xfade_n:].copy() # drop model pre-roll
86
+
87
+ for i in range(1, len(chunks)):
88
+ cur = chunks[i].samples
89
+ if cur.shape[0] < xfade_n:
90
+ continue
91
+ head, tail = cur[:xfade_n], cur[xfade_n:]
92
+ mixed = out[-xfade_n:] * eq_out + head * eq_in
93
+ out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0)
94
+
95
+ return au.Waveform(out, sr)
96
+
97
+ def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform:
98
+ n = int(round(seconds * wav.sample_rate))
99
+ return au.Waveform(wav.samples[:n], wav.sample_rate)
100
+
101
+ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
102
+ n = int(wav.sample_rate * ms / 1000.0)
103
+ if n > 0 and wav.samples.shape[0] > 2*n:
104
+ env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None]
105
+ wav.samples[:n] *= env
106
+ wav.samples[-n:] *= env[::-1]
107
+
108
+
109
+ # ---------- Token context helpers ----------
110
+ def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4):
111
+ frames_per_bar_f = (beats_per_bar * 60.0 / bpm) * fps
112
+ frames_per_bar = int(round(frames_per_bar_f))
113
+ if abs(frames_per_bar - frames_per_bar_f) > 1e-3:
114
+ reps = int(np.ceil(ctx_frames / len(tokens)))
115
+ return np.tile(tokens, (reps, 1))[-ctx_frames:]
116
+ reps = int(np.ceil(ctx_frames / len(tokens)))
117
+ tiled = np.tile(tokens, (reps, 1))
118
+ end = (len(tiled) // frames_per_bar) * frames_per_bar
119
+ if end < ctx_frames:
120
+ return tiled[-ctx_frames:]
121
+ start = end - ctx_frames
122
+ return tiled[start:end]
123
+
124
+ def take_bar_aligned_tail(wav: au.Waveform, bpm: float, beats_per_bar: int, ctx_seconds: float, max_bars=None) -> au.Waveform:
125
+ spb = (60.0 / bpm) * beats_per_bar
126
+ bars_needed = max(1, int(round(ctx_seconds / spb)))
127
+ if max_bars is not None:
128
+ bars_needed = min(bars_needed, max_bars)
129
+ tail_seconds = bars_needed * spb
130
+ n = int(round(tail_seconds * wav.sample_rate))
131
+ if n >= wav.samples.shape[0]:
132
+ return wav
133
+ return au.Waveform(wav.samples[-n:], wav.sample_rate)
134
+
135
+
136
+ # ---------- SR normalize + snap ----------
137
+ def resample_and_snap(x: np.ndarray, cur_sr: int, target_sr: int, seconds: float) -> np.ndarray:
138
+ """
139
+ x: np.ndarray shape (S, C), float32
140
+ Returns: exact-length array (round(seconds*target_sr), C)
141
+ """
142
+ if x.ndim == 1:
143
+ x = x[:, None]
144
+ if cur_sr != target_sr:
145
+ g = gcd(cur_sr, target_sr)
146
+ up, down = target_sr // g, cur_sr // g
147
+ x = resample_poly(x, up, down, axis=0)
148
+
149
+ expected_len = int(round(seconds * target_sr))
150
+ if x.shape[0] < expected_len:
151
+ pad = np.zeros((expected_len - x.shape[0], x.shape[1]), dtype=x.dtype)
152
+ x = np.vstack([x, pad])
153
+ elif x.shape[0] > expected_len:
154
+ x = x[:expected_len, :]
155
+ return x.astype(np.float32, copy=False)
156
+
157
+
158
+ # ---------- WAV encode ----------
159
+ def wav_bytes_base64(x: np.ndarray, sr: int) -> tuple[str, int, int]:
160
+ """
161
+ x: np.ndarray shape (S, C)
162
+ returns: (base64_wav, total_samples, channels)
163
+ """
164
+ buf = io.BytesIO()
165
+ sf.write(buf, x, sr, subtype="FLOAT", format="WAV")
166
+ buf.seek(0)
167
+ b64 = base64.b64encode(buf.read()).decode("utf-8")
168
+ return b64, int(x.shape[0]), int(x.shape[1])