Commit
Β·
30fdbbc
1
Parent(s):
384e4ac
updating docs a bit
Browse files- app.py +30 -179
- documentation.html +302 -43
app.py
CHANGED
|
@@ -77,45 +77,22 @@ from pydantic import BaseModel
|
|
| 77 |
from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
|
| 78 |
|
| 79 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
| 80 |
-
_FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
| 81 |
_ASSETS_REPO_ID: str | None = None
|
| 82 |
_MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
|
| 83 |
_CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
|
| 84 |
|
| 85 |
-
_STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
|
| 86 |
|
| 87 |
# Create instances (these don't modify globals)
|
| 88 |
asset_manager = AssetManager()
|
| 89 |
model_selector = ModelSelector(CheckpointManager(), asset_manager)
|
| 90 |
|
| 91 |
# Sync asset manager with existing globals
|
| 92 |
-
def _sync_asset_manager():
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# def _list_ckpt_steps(repo_id: str, revision: str = "main") -> list[int]:
|
| 98 |
-
# """
|
| 99 |
-
# List available checkpoint steps in a HF model repo without downloading all weights.
|
| 100 |
-
# Looks for:
|
| 101 |
-
# checkpoint_<step>/
|
| 102 |
-
# checkpoint_<step>.tgz | .tar.gz
|
| 103 |
-
# archives/checkpoint_<step>.tgz | .tar.gz
|
| 104 |
-
# """
|
| 105 |
-
# api = HfApi()
|
| 106 |
-
# files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision)
|
| 107 |
-
# steps = set()
|
| 108 |
-
# for f in files:
|
| 109 |
-
# m = _STEP_RE.search(f)
|
| 110 |
-
# if m:
|
| 111 |
-
# try:
|
| 112 |
-
# steps.add(int(m.group(1)))
|
| 113 |
-
# except:
|
| 114 |
-
# pass
|
| 115 |
-
# return sorted(steps)
|
| 116 |
-
|
| 117 |
-
# def _step_exists(repo_id: str, revision: str, step: int) -> bool:
|
| 118 |
-
# return step in _list_ckpt_steps(repo_id, revision)
|
| 119 |
|
| 120 |
def _any_jam_running() -> bool:
|
| 121 |
with jam_lock:
|
|
@@ -129,132 +106,6 @@ def _stop_all_jams(timeout: float = 5.0):
|
|
| 129 |
w.join(timeout=timeout)
|
| 130 |
jam_registry.pop(sid, None)
|
| 131 |
|
| 132 |
-
# def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
|
| 133 |
-
# """
|
| 134 |
-
# Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
|
| 135 |
-
# Safe to call multiple times; will overwrite globals if successful.
|
| 136 |
-
# """
|
| 137 |
-
# global _ASSETS_REPO_ID, _MEAN_EMBED, _CENTROIDS
|
| 138 |
-
# repo_id = repo_id or _FINETUNE_REPO_DEFAULT
|
| 139 |
-
# try:
|
| 140 |
-
# from huggingface_hub import hf_hub_download
|
| 141 |
-
# mean_path = None
|
| 142 |
-
# cent_path = None
|
| 143 |
-
# try:
|
| 144 |
-
# mean_path = hf_hub_download(repo_id, filename="mean_style_embed.npy", repo_type="model")
|
| 145 |
-
# except Exception:
|
| 146 |
-
# pass
|
| 147 |
-
# try:
|
| 148 |
-
# cent_path = hf_hub_download(repo_id, filename="cluster_centroids.npy", repo_type="model")
|
| 149 |
-
# except Exception:
|
| 150 |
-
# pass
|
| 151 |
-
|
| 152 |
-
# if mean_path is None and cent_path is None:
|
| 153 |
-
# return False, f"No finetune asset files found in repo {repo_id}"
|
| 154 |
-
|
| 155 |
-
# if mean_path is not None:
|
| 156 |
-
# m = np.load(mean_path)
|
| 157 |
-
# if m.ndim != 1:
|
| 158 |
-
# return False, f"mean_style_embed.npy must be 1-D (got {m.shape})"
|
| 159 |
-
# else:
|
| 160 |
-
# m = None
|
| 161 |
-
|
| 162 |
-
# if cent_path is not None:
|
| 163 |
-
# c = np.load(cent_path)
|
| 164 |
-
# if c.ndim != 2:
|
| 165 |
-
# return False, f"cluster_centroids.npy must be 2-D (got {c.shape})"
|
| 166 |
-
# else:
|
| 167 |
-
# c = None
|
| 168 |
-
|
| 169 |
-
# # Optional: shape check vs model embedding dim once model is alive
|
| 170 |
-
# try:
|
| 171 |
-
# d = int(get_mrt().style_model.config.embedding_dim)
|
| 172 |
-
# if m is not None and m.shape[0] != d:
|
| 173 |
-
# return False, f"mean_style_embed dim {m.shape[0]} != model dim {d}"
|
| 174 |
-
# if c is not None and c.shape[1] != d:
|
| 175 |
-
# return False, f"cluster_centroids dim {c.shape[1]} != model dim {d}"
|
| 176 |
-
# except Exception:
|
| 177 |
-
# # Model not built yet; weβll trust the files and rely on runtime checks later
|
| 178 |
-
# pass
|
| 179 |
-
|
| 180 |
-
# _MEAN_EMBED = m.astype(np.float32, copy=False) if m is not None else None
|
| 181 |
-
# _CENTROIDS = c.astype(np.float32, copy=False) if c is not None else None
|
| 182 |
-
# _ASSETS_REPO_ID = repo_id
|
| 183 |
-
# logging.info("Loaded finetune assets from %s (mean=%s, centroids=%s)",
|
| 184 |
-
# repo_id,
|
| 185 |
-
# "yes" if _MEAN_EMBED is not None else "no",
|
| 186 |
-
# f"{_CENTROIDS.shape[0]}x{_CENTROIDS.shape[1]}" if _CENTROIDS is not None else "no")
|
| 187 |
-
# return True, "ok"
|
| 188 |
-
# except Exception as e:
|
| 189 |
-
# logging.exception("Failed to load finetune assets: %s", e)
|
| 190 |
-
# return False, str(e)
|
| 191 |
-
|
| 192 |
-
# def _ensure_assets_loaded():
|
| 193 |
-
# # Best-effort lazy load if nothing is loaded yet
|
| 194 |
-
# if _MEAN_EMBED is None and _CENTROIDS is None:
|
| 195 |
-
# _load_finetune_assets_from_hf(_ASSETS_REPO_ID or _FINETUNE_REPO_DEFAULT)
|
| 196 |
-
# ------------------------------------------------------------------------------
|
| 197 |
-
|
| 198 |
-
# def _resolve_checkpoint_dir() -> str | None:
|
| 199 |
-
# repo_id = os.getenv("MRT_CKPT_REPO")
|
| 200 |
-
# if not repo_id:
|
| 201 |
-
# return None
|
| 202 |
-
# step = os.getenv("MRT_CKPT_STEP") # e.g. "1863001"
|
| 203 |
-
|
| 204 |
-
# root = Path(snapshot_download(
|
| 205 |
-
# repo_id=repo_id,
|
| 206 |
-
# repo_type="model",
|
| 207 |
-
# revision=os.getenv("MRT_CKPT_REV", "main"),
|
| 208 |
-
# local_dir="/home/appuser/.cache/mrt_ckpt/repo",
|
| 209 |
-
# local_dir_use_symlinks=False,
|
| 210 |
-
# ))
|
| 211 |
-
|
| 212 |
-
# # Prefer an archive if present (more reliable for Zarr/T5X)
|
| 213 |
-
# arch_names = [
|
| 214 |
-
# f"checkpoint_{step}.tgz",
|
| 215 |
-
# f"checkpoint_{step}.tar.gz",
|
| 216 |
-
# f"archives/checkpoint_{step}.tgz",
|
| 217 |
-
# f"archives/checkpoint_{step}.tar.gz",
|
| 218 |
-
# ] if step else []
|
| 219 |
-
|
| 220 |
-
# cache_root = Path("/home/appuser/.cache/mrt_ckpt/extracted")
|
| 221 |
-
# cache_root.mkdir(parents=True, exist_ok=True)
|
| 222 |
-
# for name in arch_names:
|
| 223 |
-
# arch = root / name
|
| 224 |
-
# if arch.is_file():
|
| 225 |
-
# out_dir = cache_root / f"checkpoint_{step}"
|
| 226 |
-
# marker = out_dir.with_suffix(".ok")
|
| 227 |
-
# if not marker.exists():
|
| 228 |
-
# out_dir.mkdir(parents=True, exist_ok=True)
|
| 229 |
-
# with tarfile.open(arch, "r:*") as tf:
|
| 230 |
-
# tf.extractall(out_dir)
|
| 231 |
-
# marker.write_text("ok")
|
| 232 |
-
# # sanity: require .zarray to exist inside the extracted tree
|
| 233 |
-
# if not any(out_dir.rglob(".zarray")):
|
| 234 |
-
# raise RuntimeError(f"Extracted archive missing .zarray files: {out_dir}")
|
| 235 |
-
# return str(out_dir / f"checkpoint_{step}") if (out_dir / f"checkpoint_{step}").exists() else str(out_dir)
|
| 236 |
-
|
| 237 |
-
# # No archive; try raw folder from repo and sanity check.
|
| 238 |
-
# if step:
|
| 239 |
-
# raw = root / f"checkpoint_{step}"
|
| 240 |
-
# if raw.is_dir():
|
| 241 |
-
# if not any(raw.rglob(".zarray")):
|
| 242 |
-
# raise RuntimeError(
|
| 243 |
-
# f"Downloaded checkpoint_{step} appears incomplete (no .zarray). "
|
| 244 |
-
# "Upload as a .tgz or push via git from a Unix shell."
|
| 245 |
-
# )
|
| 246 |
-
# return str(raw)
|
| 247 |
-
|
| 248 |
-
# # Pick latest if no step
|
| 249 |
-
# step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
|
| 250 |
-
# if step_dirs:
|
| 251 |
-
# pick = max(step_dirs, key=lambda d: int(d.name.split('_')[-1]))
|
| 252 |
-
# if not any(pick.rglob(".zarray")):
|
| 253 |
-
# raise RuntimeError(f"Downloaded {pick} appears incomplete (no .zarray).")
|
| 254 |
-
# return str(pick)
|
| 255 |
-
|
| 256 |
-
# return None
|
| 257 |
-
|
| 258 |
|
| 259 |
async def send_json_safe(ws: WebSocket, obj) -> bool:
|
| 260 |
"""Try to send. Returns False if the socket is (or becomes) closed."""
|
|
@@ -328,19 +179,19 @@ try:
|
|
| 328 |
except Exception:
|
| 329 |
_HAS_LOUDNORM = False
|
| 330 |
|
| 331 |
-
def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
|
| 345 |
def build_style_vector(
|
| 346 |
mrt,
|
|
@@ -518,6 +369,11 @@ def _mrt_warmup():
|
|
| 518 |
# Never crash on warmup errors; log and continue serving
|
| 519 |
logging.exception("MagentaRT warmup failed (continuing without warmup): %s", e)
|
| 520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
# Kick it off in the background on server start
|
| 522 |
@app.on_event("startup")
|
| 523 |
def _kickoff_warmup():
|
|
@@ -640,17 +496,6 @@ def model_checkpoints(repo_id: str, revision: str = "main"):
|
|
| 640 |
steps = CheckpointManager.list_ckpt_steps(repo_id, revision)
|
| 641 |
return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
|
| 642 |
|
| 643 |
-
# class ModelSelect(BaseModel):
|
| 644 |
-
# size: Optional[Literal["base","large"]] = None
|
| 645 |
-
# repo_id: Optional[str] = None
|
| 646 |
-
# revision: Optional[str] = "main"
|
| 647 |
-
# step: Optional[Union[int, str]] = None # allow "latest"
|
| 648 |
-
# assets_repo_id: Optional[str] = None # default: follow repo_id
|
| 649 |
-
# sync_assets: bool = True # load mean/centroids from repo
|
| 650 |
-
# prewarm: bool = False # call get_mrt() to build right away
|
| 651 |
-
# stop_active: bool = True # auto-stop jams; else 409
|
| 652 |
-
# dry_run: bool = False # validate only, don't swap
|
| 653 |
-
|
| 654 |
@app.post("/model/select")
|
| 655 |
def model_select(req: ModelSelect):
|
| 656 |
global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
|
|
@@ -733,6 +578,12 @@ def model_select(req: ModelSelect):
|
|
| 733 |
except Exception:
|
| 734 |
pass
|
| 735 |
raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
|
| 737 |
|
| 738 |
|
|
|
|
| 77 |
from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
|
| 78 |
|
| 79 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
| 80 |
+
# _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
| 81 |
_ASSETS_REPO_ID: str | None = None
|
| 82 |
_MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
|
| 83 |
_CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
|
| 84 |
|
| 85 |
+
# _STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
|
| 86 |
|
| 87 |
# Create instances (these don't modify globals)
|
| 88 |
asset_manager = AssetManager()
|
| 89 |
model_selector = ModelSelector(CheckpointManager(), asset_manager)
|
| 90 |
|
| 91 |
# Sync asset manager with existing globals
|
| 92 |
+
# def _sync_asset_manager():
|
| 93 |
+
# asset_manager.mean_embed = _MEAN_EMBED
|
| 94 |
+
# asset_manager.centroids = _CENTROIDS
|
| 95 |
+
# asset_manager.assets_repo_id = _ASSETS_REPO_ID
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
def _any_jam_running() -> bool:
|
| 98 |
with jam_lock:
|
|
|
|
| 106 |
w.join(timeout=timeout)
|
| 107 |
jam_registry.pop(sid, None)
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
async def send_json_safe(ws: WebSocket, obj) -> bool:
|
| 111 |
"""Try to send. Returns False if the socket is (or becomes) closed."""
|
|
|
|
| 179 |
except Exception:
|
| 180 |
_HAS_LOUDNORM = False
|
| 181 |
|
| 182 |
+
# def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
|
| 183 |
+
# extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
|
| 184 |
+
# if not extra:
|
| 185 |
+
# return mrt.embed_style("warmup")
|
| 186 |
+
# sw = [float(x) for x in (weights_str or "").split(",") if x.strip()]
|
| 187 |
+
# embeds, weights = [], []
|
| 188 |
+
# for i, s in enumerate(extra):
|
| 189 |
+
# embeds.append(mrt.embed_style(s))
|
| 190 |
+
# weights.append(sw[i] if i < len(sw) else 1.0)
|
| 191 |
+
# wsum = sum(weights) or 1.0
|
| 192 |
+
# weights = [w/wsum for w in weights]
|
| 193 |
+
# import numpy as np
|
| 194 |
+
# return np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
|
| 195 |
|
| 196 |
def build_style_vector(
|
| 197 |
mrt,
|
|
|
|
| 369 |
# Never crash on warmup errors; log and continue serving
|
| 370 |
logging.exception("MagentaRT warmup failed (continuing without warmup): %s", e)
|
| 371 |
|
| 372 |
+
|
| 373 |
+
# ----------------------------
|
| 374 |
+
# startup and model selection
|
| 375 |
+
# ----------------------------
|
| 376 |
+
|
| 377 |
# Kick it off in the background on server start
|
| 378 |
@app.on_event("startup")
|
| 379 |
def _kickoff_warmup():
|
|
|
|
| 496 |
steps = CheckpointManager.list_ckpt_steps(repo_id, revision)
|
| 497 |
return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
|
| 498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
@app.post("/model/select")
|
| 500 |
def model_select(req: ModelSelect):
|
| 501 |
global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
|
|
|
|
| 578 |
except Exception:
|
| 579 |
pass
|
| 580 |
raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
# ----------------------------
|
| 585 |
+
# one-shot generation
|
| 586 |
+
# ----------------------------
|
| 587 |
|
| 588 |
|
| 589 |
|
documentation.html
CHANGED
|
@@ -4,67 +4,326 @@
|
|
| 4 |
<meta charset="utf-8">
|
| 5 |
<title>MagentaRT Research API</title>
|
| 6 |
<style>
|
| 7 |
-
body {
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
ul { line-height: 1.8; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
</style>
|
| 13 |
</head>
|
| 14 |
<body>
|
| 15 |
-
<
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
<
|
| 23 |
-
<
|
| 24 |
-
<
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
"type": "start",
|
| 33 |
"mode": "rt",
|
| 34 |
"binary_audio": false,
|
| 35 |
"params": {
|
| 36 |
-
"styles": "
|
|
|
|
| 37 |
"temperature": 1.1,
|
| 38 |
"topk": 40,
|
| 39 |
"guidance_weight": 1.1,
|
| 40 |
-
"pace": "realtime",
|
| 41 |
-
"
|
|
|
|
|
|
|
| 42 |
}
|
| 43 |
}</pre>
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
"type": "update",
|
| 47 |
"styles": "jazz, hiphop",
|
| 48 |
-
"style_weights": "1.0,0.8",
|
| 49 |
"temperature": 1.2,
|
| 50 |
"topk": 64,
|
| 51 |
"guidance_weight": 1.0,
|
| 52 |
-
"
|
| 53 |
-
"
|
| 54 |
}</pre>
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
<
|
| 62 |
-
|
| 63 |
-
<
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
</body>
|
| 70 |
</html>
|
|
|
|
| 4 |
<meta charset="utf-8">
|
| 5 |
<title>MagentaRT Research API</title>
|
| 6 |
<style>
|
| 7 |
+
body {
|
| 8 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| 9 |
+
max-width: 900px;
|
| 10 |
+
margin: 48px auto;
|
| 11 |
+
padding: 0 24px;
|
| 12 |
+
color: #111;
|
| 13 |
+
line-height: 1.6;
|
| 14 |
+
}
|
| 15 |
+
.header { text-align: center; margin-bottom: 48px; }
|
| 16 |
+
.badge {
|
| 17 |
+
display: inline-block;
|
| 18 |
+
background: #ff6b35;
|
| 19 |
+
color: white;
|
| 20 |
+
padding: 4px 12px;
|
| 21 |
+
border-radius: 16px;
|
| 22 |
+
font-size: 0.85em;
|
| 23 |
+
font-weight: 500;
|
| 24 |
+
margin-left: 8px;
|
| 25 |
+
}
|
| 26 |
+
code, pre {
|
| 27 |
+
background: #f6f8fa;
|
| 28 |
+
border: 1px solid #eaecef;
|
| 29 |
+
border-radius: 6px;
|
| 30 |
+
font-family: 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas, monospace;
|
| 31 |
+
}
|
| 32 |
+
code { padding: 2px 6px; }
|
| 33 |
+
pre {
|
| 34 |
+
padding: 16px;
|
| 35 |
+
overflow-x: auto;
|
| 36 |
+
margin: 16px 0;
|
| 37 |
+
position: relative;
|
| 38 |
+
}
|
| 39 |
+
.copy-btn {
|
| 40 |
+
position: absolute;
|
| 41 |
+
top: 8px;
|
| 42 |
+
right: 8px;
|
| 43 |
+
background: #0969da;
|
| 44 |
+
color: white;
|
| 45 |
+
border: none;
|
| 46 |
+
border-radius: 4px;
|
| 47 |
+
padding: 4px 8px;
|
| 48 |
+
font-size: 12px;
|
| 49 |
+
cursor: pointer;
|
| 50 |
+
}
|
| 51 |
+
.copy-btn:hover { background: #0550ae; }
|
| 52 |
+
.muted { color: #656d76; }
|
| 53 |
+
.warning {
|
| 54 |
+
background: #fff8c5;
|
| 55 |
+
border: 1px solid #e3b341;
|
| 56 |
+
border-radius: 8px;
|
| 57 |
+
padding: 16px;
|
| 58 |
+
margin: 16px 0;
|
| 59 |
+
}
|
| 60 |
+
.info {
|
| 61 |
+
background: #dbeafe;
|
| 62 |
+
border: 1px solid #3b82f6;
|
| 63 |
+
border-radius: 8px;
|
| 64 |
+
padding: 16px;
|
| 65 |
+
margin: 16px 0;
|
| 66 |
+
}
|
| 67 |
ul { line-height: 1.8; }
|
| 68 |
+
.endpoint {
|
| 69 |
+
background: #f8f9fa;
|
| 70 |
+
border-left: 4px solid #0969da;
|
| 71 |
+
padding: 12px 16px;
|
| 72 |
+
margin: 12px 0;
|
| 73 |
+
}
|
| 74 |
+
.demo-placeholder {
|
| 75 |
+
background: #f6f8fa;
|
| 76 |
+
border: 2px dashed #d1d9e0;
|
| 77 |
+
border-radius: 8px;
|
| 78 |
+
padding: 48px;
|
| 79 |
+
text-align: center;
|
| 80 |
+
margin: 24px 0;
|
| 81 |
+
color: #656d76;
|
| 82 |
+
}
|
| 83 |
+
.grid {
|
| 84 |
+
display: grid;
|
| 85 |
+
grid-template-columns: 1fr 1fr;
|
| 86 |
+
gap: 24px;
|
| 87 |
+
margin: 24px 0;
|
| 88 |
+
}
|
| 89 |
+
.card {
|
| 90 |
+
background: #f8f9fa;
|
| 91 |
+
border: 1px solid #e1e8ed;
|
| 92 |
+
border-radius: 8px;
|
| 93 |
+
padding: 20px;
|
| 94 |
+
}
|
| 95 |
+
a { color: #0969da; text-decoration: none; }
|
| 96 |
+
a:hover { text-decoration: underline; }
|
| 97 |
+
.section { margin: 48px 0; }
|
| 98 |
</style>
|
| 99 |
</head>
|
| 100 |
<body>
|
| 101 |
+
<div class="header">
|
| 102 |
+
<h1>π΅ MagentaRT Research API</h1>
|
| 103 |
+
<p class="muted"><strong>AI Music Generation API</strong> β’ Real-time streaming β’ Custom fine-tuning support</p>
|
| 104 |
+
<span class="badge">Research Project</span>
|
| 105 |
+
</div>
|
| 106 |
+
|
| 107 |
+
<div class="demo-placeholder">
|
| 108 |
+
<h3>π± App Demo Video</h3>
|
| 109 |
+
<p>Demo video will be embedded here<br>
|
| 110 |
+
<small>Showing the iPhone app generating music in real-time</small></p>
|
| 111 |
+
</div>
|
| 112 |
+
|
| 113 |
+
<div class="section">
|
| 114 |
+
<h2>Overview</h2>
|
| 115 |
+
<p>This API powers AI music generation using Google's MagentaRT, designed for real-time audio streaming and custom model fine-tuning. Built for iOS app integration with WebSocket streaming support.</p>
|
| 116 |
+
|
| 117 |
+
<div class="info">
|
| 118 |
+
<strong>Hardware Requirements:</strong> Optimal performance requires an L40S GPU (48GB VRAM) for real-time streaming. L4 24GB works but may not maintain real-time performance.
|
| 119 |
+
</div>
|
| 120 |
+
</div>
|
| 121 |
+
|
| 122 |
+
<div class="section">
|
| 123 |
+
<h2>Quick Start - WebSocket Streaming</h2>
|
| 124 |
+
<p>Connect to <code>wss://<your-space>/ws/jam</code> for real-time audio generation:</p>
|
| 125 |
+
|
| 126 |
+
<h3>Start Real-time Generation</h3>
|
| 127 |
+
<pre><button class="copy-btn" onclick="copyCode(this)">Copy</button>{
|
| 128 |
"type": "start",
|
| 129 |
"mode": "rt",
|
| 130 |
"binary_audio": false,
|
| 131 |
"params": {
|
| 132 |
+
"styles": "electronic, ambient",
|
| 133 |
+
"style_weights": "1.0, 0.8",
|
| 134 |
"temperature": 1.1,
|
| 135 |
"topk": 40,
|
| 136 |
"guidance_weight": 1.1,
|
| 137 |
+
"pace": "realtime",
|
| 138 |
+
"style_ramp_seconds": 8.0,
|
| 139 |
+
"mean": 0.0,
|
| 140 |
+
"centroid_weights": "0.0, 0.0, 0.0"
|
| 141 |
}
|
| 142 |
}</pre>
|
| 143 |
+
|
| 144 |
+
<h3>Update Parameters Live</h3>
|
| 145 |
+
<pre><button class="copy-btn" onclick="copyCode(this)">Copy</button>{
|
| 146 |
"type": "update",
|
| 147 |
"styles": "jazz, hiphop",
|
| 148 |
+
"style_weights": "1.0, 0.8",
|
| 149 |
"temperature": 1.2,
|
| 150 |
"topk": 64,
|
| 151 |
"guidance_weight": 1.0,
|
| 152 |
+
"mean": 0.2,
|
| 153 |
+
"centroid_weights": "0.1, 0.3, 0.0"
|
| 154 |
}</pre>
|
| 155 |
+
|
| 156 |
+
<h3>Stop Generation</h3>
|
| 157 |
+
<pre><button class="copy-btn" onclick="copyCode(this)">Copy</button>{"type": "stop"}</pre>
|
| 158 |
+
</div>
|
| 159 |
+
|
| 160 |
+
<div class="section">
|
| 161 |
+
<h2>API Endpoints</h2>
|
| 162 |
+
|
| 163 |
+
<div class="endpoint">
|
| 164 |
+
<strong>POST /generate</strong> - Generate 4β8 bars of music with input audio
|
| 165 |
+
</div>
|
| 166 |
+
|
| 167 |
+
<div class="endpoint">
|
| 168 |
+
<strong>POST /generate_style</strong> - Generate music from style prompts only (experimental)
|
| 169 |
+
</div>
|
| 170 |
+
|
| 171 |
+
<div class="endpoint">
|
| 172 |
+
<strong>POST /jam/start</strong> - Start continuous jamming session
|
| 173 |
+
</div>
|
| 174 |
+
|
| 175 |
+
<div class="endpoint">
|
| 176 |
+
<strong>GET /jam/next</strong> - Get next audio chunk from session
|
| 177 |
+
</div>
|
| 178 |
+
|
| 179 |
+
<div class="endpoint">
|
| 180 |
+
<strong>POST /jam/consume</strong> - Mark chunk as consumed
|
| 181 |
+
</div>
|
| 182 |
+
|
| 183 |
+
<div class="endpoint">
|
| 184 |
+
<strong>POST /jam/stop</strong> - End jamming session
|
| 185 |
+
</div>
|
| 186 |
+
|
| 187 |
+
<div class="endpoint">
|
| 188 |
+
<strong>WEBSOCKET /ws/jam</strong> - Real-time streaming interface
|
| 189 |
+
</div>
|
| 190 |
+
|
| 191 |
+
<div class="endpoint">
|
| 192 |
+
<strong>POST /model/select</strong> - Switch between base and fine-tuned models
|
| 193 |
+
</div>
|
| 194 |
+
</div>
|
| 195 |
+
|
| 196 |
+
<div class="section">
|
| 197 |
+
<h2>Custom Fine-Tuning</h2>
|
| 198 |
+
<p>Train your own MagentaRT models and use them with this API and the iOS app.</p>
|
| 199 |
+
|
| 200 |
+
<div class="grid">
|
| 201 |
+
<div class="card">
|
| 202 |
+
<h3>1. Train Your Model</h3>
|
| 203 |
+
<p>Use the official MagentaRT fine-tuning notebook:</p>
|
| 204 |
+
<p><a href="https://github.com/magenta-realtime/notebooks/blob/main/Magenta_RT_Finetune.ipynb" target="_blank">π MagentaRT Fine-tuning Colab</a></p>
|
| 205 |
+
<p>This will create checkpoint folders like:</p>
|
| 206 |
+
<ul>
|
| 207 |
+
<li><code>checkpoint_1861001/</code></li>
|
| 208 |
+
<li><code>checkpoint_1862001/</code></li>
|
| 209 |
+
<li>And steering assets: <code>cluster_centroids.npy</code>, <code>mean_style_embed.npy</code></li>
|
| 210 |
+
</ul>
|
| 211 |
+
</div>
|
| 212 |
+
|
| 213 |
+
<div class="card">
|
| 214 |
+
<h3>2. Package Checkpoints</h3>
|
| 215 |
+
<p>Checkpoints must be compressed as .tgz files to preserve .zarray files correctly.</p>
|
| 216 |
+
<div class="warning">
|
| 217 |
+
<strong>Important:</strong> Do not download checkpoint folders directly from Google Drive - the .zarray files won't transfer properly.
|
| 218 |
+
</div>
|
| 219 |
+
</div>
|
| 220 |
+
</div>
|
| 221 |
+
|
| 222 |
+
<h3>Checkpoint Packaging Script</h3>
|
| 223 |
+
<p>Use this in a Colab cell to properly package your checkpoints:</p>
|
| 224 |
+
<pre><button class="copy-btn" onclick="copyCode(this)">Copy</button># Mount Drive to access your trained checkpoints
|
| 225 |
+
from google.colab import drive
|
| 226 |
+
drive.mount('/content/drive')
|
| 227 |
+
|
| 228 |
+
# Set the path to your checkpoint folder
|
| 229 |
+
CKPT_SRC = '/content/drive/MyDrive/thepatch/checkpoint_1862001' # Adjust path
|
| 230 |
+
|
| 231 |
+
# Copy folder to local storage (preserves dotfiles)
|
| 232 |
+
!rm -rf /content/checkpoint_1862001
|
| 233 |
+
!cp -a "$CKPT_SRC" /content/
|
| 234 |
+
|
| 235 |
+
# Verify .zarray files are present
|
| 236 |
+
!find /content/checkpoint_1862001 -name .zarray | wc -l
|
| 237 |
+
|
| 238 |
+
# Create properly formatted .tgz archive
|
| 239 |
+
!tar -C /content -czf /content/checkpoint_1862001.tgz checkpoint_1862001
|
| 240 |
+
|
| 241 |
+
# Verify critical files are in the archive
|
| 242 |
+
!tar -tzf /content/checkpoint_1862001.tgz | grep -c '.zarray'
|
| 243 |
+
|
| 244 |
+
# Download the .tgz file
|
| 245 |
+
from google.colab import files
|
| 246 |
+
files.download('/content/checkpoint_1862001.tgz')</pre>
|
| 247 |
+
|
| 248 |
+
<h3>3. Upload to Hugging Face</h3>
|
| 249 |
+
<p>Create a model repository and upload:</p>
|
| 250 |
+
<ul>
|
| 251 |
+
<li>Your <code>.tgz</code> checkpoint files</li>
|
| 252 |
+
<li><code>cluster_centroids.npy</code> (for steering)</li>
|
| 253 |
+
<li><code>mean_style_embed.npy</code> (for steering)</li>
|
| 254 |
+
</ul>
|
| 255 |
+
|
| 256 |
+
<div class="info">
|
| 257 |
+
<strong>Example Repository:</strong> <a href="https://huggingface.co/thepatch/magenta-ft" target="_blank">thepatch/magenta-ft</a><br>
|
| 258 |
+
Shows the correct file structure with .tgz files and .npy steering assets in the root directory.
|
| 259 |
+
</div>
|
| 260 |
+
|
| 261 |
+
<h3>4. Use in the App</h3>
|
| 262 |
+
<p>In the iOS app's model selector, point to your Hugging Face repository URL. The app will automatically discover available checkpoints and allow switching between them.</p>
|
| 263 |
+
</div>
|
| 264 |
+
|
| 265 |
+
<div class="section">
|
| 266 |
+
<h2>Technical Specifications</h2>
|
| 267 |
+
<ul>
|
| 268 |
+
<li><strong>Audio Format:</strong> 48 kHz stereo, ~2.0s chunks with ~40ms crossfade</li>
|
| 269 |
+
<li><strong>Model Sizes:</strong> Base and Large variants available</li>
|
| 270 |
+
<li><strong>Steering:</strong> Support for text prompts, audio embeddings, and centroid-based fine-tune steering</li>
|
| 271 |
+
<li><strong>Real-time Performance:</strong> L40S recommended; L4 may experience slight delays</li>
|
| 272 |
+
<li><strong>Memory Requirements:</strong> ~40GB VRAM for sustained real-time streaming</li>
|
| 273 |
+
</ul>
|
| 274 |
+
|
| 275 |
+
<div class="warning">
|
| 276 |
+
<strong>Note:</strong> The <code>/generate_style</code> endpoint is experimental and may not properly adhere to BPM without additional context (considering metronome-based context instead of silence).
|
| 277 |
+
</div>
|
| 278 |
+
</div>
|
| 279 |
+
|
| 280 |
+
<div class="section">
|
| 281 |
+
<h2>Integration with iOS App</h2>
|
| 282 |
+
<p>This API is designed to work seamlessly with our iOS music generation app:</p>
|
| 283 |
+
<ul>
|
| 284 |
+
<li>Real-time audio streaming via WebSockets</li>
|
| 285 |
+
<li>Dynamic model switching between base and fine-tuned models</li>
|
| 286 |
+
<li>Integration with stable-audio-open-small for combined input audio generation</li>
|
| 287 |
+
<li>Live parameter adjustment during generation</li>
|
| 288 |
+
</ul>
|
| 289 |
+
</div>
|
| 290 |
+
|
| 291 |
+
<div class="section">
|
| 292 |
+
<h2>Deployment</h2>
|
| 293 |
+
<p>To run your own instance:</p>
|
| 294 |
+
<ol>
|
| 295 |
+
<li>Duplicate this Hugging Face Space</li>
|
| 296 |
+
<li>Ensure you have access to an L40S GPU</li>
|
| 297 |
+
<li>Point your iOS app to the new space URL (e.g., <code>https://your-username-magenta-retry.hf.space</code>)</li>
|
| 298 |
+
<li>Upload your fine-tuned models as described above</li>
|
| 299 |
+
</ol>
|
| 300 |
+
</div>
|
| 301 |
+
|
| 302 |
+
<div class="section">
|
| 303 |
+
<h2>Support & Contact</h2>
|
| 304 |
+
<p>This is an active research project. For questions, technical support, or collaboration:</p>
|
| 305 |
+
<p><strong>Email:</strong> <a href="mailto:kev@thecollabagepatch.com">kev@thecollabagepatch.com</a></p>
|
| 306 |
+
|
| 307 |
+
<div class="info">
|
| 308 |
+
<strong>Research Status:</strong> This project is under active development. Features and API may change. We welcome feedback and contributions from the research community.
|
| 309 |
+
</div>
|
| 310 |
+
</div>
|
| 311 |
+
|
| 312 |
+
<div class="section">
|
| 313 |
+
<h2>Licensing</h2>
|
| 314 |
+
<p>Built on Google's MagentaRT (Apache 2.0 + CC-BY 4.0). Users are responsible for their generated outputs and ensuring compliance with applicable laws and platform policies.</p>
|
| 315 |
+
<p><a href="/docs">π API Reference Documentation</a></p>
|
| 316 |
+
</div>
|
| 317 |
+
|
| 318 |
+
<script>
|
| 319 |
+
function copyCode(button) {
|
| 320 |
+
const pre = button.parentElement;
|
| 321 |
+
const code = pre.textContent.replace('Copy', '').trim();
|
| 322 |
+
navigator.clipboard.writeText(code).then(() => {
|
| 323 |
+
button.textContent = 'Copied!';
|
| 324 |
+
setTimeout(() => button.textContent = 'Copy', 2000);
|
| 325 |
+
});
|
| 326 |
+
}
|
| 327 |
+
</script>
|
| 328 |
</body>
|
| 329 |
</html>
|