Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| FastAPI JASCO server (Postman-ready) with robust Hugging Face auth and | |
| safe handling of HF fast-transfer (hf_transfer). | |
| - If HF_HUB_ENABLE_HF_TRANSFER=1 but 'hf_transfer' isn't installed, we fall back | |
| to standard downloads and log a warning instead of failing. | |
| - POST /predict supports multipart (drums_file) and JSON (drums_b64). | |
| - GET /hf-status shows auth and model access. | |
| Run: | |
| export HUGGINGFACE_HUB_TOKEN=hf_xxx # or HF_TOKEN/HFTOKEN/HUGGINGFACEHUB_API_TOKEN | |
| uvicorn main:app --host 0.0.0.0 --port 7860 | |
| """ | |
| # ----------------------------- | |
| # Environment (HF Spaces-friendly) | |
| # ----------------------------- | |
| import os | |
| from pathlib import Path | |
| from requests import Request, Response | |
| from pydantic import BaseModel, Field | |
| import numpy as np | |
| from scipy.io import wavfile | |
| from fastapi.responses import FileResponse | |
| def _pick_cache_dir() -> Path: | |
| for c in [Path("/data/cache"), Path("/tmp/audiocraft_cache"), Path.cwd() / "cache"]: | |
| try: | |
| c.mkdir(parents=True, exist_ok=True) | |
| (c / ".w").touch(); (c / ".w").unlink() | |
| return c | |
| except Exception: | |
| pass | |
| return Path.cwd() | |
| CACHE_DIR = _pick_cache_dir() | |
| for sub in ["models", "huggingface", "transformers", "drum_cache", "cache"]: | |
| (CACHE_DIR / sub).mkdir(parents=True, exist_ok=True) | |
| os.environ["AUDIOCRAFT_CACHE_DIR"] = str(CACHE_DIR) | |
| os.environ["XDG_CACHE_HOME"] = str(CACHE_DIR) | |
| os.environ["TORCH_HOME"] = str(CACHE_DIR / "cache") | |
| os.environ["HF_HOME"] = str(CACHE_DIR / "huggingface") | |
| os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR / "transformers") | |
| os.environ["NUMBA_DISABLE_JIT"] = "1" | |
| # Do NOT force-enable fast transfer; handle it dynamically below. | |
| # os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # (removed) | |
| # ----------------------------- | |
| # Imports | |
| # ----------------------------- | |
| import io | |
| import re | |
| import json | |
| import wave | |
| import time | |
| import base64 | |
| import random | |
| import hashlib | |
| import zipfile | |
| from tempfile import NamedTemporaryFile | |
| from typing import Optional, List, Tuple, Union, Optional | |
| import numpy as np | |
| import torch | |
| from fastapi import FastAPI, Request, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| # Hugging Face auth helpers | |
| from huggingface_hub import login as hf_login, HfApi, HfFolder | |
| from huggingface_hub.utils import HfHubHTTPError | |
| # JASCO / AudioCraft | |
| from audiocraft.data.audio_utils import f32_pcm, normalize_audio | |
| from audiocraft.data.audio import audio_write | |
| from audiocraft.models import JASCO | |
| # ----------------------------- | |
| # App boilerplate | |
| # ----------------------------- | |
| app = FastAPI(title="JASCO /predict (HF auth)") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] | |
| ) | |
| # ----------------------------- | |
| # Hugging Face auth utilities | |
| # ----------------------------- | |
| def _get_hf_token() -> Optional[str]: | |
| for k in ("HUGGINGFACE_HUB_TOKEN", "HUGGINGFACEHUB_API_TOKEN", "HF_TOKEN", "HFTOKEN"): | |
| v = os.getenv(k) | |
| if v: | |
| return v.strip() | |
| return None | |
| HF_TOKEN = _get_hf_token() | |
| def ensure_hf_login(): | |
| """Login once; provide clear logging. No-op if no token (but gated models will fail).""" | |
| global HF_TOKEN | |
| if not HF_TOKEN: | |
| print("[HF] No token found in env (HUGGINGFACE_HUB_TOKEN / HUGGINGFACEHUB_API_TOKEN / HF_TOKEN / HFTOKEN).") | |
| return | |
| try: | |
| hf_login(token=HF_TOKEN, add_to_git_credential=False) | |
| HfFolder.save_token(HF_TOKEN) # persist under HF_HOME | |
| who = HfApi().whoami(token=HF_TOKEN) | |
| print(f"[HF] Logged in as: {who.get('name') or who.get('email') or who.get('username')}") | |
| except Exception as e: | |
| print(f"[HF] Login failed: {e}") | |
| def hf_status(): | |
| token = _get_hf_token() | |
| out = {"token_present": bool(token)} | |
| try: | |
| if token: | |
| who = HfApi().whoami(token=token) | |
| out["whoami"] = who | |
| else: | |
| out["whoami"] = None | |
| except Exception as e: | |
| out["whoami_error"] = str(e) | |
| model_id = "facebook/jasco-chords-drums-melody-400M" | |
| try: | |
| api = HfApi() | |
| info = api.model_info(model_id, token=token) if token else api.model_info(model_id) | |
| out["model_access"] = True | |
| out["model_private"] = getattr(info, "private", None) | |
| out["gated"] = bool(getattr(info, "gated", False)) | |
| except Exception as e: | |
| out["model_access"] = False | |
| out["error"] = str(e) | |
| return out | |
| # ----------------------------- | |
| # Chords helpers | |
| # ----------------------------- | |
| def _default_chord_map(): | |
| chords = [ | |
| "N","C","Cm","C7","Cmaj7","Cm7","D","Dm","D7","Dmaj7","Dm7", | |
| "E","Em","E7","Emaj7","Em7","F","Fm","F7","Fmaj7","Fm7", | |
| ] | |
| return {ch:i for i,ch in enumerate(chords)} | |
| def _validate_chord(ch: str, mapping: dict) -> str: | |
| return ch if ch in mapping else "UNK" | |
| def chords_string_to_list(chords: str): | |
| if not chords or chords.strip() == "": | |
| return [] | |
| try: | |
| clean = chords.replace("[", "").replace("]", "").replace(" ", "") | |
| pairs = re.findall(r"\(([^,]+),([^)]+)\)", clean) | |
| mapping = _default_chord_map() | |
| return [(_validate_chord(ch.strip(), mapping), float(t.strip())) for ch, t in pairs] | |
| except Exception: | |
| return [] | |
| # ----------------------------- | |
| # Audio decoding (WAV stdlib) | |
| # ----------------------------- | |
| def _read_wav_bytes(raw: Optional[bytes]) -> Tuple[int, Optional[torch.Tensor]]: | |
| if not raw: | |
| return 32000, None | |
| try: | |
| with wave.open(io.BytesIO(raw), "rb") as wf: | |
| sr = wf.getframerate() | |
| ch = wf.getnchannels() | |
| sw = wf.getsampwidth() | |
| frames = wf.getnframes() | |
| buf = wf.readframes(frames) | |
| if sw == 2: data = np.frombuffer(buf, dtype=np.int16).astype(np.float32) / 32768.0 | |
| elif sw == 1: data = (np.frombuffer(buf, dtype=np.uint8).astype(np.float32) - 128) / 128.0 | |
| elif sw == 4: data = np.frombuffer(buf, dtype=np.float32) | |
| else: return 32000, None | |
| if ch > 1: data = data.reshape(-1, ch).T | |
| else: data = data[None, :] | |
| drums = f32_pcm(torch.from_numpy(data)).t() | |
| if drums.dim() == 1: | |
| drums = drums[None] | |
| drums = normalize_audio(drums, "loudness", loudness_headroom_db=16, sample_rate=sr) | |
| return sr, drums | |
| except Exception as e: | |
| print(f"[audio] WAV decode failed: {e}") | |
| return 32000, None | |
| def _read_uploadfile_to_bytes(file: Optional[UploadFile]) -> Optional[bytes]: | |
| if file is None: | |
| return None | |
| try: | |
| return file.file.read() | |
| except Exception: | |
| return None | |
| def _read_b64_to_bytes(b64str: Optional[str]) -> Optional[bytes]: | |
| if not b64str: | |
| return None | |
| try: | |
| s = b64str.strip() | |
| if s.startswith("data:"): | |
| s = s.split(",", 1)[1] | |
| return base64.b64decode(s) | |
| except Exception: | |
| return None | |
| # ----------------------------- | |
| # Model | |
| # ----------------------------- | |
| MODEL = None | |
| def _ensure_mapping_file() -> Path: | |
| import pickle | |
| mapping_file = CACHE_DIR / "chord_to_index_mapping.pkl" | |
| if not mapping_file.exists(): | |
| with open(mapping_file, "wb") as f: | |
| pickle.dump(_default_chord_map(), f) | |
| return mapping_file | |
| def load_model(name: str): | |
| """ | |
| Load JASCO, ensuring HF auth for gated repos. | |
| Falls back if hf_transfer is unavailable. | |
| """ | |
| global MODEL | |
| if MODEL is not None and getattr(MODEL, "name", None) == name: | |
| return MODEL | |
| # Ensure HF login | |
| ensure_hf_login() | |
| # Preflight access for clearer errors | |
| try: | |
| api = HfApi() | |
| token = _get_hf_token() | |
| _ = api.model_info(name, token=token) if token else api.model_info(name) | |
| except HfHubHTTPError as e: | |
| msg = ( | |
| f"Cannot access model '{name}'. This repo may be gated or private.\n" | |
| f"- Ensure your token has access and terms are accepted.\n" | |
| f"- Provide token via HUGGINGFACE_HUB_TOKEN (or HF_TOKEN/HFTOKEN/HUGGINGFACEHUB_API_TOKEN).\n" | |
| f"Hugging Face error: {e}" | |
| ) | |
| raise HTTPException(status_code=401, detail=msg) | |
| cache_path = CACHE_DIR / name.replace("/", "_") | |
| cache_path.mkdir(parents=True, exist_ok=True) | |
| os.environ["AUDIOCRAFT_CACHE_DIR"] = str(cache_path) | |
| os.environ["TRANSFORMERS_CACHE"] = str(cache_path / "transformers") | |
| mapping_file = _ensure_mapping_file() | |
| try: | |
| model = JASCO.get_pretrained(name, device="cpu", chords_mapping_path=str(mapping_file)) | |
| model.name = name | |
| import pickle | |
| if not hasattr(model, "chord_to_index"): | |
| with open(mapping_file, "rb") as f: | |
| model.chord_to_index = pickle.load(f) | |
| except HfHubHTTPError as e: | |
| raise HTTPException(status_code=401, detail=f"Model load failed due to HF auth/access: {e}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Model load failed11: {e}") | |
| MODEL = model | |
| return MODEL | |
| def set_gen_params(model, **kwargs): | |
| valid = None | |
| if hasattr(model, "get_generation_params"): | |
| try: | |
| valid = set(model.get_generation_params().keys()) | |
| except Exception: | |
| pass | |
| filtered, unknown = {}, [] | |
| for k, v in kwargs.items(): | |
| if valid is not None and k not in valid: | |
| unknown.append(k) | |
| else: | |
| filtered[k] = v | |
| print(f"[gen] request={kwargs}") | |
| if valid is not None: | |
| print(f"[gen] applied={filtered} unknown={unknown}") | |
| model.set_generation_params(**filtered) | |
| def _tensor_fp(t: Optional[torch.Tensor]) -> str: | |
| if t is None: | |
| return "NONE" | |
| try: | |
| x = t.detach().cpu().contiguous().float() | |
| return hashlib.sha1(x.numpy().tobytes()).hexdigest()[:10] | |
| except Exception: | |
| return "ERR" | |
| # ----------------------------- | |
| # Endpoints | |
| # ----------------------------- | |
| def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model_loaded": MODEL is not None, | |
| "cache_dir": str(CACHE_DIR), | |
| } | |
| _TEXT_KEYS = ["text", "prompt", "description", "query", "message", "input", "content"] | |
| class PredictRequest(BaseModel): | |
| model: str = Field(default="facebook/jasco-chords-drums-melody-400M") | |
| text: str = "" | |
| chords_sym: str = "" | |
| n_samples: int = Field(default=2) | |
| seed: Optional[int] = Field(default=None) | |
| cfg_coef_all: float = Field(default=1.25) | |
| cfg_coef_txt: float = 2.5 | |
| ode_rtol: float = 1e-4 | |
| ode_atol: float = 1e-4 | |
| ode_solver: str = "euler" | |
| ode_steps: int = 10 | |
| drums_b64: Optional[str] = None | |
| drums_upload: Optional[UploadFile] = None | |
| class PredictResponse(BaseModel): | |
| status: str = "success" | |
| message: Optional[str] = None | |
| data: Optional[dict] = None | |
| def tensor_to_wav_scipy(tensor, sample_rate, filename): | |
| # Convert to numpy and ensure correct format | |
| audio_data = tensor.detach().cpu().numpy() | |
| # Normalize to 16-bit range | |
| audio_data = np.clip(audio_data, -1.0, 1.0) | |
| audio_data = (audio_data * 32767).astype(np.int16) | |
| # Save as WAV | |
| wavfile.write(filename, sample_rate, audio_data) | |
| class FileCleaner: | |
| def __init__(self, file_lifetime: float = 3600): | |
| self.file_lifetime = file_lifetime | |
| self.files = [] | |
| def add(self, path: Union[str, Path]): | |
| self._cleanup() | |
| self.files.append((time.time(), Path(path))) | |
| def _cleanup(self): | |
| now = time.time() | |
| for time_added, path in list(self.files): | |
| if now - time_added > self.file_lifetime: | |
| if path.exists(): | |
| path.unlink() | |
| self.files.pop(0) | |
| else: | |
| break | |
| file_cleaner = FileCleaner() | |
| async def predict(request: Request): | |
| """ | |
| Returns a ZIP with jasco_1.wav, jasco_2.wav, ... | |
| Accepts: | |
| - multipart/form-data (fields + optional drums_file) | |
| - application/json (fields + optional drums_b64) | |
| """ | |
| ct = (request.headers.get("content-type") or "application/json").lower() | |
| params = { | |
| "model": "facebook/jasco-chords-drums-melody-400M", | |
| "text": "", | |
| "chords_sym": "", | |
| "n_samples": 1, | |
| "seed": None, | |
| "cfg_coef_all": 1.25, | |
| "cfg_coef_txt": 2.5, | |
| "ode_rtol": 1e-4, | |
| "ode_atol": 1e-4, | |
| "ode_solver": "euler", | |
| "ode_steps": 10, | |
| "drums_b64": None | |
| } | |
| drums_upload: Optional[UploadFile] = None | |
| try: | |
| if ct == "application/json": | |
| data = await request.json() | |
| if not isinstance(data, dict): | |
| raise HTTPException(status_code=400, detail="JSON body must be an object") | |
| for k in _TEXT_KEYS: | |
| if k in data and data[k]: | |
| params["text"] = data[k]; break | |
| params["model"] = data.get("model", params["model"]) | |
| params["chords_sym"] = data.get("chords_sym", params["chords_sym"]) | |
| params["n_samples"] = int(data.get("n_samples", params["n_samples"])) | |
| params["seed"] = data.get("seed", None) | |
| params["cfg_coef_all"] = float(data.get("cfg_coef_all", params["cfg_coef_all"])) | |
| params["cfg_coef_txt"] = float(data.get("cfg_coef_txt", params["cfg_coef_txt"])) | |
| params["ode_rtol"] = float(data.get("ode_rtol", params["ode_rtol"])) | |
| params["ode_atol"] = float(data.get("ode_atol", params["ode_atol"])) | |
| params["ode_solver"] = str(data.get("ode_solver", params["ode_solver"])).lower() | |
| params["ode_steps"] = int(data.get("ode_steps", params["ode_steps"])) | |
| params["drums_b64"] = data.get("drums_b64", None) | |
| raw_drums = _read_b64_to_bytes(params["drums_b64"]) | |
| else: | |
| form = await request.form() | |
| fd = {k: (form.get(k)) for k in form.keys() if form.get(k)} | |
| for k in _TEXT_KEYS: | |
| if k in fd and fd[k]: | |
| params["text"] = fd[k]; break | |
| params["model"] = fd.get("model", params["model"]) | |
| params["chords_sym"] = fd.get("chords_sym", params["chords_sym"]) | |
| params["n_samples"] = int(fd.get("n_samples", params["n_samples"])) | |
| params["seed"] = fd.get("seed", None) | |
| params["cfg_coef_all"] = float(fd.get("cfg_coef_all", params["cfg_coef_all"])) | |
| params["cfg_coef_txt"] = float(fd.get("cfg_coef_txt", params["cfg_coef_txt"])) | |
| params["ode_rtol"] = float(fd.get("ode_rtol", params["ode_rtol"])) | |
| params["ode_atol"] = float(fd.get("ode_atol", params["ode_atol"])) | |
| params["ode_solver"] = str(fd.get("ode_solver", params["ode_solver"])).lower() | |
| params["ode_steps"] = int(fd.get("ode_steps", params["ode_steps"])) | |
| params["drums_b64"] = fd.get("drums_b64", None) | |
| drums_upload = form.get("drums_file") | |
| raw_drums = _read_uploadfile_to_bytes(drums_upload) if drums_upload else _read_b64_to_bytes(params["drums_b64"]) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Bad request: {e}") | |
| print(json.dumps({ | |
| "ct": ct, | |
| "text_len": len(params["text"] or ""), | |
| "text_preview": (params["text"] or "")[:120], | |
| "model": params["model"], | |
| "n_samples": params["n_samples"], | |
| "has_drums_bytes": raw_drums is not None, | |
| })) | |
| model = load_model(params["model"]) # may raise HTTPException(401/500) | |
| drums_sr, drums_tensor = _read_wav_bytes(raw_drums) | |
| print(f"[predict] drums_present={drums_tensor is not None} sr={drums_sr} drums_fp={_tensor_fp(drums_tensor)}") | |
| base_seed = int(params["seed"]) if params["seed"] is not None else (int(time.time() * 1000) & 0xFFFFFFFF) | |
| random.seed(base_seed); np.random.seed(base_seed); torch.manual_seed(base_seed) | |
| if torch.cuda.is_available(): torch.cuda.manual_seed_all(base_seed) | |
| set_gen_params( | |
| model, | |
| cfg_coef_all=float(params["cfg_coef_all"]), | |
| cfg_coef_txt=float(params["cfg_coef_txt"]), | |
| ode_rtol=float(params["ode_rtol"]), | |
| ode_atol=float(params["ode_atol"]), | |
| euler=(params["ode_solver"] == "euler"), | |
| euler_steps=int(params["ode_steps"]) | |
| ) | |
| texts = [params["text"]] * max(1, int(params["n_samples"])) | |
| chords_list = chords_string_to_list(params["chords_sym"]) | |
| print(f"[predictdebug] chords_list={chords_list}") | |
| print(f"[predictdebug] drums_tensor={drums_tensor}") | |
| print(f"[predictdebug] drums_sr={drums_sr}") | |
| print(f"[predictdebug] model={model}") | |
| print(f"[predictdebug] texts={texts}") | |
| try: | |
| outputs = model.generate_music( | |
| descriptions=texts, | |
| chords=chords_list, | |
| drums_wav=drums_tensor, | |
| melody_salience_matrix=None, | |
| drums_sample_rate=drums_sr, | |
| progress=False | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {e}") | |
| # Usage: | |
| # for i, wav in enumerate(outputs): | |
| # with NamedTemporaryFile("wb", suffix=".wav", delete=False) as f: | |
| # tensor_to_wav_scipy(wav, model.sample_rate, f.name) | |
| # zf.write(f.name, arcname=f"jasco_{i+1}.wav") | |
| print(f"[predictdebug] outputs={outputs}") # Log the raw model outputs | |
| # Convert model outputs from GPU tensor to CPU float tensor for processing | |
| outputs = outputs.detach().cpu().float() | |
| print(f"[predictdebug] outputs converted to cpu={outputs}") # Log the converted outputs | |
| with NamedTemporaryFile("wb", suffix=".wav", delete=False) as f: | |
| tmp_path = f.name | |
| audio_write( | |
| tmp_path, | |
| outputs[0], | |
| MODEL.sample_rate, # or model.sample_rate β be consistent | |
| strategy="loudness", | |
| loudness_headroom_db=16, | |
| loudness_compressor=True, | |
| add_suffix=False, | |
| ) | |
| return FileResponse( | |
| path=tmp_path, | |
| media_type="audio/wav", | |
| filename="jasco_output.wav" | |
| ) | |
| if __name__ == "__main__": | |
| ensure_hf_login() | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |
| # outputs = [1,2] | |
| # outputs[1] = [name, wav] | |
| # wav =[0.39203242, ] | |