Upload folder using huggingface_hub
Browse files- .DS_Store +0 -0
- .gitattributes +1 -0
- Beethoven_WoO80_var27_8bars_3_15.wav +3 -0
- Dockerfile +57 -0
- __pycache__/constants.cpython-312.pyc +0 -0
- __pycache__/handler.cpython-312.pyc +0 -0
- checkpoints/.gitignore +1 -0
- checkpoints/fold0/best.ckpt +3 -0
- checkpoints/fold1/best.ckpt +3 -0
- checkpoints/fold2/best.ckpt +3 -0
- checkpoints/fold3/best.ckpt +3 -0
- checkpoints/fold_0/best.ckpt +3 -0
- checkpoints/fold_1/best.ckpt +3 -0
- checkpoints/fold_2/best.ckpt +3 -0
- checkpoints/fold_3/best.ckpt +3 -0
- constants.py +69 -0
- handler.py +247 -0
- models/__init__.py +27 -0
- models/__pycache__/__init__.cpython-312.pyc +0 -0
- models/__pycache__/inference.cpython-312.pyc +0 -0
- models/__pycache__/loader.cpython-312.pyc +0 -0
- models/calibration.py +119 -0
- models/inference.py +79 -0
- models/loader.py +189 -0
- preprocessing/__init__.py +15 -0
- preprocessing/__pycache__/__init__.cpython-312.pyc +0 -0
- preprocessing/__pycache__/audio.cpython-312.pyc +0 -0
- preprocessing/audio.py +130 -0
- requirements.txt +19 -0
- sync_checkpoints.sh +51 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Beethoven_WoO80_var27_8bars_3_15.wav filter=lfs diff=lfs merge=lfs -text
|
Beethoven_WoO80_var27_8bars_3_15.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:46cbac4d34b53eb2cd7bc83e0d2e1f44dbecb02a5471b6e0ca444d0ff29251c2
|
| 3 |
+
size 2531508
|
Dockerfile
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# M1c MuQ L9-12 Inference Handler
|
| 2 |
+
# HuggingFace Inference Endpoints container for piano performance analysis
|
| 3 |
+
|
| 4 |
+
FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04
|
| 5 |
+
|
| 6 |
+
# Prevent interactive prompts
|
| 7 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 8 |
+
|
| 9 |
+
# Install system dependencies
|
| 10 |
+
RUN apt-get update && apt-get install -y \
|
| 11 |
+
python3.11 \
|
| 12 |
+
python3.11-venv \
|
| 13 |
+
ffmpeg \
|
| 14 |
+
libsndfile1 \
|
| 15 |
+
git \
|
| 16 |
+
curl \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
# Set Python 3.11 as default
|
| 20 |
+
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 && \
|
| 21 |
+
update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
|
| 22 |
+
|
| 23 |
+
# Install uv
|
| 24 |
+
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 25 |
+
ENV PATH="/root/.local/bin:$PATH"
|
| 26 |
+
|
| 27 |
+
WORKDIR /app
|
| 28 |
+
|
| 29 |
+
# Install Python dependencies with uv
|
| 30 |
+
COPY requirements.txt .
|
| 31 |
+
RUN uv pip install --system --no-cache -r requirements.txt
|
| 32 |
+
|
| 33 |
+
# Pre-download HuggingFace models (cached in image)
|
| 34 |
+
# MuQ only
|
| 35 |
+
RUN python3 -c "\
|
| 36 |
+
print('Downloading MuQ-large-msd-iter...'); \
|
| 37 |
+
from muq import MuQ; \
|
| 38 |
+
MuQ.from_pretrained('OpenMuQ/MuQ-large-msd-iter'); \
|
| 39 |
+
print('Done!'); \
|
| 40 |
+
"
|
| 41 |
+
|
| 42 |
+
# Copy application code
|
| 43 |
+
COPY constants.py .
|
| 44 |
+
COPY handler.py .
|
| 45 |
+
COPY models/ ./models/
|
| 46 |
+
COPY preprocessing/ ./preprocessing/
|
| 47 |
+
|
| 48 |
+
# Create checkpoints directory structure
|
| 49 |
+
RUN mkdir -p /app/checkpoints/fold0 /app/checkpoints/fold1 /app/checkpoints/fold2 /app/checkpoints/fold3
|
| 50 |
+
|
| 51 |
+
# Set environment variables
|
| 52 |
+
ENV PYTHONUNBUFFERED=1
|
| 53 |
+
ENV TRANSFORMERS_CACHE=/app/.cache/huggingface
|
| 54 |
+
ENV HF_HOME=/app/.cache/huggingface
|
| 55 |
+
|
| 56 |
+
# HuggingFace Inference Endpoints expects handler.py
|
| 57 |
+
# The EndpointHandler class will be automatically detected
|
__pycache__/constants.cpython-312.pyc
ADDED
|
Binary file (1.63 kB). View file
|
|
|
__pycache__/handler.cpython-312.pyc
ADDED
|
Binary file (9.12 kB). View file
|
|
|
checkpoints/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*
|
checkpoints/fold0/best.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b0fe082f7928a1d180313643d21ec83c82f392c01a4e8a2595cb55607306084
|
| 3 |
+
size 15869013
|
checkpoints/fold1/best.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee6e94e1367da7999d4ee9f68de8e31724b7957e0571d421a8087d42cec8a9c8
|
| 3 |
+
size 15869013
|
checkpoints/fold2/best.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:551924ff5d2692c129b86f750798d9532a288bc3e1d56dca2b897fe73672c717
|
| 3 |
+
size 15869013
|
checkpoints/fold3/best.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68188c86e5190e9a7663a69a013e7e9cc5cb793d177b8f1fb6252dae46607f24
|
| 3 |
+
size 15869013
|
checkpoints/fold_0/best.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:edeed3bb972341076e06f2bacebfb40fd9ed5cba5e9f8de8f9bfe4e12c586d47
|
| 3 |
+
size 637877835
|
checkpoints/fold_1/best.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b53550d16677918954191bf2228a3210986f5ddfcd48010f0272e981eab838cb
|
| 3 |
+
size 637877899
|
checkpoints/fold_2/best.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:027592ad5a7a31ad1ae1628944842e3b5c4da96bb7f16a955f266a4b41a7b20b
|
| 3 |
+
size 637877899
|
checkpoints/fold_3/best.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb0de4d0aa81cae2ee8fe0b6a40a634624e75567443b343602b55d21fd80c49f
|
| 3 |
+
size 637877899
|
constants.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Constants for A1-Max MuQ LoRA inference handler."""
|
| 2 |
+
|
| 3 |
+
PERCEPIANO_DIMENSIONS = [
|
| 4 |
+
"dynamics",
|
| 5 |
+
"timing",
|
| 6 |
+
"pedaling",
|
| 7 |
+
"articulation",
|
| 8 |
+
"phrasing",
|
| 9 |
+
"interpretation",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
# A1-Max model configuration
|
| 13 |
+
# MuQ embeddings (1024 dim) with attention pooling -> encoder -> regression head
|
| 14 |
+
MODEL_CONFIG = {
|
| 15 |
+
# MuQ configuration (layers to average)
|
| 16 |
+
"muq_layer_start": 9,
|
| 17 |
+
"muq_layer_end": 13, # Exclusive (layers 9, 10, 11, 12)
|
| 18 |
+
"muq_dim": 1024, # Per-layer hidden size (= input_dim)
|
| 19 |
+
# Head configuration
|
| 20 |
+
"input_dim": 1024,
|
| 21 |
+
"hidden_dim": 512,
|
| 22 |
+
"num_labels": 6,
|
| 23 |
+
"dropout": 0.2,
|
| 24 |
+
# Audio processing
|
| 25 |
+
"target_sr": 24000,
|
| 26 |
+
"max_frames": 1000,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
# Model info for response
|
| 30 |
+
MODEL_INFO = {
|
| 31 |
+
"name": "A1-Max MuQ LoRA",
|
| 32 |
+
"type": "audio-muq-lora",
|
| 33 |
+
"pairwise": 0.7872,
|
| 34 |
+
"description": "A1-Max: MuQ + LoRA with ListMLE, CCC, mixup, hard negative mining",
|
| 35 |
+
"architecture": "MuQLoRAMaxModel (MuQ L9-12 avg -> attn pool -> encoder -> 6-dim regression)",
|
| 36 |
+
"best_config": "A1max_r32_L7-12_ls0.1",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Number of folds for ensemble
|
| 40 |
+
N_FOLDS = 4
|
| 41 |
+
|
| 42 |
+
# MAESTRO calibration stats: per-dimension distribution over 24,321 professional segments.
|
| 43 |
+
# Computed by model/scripts/compute_maestro_calibration.py using A1-Max 4-fold ensemble.
|
| 44 |
+
MAESTRO_CALIBRATION = {
|
| 45 |
+
"dynamics": {
|
| 46 |
+
"mean": 0.560947, "std": 0.021063,
|
| 47 |
+
"p5": 0.526612, "p25": 0.546136, "p50": 0.560859, "p75": 0.575372, "p95": 0.59573,
|
| 48 |
+
},
|
| 49 |
+
"timing": {
|
| 50 |
+
"mean": 0.531883, "std": 0.028791,
|
| 51 |
+
"p5": 0.480467, "p25": 0.512976, "p50": 0.534302, "p75": 0.552652, "p95": 0.575376,
|
| 52 |
+
},
|
| 53 |
+
"pedaling": {
|
| 54 |
+
"mean": 0.590465, "std": 0.030438,
|
| 55 |
+
"p5": 0.534399, "p25": 0.572243, "p50": 0.593731, "p75": 0.611854, "p95": 0.635053,
|
| 56 |
+
},
|
| 57 |
+
"articulation": {
|
| 58 |
+
"mean": 0.553624, "std": 0.014287,
|
| 59 |
+
"p5": 0.53023, "p25": 0.543792, "p50": 0.553554, "p75": 0.563275, "p95": 0.577426,
|
| 60 |
+
},
|
| 61 |
+
"phrasing": {
|
| 62 |
+
"mean": 0.550866, "std": 0.013717,
|
| 63 |
+
"p5": 0.528466, "p25": 0.541541, "p50": 0.550801, "p75": 0.560116, "p95": 0.573567,
|
| 64 |
+
},
|
| 65 |
+
"interpretation": {
|
| 66 |
+
"mean": 0.564377, "std": 0.023457,
|
| 67 |
+
"p5": 0.522434, "p25": 0.549302, "p50": 0.566195, "p75": 0.580981, "p95": 0.599733,
|
| 68 |
+
},
|
| 69 |
+
}
|
handler.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HuggingFace Inference Endpoints handler for piano performance analysis.
|
| 2 |
+
|
| 3 |
+
A1-Max MuQ LoRA model using MuQ layers 9-12 with attention pooling.
|
| 4 |
+
Returns 6-dimension performance evaluation scores:
|
| 5 |
+
dynamics, timing, pedaling, articulation, phrasing, interpretation.
|
| 6 |
+
|
| 7 |
+
Compatible with HuggingFace Inference Endpoints custom handler pattern.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import base64
|
| 11 |
+
import time
|
| 12 |
+
import traceback
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any, Dict, Union
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from constants import MODEL_INFO, PERCEPIANO_DIMENSIONS
|
| 19 |
+
from models.loader import get_model_cache
|
| 20 |
+
from models.inference import (
|
| 21 |
+
extract_muq_embeddings,
|
| 22 |
+
predict_with_ensemble,
|
| 23 |
+
)
|
| 24 |
+
from preprocessing.audio import (
|
| 25 |
+
AudioDownloadError,
|
| 26 |
+
AudioProcessingError,
|
| 27 |
+
download_and_preprocess_audio,
|
| 28 |
+
preprocess_audio_from_bytes,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class EndpointHandler:
|
| 33 |
+
"""HuggingFace Inference Endpoints handler for piano performance analysis."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, path: str = ""):
|
| 36 |
+
"""Initialize MuQ model and prediction heads.
|
| 37 |
+
|
| 38 |
+
Called once when the endpoint container starts.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
path: Path to the model repository (provided by HF Inference Endpoints).
|
| 42 |
+
Contains the checkpoints/ directory with model weights.
|
| 43 |
+
"""
|
| 44 |
+
print(f"Initializing A1-Max EndpointHandler with path: {path}")
|
| 45 |
+
|
| 46 |
+
# Determine checkpoint directory
|
| 47 |
+
# HF Inference Endpoints mount the repo at the provided path
|
| 48 |
+
# Fall back to /repository (HF default) or current dir for local testing
|
| 49 |
+
if path:
|
| 50 |
+
model_path = Path(path)
|
| 51 |
+
else:
|
| 52 |
+
model_path = Path("/repository")
|
| 53 |
+
if not model_path.exists():
|
| 54 |
+
model_path = Path(".")
|
| 55 |
+
|
| 56 |
+
checkpoint_dir = model_path / "checkpoints"
|
| 57 |
+
if not checkpoint_dir.exists():
|
| 58 |
+
# Try /app/checkpoints for backward compatibility
|
| 59 |
+
checkpoint_dir = Path("/app/checkpoints")
|
| 60 |
+
|
| 61 |
+
print(f"Using checkpoint directory: {checkpoint_dir}")
|
| 62 |
+
|
| 63 |
+
# Initialize model cache (loads MuQ and prediction heads)
|
| 64 |
+
self._cache = get_model_cache()
|
| 65 |
+
self._cache.initialize(device="cuda", checkpoint_dir=checkpoint_dir)
|
| 66 |
+
|
| 67 |
+
print("A1-Max EndpointHandler initialization complete!")
|
| 68 |
+
|
| 69 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 70 |
+
"""Process inference request.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
data: Request payload. Supports two formats:
|
| 74 |
+
|
| 75 |
+
HuggingFace format:
|
| 76 |
+
{
|
| 77 |
+
"inputs": "<base64-audio>" or {"audio_url": "..."},
|
| 78 |
+
"parameters": {
|
| 79 |
+
"max_duration_seconds": 300
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
Legacy RunPod format (for backward compatibility):
|
| 84 |
+
{
|
| 85 |
+
"input": {
|
| 86 |
+
"audio_url": "https://...",
|
| 87 |
+
"options": {...}
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Prediction results:
|
| 93 |
+
{
|
| 94 |
+
"predictions": {"timing": 0.85, ...},
|
| 95 |
+
"model_info": {"name": "M1c-MuQ-L9-12", "r2": 0.539},
|
| 96 |
+
"audio_duration_seconds": 180.5,
|
| 97 |
+
"processing_time_ms": 1234
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
Or error:
|
| 101 |
+
{
|
| 102 |
+
"error": {"code": "...", "message": "..."}
|
| 103 |
+
}
|
| 104 |
+
"""
|
| 105 |
+
start_time = time.time()
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
# Parse input - support both HF and legacy RunPod formats
|
| 109 |
+
inputs, parameters = self._parse_request(data)
|
| 110 |
+
|
| 111 |
+
# Extract parameters
|
| 112 |
+
max_duration = parameters.get("max_duration_seconds", 300)
|
| 113 |
+
|
| 114 |
+
# Load and preprocess audio
|
| 115 |
+
audio, duration = self._load_audio(inputs, max_duration)
|
| 116 |
+
print(f"Audio loaded: {duration:.1f}s")
|
| 117 |
+
|
| 118 |
+
# Verify models are loaded
|
| 119 |
+
if not self._cache.muq_model:
|
| 120 |
+
return {
|
| 121 |
+
"error": {
|
| 122 |
+
"code": "MODEL_NOT_LOADED",
|
| 123 |
+
"message": "MuQ model not initialized",
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
# Extract MuQ embeddings (averaged layers 9-12)
|
| 128 |
+
print("Extracting MuQ embeddings (layers 9-12)...")
|
| 129 |
+
embeddings = extract_muq_embeddings(audio, self._cache)
|
| 130 |
+
print(f"MuQ embeddings shape: {embeddings.shape}")
|
| 131 |
+
|
| 132 |
+
# Get ensemble predictions (4-fold A1-Max)
|
| 133 |
+
print("Running A1-Max ensemble inference...")
|
| 134 |
+
predictions = predict_with_ensemble(embeddings, self._cache)
|
| 135 |
+
|
| 136 |
+
# Build response
|
| 137 |
+
processing_time_ms = int((time.time() - start_time) * 1000)
|
| 138 |
+
|
| 139 |
+
result = {
|
| 140 |
+
"predictions": self._predictions_to_dict(predictions),
|
| 141 |
+
"model_info": {
|
| 142 |
+
"name": MODEL_INFO["name"],
|
| 143 |
+
"type": MODEL_INFO["type"],
|
| 144 |
+
"pairwise": MODEL_INFO["pairwise"],
|
| 145 |
+
"architecture": MODEL_INFO["architecture"],
|
| 146 |
+
"ensemble_folds": len(self._cache.muq_heads),
|
| 147 |
+
},
|
| 148 |
+
"audio_duration_seconds": duration,
|
| 149 |
+
"processing_time_ms": processing_time_ms,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
print(f"Inference complete in {processing_time_ms}ms")
|
| 153 |
+
return result
|
| 154 |
+
|
| 155 |
+
except AudioDownloadError as e:
|
| 156 |
+
return {
|
| 157 |
+
"error": {
|
| 158 |
+
"code": "AUDIO_DOWNLOAD_FAILED",
|
| 159 |
+
"message": str(e),
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
except AudioProcessingError as e:
|
| 164 |
+
return {
|
| 165 |
+
"error": {
|
| 166 |
+
"code": "AUDIO_PROCESSING_FAILED",
|
| 167 |
+
"message": str(e),
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
return {
|
| 173 |
+
"error": {
|
| 174 |
+
"code": "INFERENCE_ERROR",
|
| 175 |
+
"message": str(e),
|
| 176 |
+
"traceback": traceback.format_exc(),
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
def _parse_request(self, data: Dict[str, Any]) -> tuple:
|
| 181 |
+
"""Parse request data supporting both HF and legacy formats.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Tuple of (inputs, parameters)
|
| 185 |
+
"""
|
| 186 |
+
# HF format: {"inputs": ..., "parameters": ...}
|
| 187 |
+
if "inputs" in data:
|
| 188 |
+
inputs = data["inputs"]
|
| 189 |
+
parameters = data.get("parameters", {})
|
| 190 |
+
return inputs, parameters
|
| 191 |
+
|
| 192 |
+
# Legacy RunPod format: {"input": {"audio_url": ..., "options": ...}}
|
| 193 |
+
if "input" in data:
|
| 194 |
+
job_input = data["input"]
|
| 195 |
+
inputs = {
|
| 196 |
+
"audio_url": job_input.get("audio_url"),
|
| 197 |
+
"performance_id": job_input.get("performance_id", "unknown"),
|
| 198 |
+
}
|
| 199 |
+
parameters = job_input.get("options", {})
|
| 200 |
+
parameters["performance_id"] = inputs.get("performance_id", "unknown")
|
| 201 |
+
return inputs, parameters
|
| 202 |
+
|
| 203 |
+
# Fallback: treat entire data as inputs
|
| 204 |
+
return data, {}
|
| 205 |
+
|
| 206 |
+
def _load_audio(
|
| 207 |
+
self, inputs: Union[str, bytes, Dict[str, Any]], max_duration: int
|
| 208 |
+
) -> tuple:
|
| 209 |
+
"""Load audio from various input formats.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
inputs: One of:
|
| 213 |
+
- str: Base64-encoded audio bytes
|
| 214 |
+
- bytes: Raw audio bytes
|
| 215 |
+
- dict: {"audio_url": "..."} for URL-based loading
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
Tuple of (audio_array, duration_seconds)
|
| 219 |
+
"""
|
| 220 |
+
if isinstance(inputs, str):
|
| 221 |
+
# Base64-encoded audio
|
| 222 |
+
try:
|
| 223 |
+
audio_bytes = base64.b64decode(inputs)
|
| 224 |
+
return preprocess_audio_from_bytes(audio_bytes, max_duration=max_duration)
|
| 225 |
+
except Exception:
|
| 226 |
+
# Maybe it's a URL string
|
| 227 |
+
if inputs.startswith("http"):
|
| 228 |
+
return download_and_preprocess_audio(inputs, max_duration=max_duration)
|
| 229 |
+
raise AudioProcessingError("Invalid input string: not base64 or URL")
|
| 230 |
+
|
| 231 |
+
elif isinstance(inputs, bytes):
|
| 232 |
+
# Raw bytes
|
| 233 |
+
return preprocess_audio_from_bytes(inputs, max_duration=max_duration)
|
| 234 |
+
|
| 235 |
+
elif isinstance(inputs, dict):
|
| 236 |
+
# URL-based input
|
| 237 |
+
audio_url = inputs.get("audio_url")
|
| 238 |
+
if not audio_url:
|
| 239 |
+
raise AudioProcessingError("No audio_url provided in inputs")
|
| 240 |
+
return download_and_preprocess_audio(audio_url, max_duration=max_duration)
|
| 241 |
+
|
| 242 |
+
else:
|
| 243 |
+
raise AudioProcessingError(f"Unsupported input type: {type(inputs)}")
|
| 244 |
+
|
| 245 |
+
def _predictions_to_dict(self, preds: np.ndarray) -> Dict[str, float]:
|
| 246 |
+
"""Convert prediction array to dimension dict."""
|
| 247 |
+
return {dim: float(preds[i]) for i, dim in enumerate(PERCEPIANO_DIMENSIONS)}
|
models/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model loading and inference for A1-Max MuQ LoRA."""
|
| 2 |
+
|
| 3 |
+
from models.loader import (
|
| 4 |
+
A1MaxInferenceHead,
|
| 5 |
+
ModelCache,
|
| 6 |
+
get_model_cache,
|
| 7 |
+
)
|
| 8 |
+
from models.inference import (
|
| 9 |
+
extract_muq_embeddings,
|
| 10 |
+
predict_with_ensemble,
|
| 11 |
+
)
|
| 12 |
+
from models.calibration import (
|
| 13 |
+
calibrate_predictions,
|
| 14 |
+
predictions_to_calibrated_dict,
|
| 15 |
+
get_calibration_context,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"A1MaxInferenceHead",
|
| 20 |
+
"ModelCache",
|
| 21 |
+
"get_model_cache",
|
| 22 |
+
"extract_muq_embeddings",
|
| 23 |
+
"predict_with_ensemble",
|
| 24 |
+
"calibrate_predictions",
|
| 25 |
+
"predictions_to_calibrated_dict",
|
| 26 |
+
"get_calibration_context",
|
| 27 |
+
]
|
models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (688 Bytes). View file
|
|
|
models/__pycache__/inference.cpython-312.pyc
ADDED
|
Binary file (4.87 kB). View file
|
|
|
models/__pycache__/loader.cpython-312.pyc
ADDED
|
Binary file (8.84 kB). View file
|
|
|
models/calibration.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MAESTRO-based calibration for performance predictions.
|
| 2 |
+
|
| 3 |
+
Normalizes raw model predictions relative to professional MAESTRO recordings,
|
| 4 |
+
making scores more interpretable for end users.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Dict
|
| 9 |
+
|
| 10 |
+
from constants import MAESTRO_CALIBRATION, PERCEPIANO_DIMENSIONS
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def calibrate_predictions(
|
| 14 |
+
raw_predictions: np.ndarray,
|
| 15 |
+
method: str = "percentile",
|
| 16 |
+
) -> np.ndarray:
|
| 17 |
+
"""Calibrate raw predictions using MAESTRO professional benchmarks.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
raw_predictions: Raw model outputs [6] in range ~[0, 1]
|
| 21 |
+
method: Calibration method:
|
| 22 |
+
- "percentile": Scale to [0, 1] where 0 = MAESTRO 5th percentile,
|
| 23 |
+
1 = MAESTRO 95th percentile. Scores can exceed [0, 1] for
|
| 24 |
+
exceptional or below-average performances.
|
| 25 |
+
- "zscore": Convert to z-scores relative to MAESTRO distribution.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Calibrated predictions [6]. For "percentile" method, ~0.5 means
|
| 29 |
+
comparable to average MAESTRO professional performance.
|
| 30 |
+
"""
|
| 31 |
+
calibrated = np.zeros_like(raw_predictions)
|
| 32 |
+
|
| 33 |
+
for i, dim in enumerate(PERCEPIANO_DIMENSIONS):
|
| 34 |
+
raw_score = raw_predictions[i]
|
| 35 |
+
|
| 36 |
+
# Get calibration stats - keys match PERCEPIANO_DIMENSIONS exactly
|
| 37 |
+
dim_key = dim
|
| 38 |
+
if dim_key not in MAESTRO_CALIBRATION:
|
| 39 |
+
# Fallback: use raw score (this shouldn't happen with properly configured data)
|
| 40 |
+
calibrated[i] = raw_score
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
stats = MAESTRO_CALIBRATION[dim_key]
|
| 44 |
+
|
| 45 |
+
if method == "percentile":
|
| 46 |
+
# Scale so MAESTRO 5th percentile = 0, 95th percentile = 1
|
| 47 |
+
# This means ~0.5 = average professional performance
|
| 48 |
+
p5 = stats["p5"]
|
| 49 |
+
p95 = stats["p95"]
|
| 50 |
+
range_width = p95 - p5
|
| 51 |
+
|
| 52 |
+
if range_width > 0:
|
| 53 |
+
calibrated[i] = (raw_score - p5) / range_width
|
| 54 |
+
else:
|
| 55 |
+
calibrated[i] = 0.5
|
| 56 |
+
|
| 57 |
+
elif method == "zscore":
|
| 58 |
+
# Convert to z-score relative to MAESTRO mean/std
|
| 59 |
+
mean = stats["mean"]
|
| 60 |
+
std = stats["std"]
|
| 61 |
+
if std > 0:
|
| 62 |
+
calibrated[i] = (raw_score - mean) / std
|
| 63 |
+
else:
|
| 64 |
+
calibrated[i] = 0.0
|
| 65 |
+
|
| 66 |
+
else:
|
| 67 |
+
calibrated[i] = raw_score
|
| 68 |
+
|
| 69 |
+
return calibrated
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def predictions_to_calibrated_dict(
|
| 73 |
+
raw_predictions: np.ndarray,
|
| 74 |
+
) -> Dict[str, Dict[str, float]]:
|
| 75 |
+
"""Convert raw predictions to a dict with both raw and calibrated scores.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
raw_predictions: Raw model outputs [6]
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Dict with structure:
|
| 82 |
+
{
|
| 83 |
+
"timing": {"raw": 0.65, "calibrated": 0.42, "percentile_rank": 42},
|
| 84 |
+
...
|
| 85 |
+
}
|
| 86 |
+
"""
|
| 87 |
+
calibrated = calibrate_predictions(raw_predictions, method="percentile")
|
| 88 |
+
result = {}
|
| 89 |
+
|
| 90 |
+
for i, dim in enumerate(PERCEPIANO_DIMENSIONS):
|
| 91 |
+
raw_score = float(raw_predictions[i])
|
| 92 |
+
cal_score = float(calibrated[i])
|
| 93 |
+
|
| 94 |
+
# Clamp percentile rank to [0, 100] for display
|
| 95 |
+
percentile_rank = int(max(0, min(100, cal_score * 100)))
|
| 96 |
+
|
| 97 |
+
result[dim] = {
|
| 98 |
+
"raw": round(raw_score, 4),
|
| 99 |
+
"calibrated": round(max(0.0, min(1.0, cal_score)), 4),
|
| 100 |
+
"percentile_rank": percentile_rank,
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
return result
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_calibration_context() -> str:
|
| 107 |
+
"""Get a text description of the calibration for LLM context.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
String describing how to interpret calibrated scores.
|
| 111 |
+
"""
|
| 112 |
+
return """Score Interpretation (calibrated relative to 500 professional MAESTRO recordings):
|
| 113 |
+
- 0.0 = Performance at the 5th percentile of professionals (lower end)
|
| 114 |
+
- 0.5 = Performance at the 50th percentile of professionals (average professional level)
|
| 115 |
+
- 1.0 = Performance at the 95th percentile of professionals (exceptional)
|
| 116 |
+
- Scores can exceed [0, 1] for truly exceptional or below-average performances
|
| 117 |
+
|
| 118 |
+
Note: These scores compare against competition-level professional pianists.
|
| 119 |
+
A calibrated score of 0.5 represents professional-level competency."""
|
models/inference.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A1-Max MuQ inference - MuQ embedding extraction and prediction."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from constants import MODEL_CONFIG
|
| 7 |
+
from models.loader import ModelCache
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@torch.no_grad()
|
| 11 |
+
def extract_muq_embeddings(
|
| 12 |
+
audio: np.ndarray,
|
| 13 |
+
cache: ModelCache,
|
| 14 |
+
layer_start: int = None,
|
| 15 |
+
layer_end: int = None,
|
| 16 |
+
max_frames: int = None,
|
| 17 |
+
) -> torch.Tensor:
|
| 18 |
+
"""Extract MuQ embeddings from audio waveform.
|
| 19 |
+
|
| 20 |
+
Averages hidden states from layers 9-12 (best performing range).
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
audio: Audio waveform at 24kHz
|
| 24 |
+
cache: Model cache with loaded MuQ model
|
| 25 |
+
layer_start: Start layer (inclusive), default 9
|
| 26 |
+
layer_end: End layer (exclusive), default 13
|
| 27 |
+
max_frames: Maximum frames to keep
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Embeddings tensor [T, 1024] where T is number of frames
|
| 31 |
+
"""
|
| 32 |
+
layer_start = layer_start or MODEL_CONFIG["muq_layer_start"]
|
| 33 |
+
layer_end = layer_end or MODEL_CONFIG["muq_layer_end"]
|
| 34 |
+
max_frames = max_frames or MODEL_CONFIG["max_frames"]
|
| 35 |
+
|
| 36 |
+
# MuQ expects [B, samples] tensor
|
| 37 |
+
wavs = torch.tensor(audio).unsqueeze(0).to(cache.device)
|
| 38 |
+
|
| 39 |
+
# Get hidden states from all layers
|
| 40 |
+
outputs = cache.muq_model(wavs, output_hidden_states=True)
|
| 41 |
+
|
| 42 |
+
# Average layers 9-12 (indices in hidden_states tuple)
|
| 43 |
+
# hidden_states is tuple of [B, T, D] tensors
|
| 44 |
+
hidden_states = outputs.hidden_states[layer_start:layer_end]
|
| 45 |
+
embeddings = torch.stack(hidden_states, dim=0).mean(dim=0).squeeze(0)
|
| 46 |
+
|
| 47 |
+
if embeddings.shape[0] > max_frames:
|
| 48 |
+
embeddings = embeddings[:max_frames]
|
| 49 |
+
|
| 50 |
+
return embeddings
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@torch.no_grad()
|
| 54 |
+
def predict_with_ensemble(
|
| 55 |
+
embeddings: torch.Tensor,
|
| 56 |
+
cache: ModelCache,
|
| 57 |
+
) -> np.ndarray:
|
| 58 |
+
"""Get predictions from 4-fold ensemble of A1-Max heads.
|
| 59 |
+
|
| 60 |
+
Each head uses attention pooling on frame-level embeddings,
|
| 61 |
+
then encoder + regression head to predict 6-dim scores.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
embeddings: Frame embeddings [T, D] from MuQ
|
| 65 |
+
cache: Model cache with loaded heads
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Averaged predictions [6] across all folds
|
| 69 |
+
"""
|
| 70 |
+
if not cache.muq_heads:
|
| 71 |
+
raise RuntimeError("No A1-Max heads loaded in cache")
|
| 72 |
+
|
| 73 |
+
# Get predictions from each fold head
|
| 74 |
+
predictions = []
|
| 75 |
+
for head in cache.muq_heads:
|
| 76 |
+
pred = head(embeddings).cpu().numpy()
|
| 77 |
+
predictions.append(pred)
|
| 78 |
+
|
| 79 |
+
return np.mean(predictions, axis=0)
|
models/loader.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model loading and caching for A1-Max MuQ LoRA inference."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from constants import MODEL_CONFIG, N_FOLDS
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class A1MaxInferenceHead(nn.Module):
|
| 13 |
+
"""Inference-only version of MuQLoRAMaxModel's predict_scores path.
|
| 14 |
+
|
| 15 |
+
Replicates the architecture needed for score prediction:
|
| 16 |
+
- Attention pooling: [B, T, D] -> [B, D]
|
| 17 |
+
- Encoder: 2-layer MLP [B, D] -> [B, hidden_dim]
|
| 18 |
+
- Regression head: MLP + sigmoid [B, hidden_dim] -> [B, num_labels]
|
| 19 |
+
|
| 20 |
+
Does NOT include ranking/contrastive/comparator modules (training-only).
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
input_dim: int = 1024,
|
| 26 |
+
hidden_dim: int = 512,
|
| 27 |
+
num_labels: int = 6,
|
| 28 |
+
dropout: float = 0.2,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.num_labels = num_labels
|
| 32 |
+
|
| 33 |
+
# Attention pooling (matches MuQLoRAModel.attn)
|
| 34 |
+
self.attn = nn.Sequential(
|
| 35 |
+
nn.Linear(input_dim, 256), nn.Tanh(), nn.Linear(256, 1)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Shared encoder (matches MuQLoRAModel.encoder)
|
| 39 |
+
self.encoder = nn.Sequential(
|
| 40 |
+
nn.Linear(input_dim, hidden_dim),
|
| 41 |
+
nn.LayerNorm(hidden_dim),
|
| 42 |
+
nn.GELU(),
|
| 43 |
+
nn.Dropout(dropout),
|
| 44 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 45 |
+
nn.LayerNorm(hidden_dim),
|
| 46 |
+
nn.GELU(),
|
| 47 |
+
nn.Dropout(dropout),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Regression head (matches MuQLoRAModel.regression_head)
|
| 51 |
+
self.regression_head = nn.Sequential(
|
| 52 |
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
| 53 |
+
nn.GELU(),
|
| 54 |
+
nn.Dropout(dropout),
|
| 55 |
+
nn.Linear(hidden_dim // 2, num_labels),
|
| 56 |
+
nn.Sigmoid(),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
"""Predict quality scores from frame embeddings.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
embeddings: Frame embeddings [B, T, D] or [T, D].
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Scores [B, num_labels] or [num_labels] in [0, 1].
|
| 67 |
+
"""
|
| 68 |
+
squeeze_output = False
|
| 69 |
+
if embeddings.dim() == 2:
|
| 70 |
+
embeddings = embeddings.unsqueeze(0)
|
| 71 |
+
squeeze_output = True
|
| 72 |
+
|
| 73 |
+
# Attention pool
|
| 74 |
+
scores = self.attn(embeddings).squeeze(-1) # [B, T]
|
| 75 |
+
w = torch.softmax(scores, dim=-1).unsqueeze(-1) # [B, T, 1]
|
| 76 |
+
pooled = (embeddings * w).sum(1) # [B, D]
|
| 77 |
+
|
| 78 |
+
# Encode
|
| 79 |
+
z = self.encoder(pooled) # [B, hidden_dim]
|
| 80 |
+
|
| 81 |
+
# Predict
|
| 82 |
+
result = self.regression_head(z) # [B, num_labels]
|
| 83 |
+
|
| 84 |
+
return result.squeeze(0) if squeeze_output else result
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ModelCache:
|
| 88 |
+
"""Singleton cache for loaded models."""
|
| 89 |
+
|
| 90 |
+
_instance: Optional["ModelCache"] = None
|
| 91 |
+
|
| 92 |
+
def __new__(cls) -> "ModelCache":
|
| 93 |
+
if cls._instance is None:
|
| 94 |
+
cls._instance = super().__new__(cls)
|
| 95 |
+
cls._instance._initialized = False
|
| 96 |
+
return cls._instance
|
| 97 |
+
|
| 98 |
+
def __init__(self):
|
| 99 |
+
if self._initialized:
|
| 100 |
+
return
|
| 101 |
+
self.muq_model = None
|
| 102 |
+
self.muq_heads: List[A1MaxInferenceHead] = []
|
| 103 |
+
self.device = None
|
| 104 |
+
self._initialized = True
|
| 105 |
+
|
| 106 |
+
def initialize(self, device: str = "cuda", checkpoint_dir: Optional[Path] = None):
|
| 107 |
+
"""Load MuQ model and A1-Max prediction heads. Called once on container start."""
|
| 108 |
+
if self.muq_model is not None:
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 112 |
+
print(f"Initializing A1-Max models on {self.device}...")
|
| 113 |
+
|
| 114 |
+
# Load MuQ from HuggingFace
|
| 115 |
+
print("Loading MuQ-large-msd-iter...")
|
| 116 |
+
try:
|
| 117 |
+
from muq import MuQ
|
| 118 |
+
self.muq_model = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
|
| 119 |
+
self.muq_model = self.muq_model.to(self.device)
|
| 120 |
+
self.muq_model.eval()
|
| 121 |
+
print("MuQ loaded successfully")
|
| 122 |
+
except ImportError as e:
|
| 123 |
+
raise ImportError(
|
| 124 |
+
"MuQ library not found. Install with: pip install muq"
|
| 125 |
+
) from e
|
| 126 |
+
|
| 127 |
+
# Load A1-Max prediction heads (4 folds)
|
| 128 |
+
print("Loading A1-Max prediction heads...")
|
| 129 |
+
checkpoint_dir = checkpoint_dir or Path("/repository/checkpoints")
|
| 130 |
+
if not checkpoint_dir.exists():
|
| 131 |
+
checkpoint_dir = Path("/app/checkpoints")
|
| 132 |
+
|
| 133 |
+
for fold in range(N_FOLDS):
|
| 134 |
+
ckpt_path = checkpoint_dir / f"fold_{fold}" / "best.ckpt"
|
| 135 |
+
# Also try the epoch-based naming from sweep
|
| 136 |
+
if not ckpt_path.exists():
|
| 137 |
+
fold_dir = checkpoint_dir / f"fold_{fold}"
|
| 138 |
+
if fold_dir.exists():
|
| 139 |
+
ckpts = sorted(fold_dir.glob("*.ckpt"))
|
| 140 |
+
if ckpts:
|
| 141 |
+
ckpt_path = ckpts[0]
|
| 142 |
+
if ckpt_path.exists():
|
| 143 |
+
head = self._load_a1max_head(ckpt_path)
|
| 144 |
+
self.muq_heads.append(head)
|
| 145 |
+
print(f" Loaded fold {fold} from {ckpt_path}")
|
| 146 |
+
else:
|
| 147 |
+
print(f" Warning: No checkpoint found for fold {fold}")
|
| 148 |
+
|
| 149 |
+
print(f"Initialization complete. {len(self.muq_heads)} heads loaded.")
|
| 150 |
+
|
| 151 |
+
def _load_a1max_head(self, ckpt_path: Path) -> A1MaxInferenceHead:
|
| 152 |
+
"""Load an A1MaxInferenceHead from PyTorch Lightning checkpoint."""
|
| 153 |
+
checkpoint = torch.load(ckpt_path, map_location=self.device, weights_only=False)
|
| 154 |
+
|
| 155 |
+
hparams = checkpoint.get("hyper_parameters", {})
|
| 156 |
+
|
| 157 |
+
head = A1MaxInferenceHead(
|
| 158 |
+
input_dim=hparams.get("input_dim", MODEL_CONFIG["input_dim"]),
|
| 159 |
+
hidden_dim=hparams.get("hidden_dim", MODEL_CONFIG["hidden_dim"]),
|
| 160 |
+
num_labels=hparams.get("num_labels", MODEL_CONFIG["num_labels"]),
|
| 161 |
+
dropout=hparams.get("dropout", MODEL_CONFIG["dropout"]),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Load state dict from Lightning checkpoint
|
| 165 |
+
state_dict = checkpoint["state_dict"]
|
| 166 |
+
|
| 167 |
+
# Map Lightning keys to inference head keys
|
| 168 |
+
# Lightning saves as: attn.0.weight, encoder.0.weight, regression_head.0.weight, etc.
|
| 169 |
+
head_state = {}
|
| 170 |
+
for key, value in state_dict.items():
|
| 171 |
+
if key.startswith("attn.") or key.startswith("encoder.") or key.startswith("regression_head."):
|
| 172 |
+
head_state[key] = value
|
| 173 |
+
|
| 174 |
+
head.load_state_dict(head_state, strict=True)
|
| 175 |
+
|
| 176 |
+
head.to(self.device)
|
| 177 |
+
head.eval()
|
| 178 |
+
return head
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
_cache: Optional[ModelCache] = None
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_model_cache() -> ModelCache:
|
| 185 |
+
"""Get the global model cache instance."""
|
| 186 |
+
global _cache
|
| 187 |
+
if _cache is None:
|
| 188 |
+
_cache = ModelCache()
|
| 189 |
+
return _cache
|
preprocessing/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio preprocessing modules."""
|
| 2 |
+
|
| 3 |
+
from preprocessing.audio import (
|
| 4 |
+
download_and_preprocess_audio,
|
| 5 |
+
preprocess_audio_from_bytes,
|
| 6 |
+
AudioDownloadError,
|
| 7 |
+
AudioProcessingError,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"download_and_preprocess_audio",
|
| 12 |
+
"preprocess_audio_from_bytes",
|
| 13 |
+
"AudioDownloadError",
|
| 14 |
+
"AudioProcessingError",
|
| 15 |
+
]
|
preprocessing/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (541 Bytes). View file
|
|
|
preprocessing/__pycache__/audio.cpython-312.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
preprocessing/audio.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio download and preprocessing for D9c inference."""
|
| 2 |
+
|
| 3 |
+
import tempfile
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
import librosa
|
| 8 |
+
import numpy as np
|
| 9 |
+
import requests
|
| 10 |
+
|
| 11 |
+
# Default sample rate for MERT/MuQ (hardcoded to avoid import issues)
|
| 12 |
+
TARGET_SR = 24000
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AudioDownloadError(Exception):
|
| 16 |
+
"""Raised when audio download fails."""
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AudioProcessingError(Exception):
|
| 21 |
+
"""Raised when audio processing fails."""
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def download_and_preprocess_audio(
|
| 26 |
+
audio_url: str,
|
| 27 |
+
target_sr: int = TARGET_SR,
|
| 28 |
+
max_duration: int = 300,
|
| 29 |
+
timeout: int = 60,
|
| 30 |
+
) -> Tuple[np.ndarray, float]:
|
| 31 |
+
"""Download audio from URL and preprocess for MERT/MuQ.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
audio_url: URL to download audio from
|
| 35 |
+
target_sr: Target sample rate (24kHz for MERT/MuQ)
|
| 36 |
+
max_duration: Maximum audio duration in seconds
|
| 37 |
+
timeout: Download timeout in seconds
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Tuple of (audio_array, duration_seconds)
|
| 41 |
+
|
| 42 |
+
Raises:
|
| 43 |
+
AudioDownloadError: If download fails
|
| 44 |
+
AudioProcessingError: If audio processing fails
|
| 45 |
+
"""
|
| 46 |
+
try:
|
| 47 |
+
response = requests.get(audio_url, timeout=timeout, stream=True)
|
| 48 |
+
response.raise_for_status()
|
| 49 |
+
except requests.RequestException as e:
|
| 50 |
+
raise AudioDownloadError(f"Failed to download audio: {e}")
|
| 51 |
+
|
| 52 |
+
# Determine file extension from content-type or URL
|
| 53 |
+
content_type = response.headers.get("content-type", "")
|
| 54 |
+
if "mpeg" in content_type or audio_url.endswith(".mp3"):
|
| 55 |
+
suffix = ".mp3"
|
| 56 |
+
elif "wav" in content_type or audio_url.endswith(".wav"):
|
| 57 |
+
suffix = ".wav"
|
| 58 |
+
elif "flac" in content_type or audio_url.endswith(".flac"):
|
| 59 |
+
suffix = ".flac"
|
| 60 |
+
else:
|
| 61 |
+
suffix = ".mp3"
|
| 62 |
+
|
| 63 |
+
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
|
| 64 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 65 |
+
f.write(chunk)
|
| 66 |
+
temp_path = Path(f.name)
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
audio, sr = librosa.load(temp_path, sr=target_sr, mono=True)
|
| 70 |
+
duration = len(audio) / sr
|
| 71 |
+
|
| 72 |
+
if duration > max_duration:
|
| 73 |
+
raise AudioProcessingError(
|
| 74 |
+
f"Audio too long: {duration:.1f}s > {max_duration}s limit"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if duration < 1.0:
|
| 78 |
+
raise AudioProcessingError(
|
| 79 |
+
f"Audio too short: {duration:.1f}s < 1.0s minimum"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return audio, duration
|
| 83 |
+
|
| 84 |
+
except AudioProcessingError:
|
| 85 |
+
raise
|
| 86 |
+
except Exception as e:
|
| 87 |
+
raise AudioProcessingError(f"Failed to process audio: {e}")
|
| 88 |
+
|
| 89 |
+
finally:
|
| 90 |
+
temp_path.unlink(missing_ok=True)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_audio_from_file(
|
| 94 |
+
audio_path: Path,
|
| 95 |
+
target_sr: int = TARGET_SR,
|
| 96 |
+
) -> Tuple[np.ndarray, float]:
|
| 97 |
+
"""Load audio from local file."""
|
| 98 |
+
audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
|
| 99 |
+
duration = len(audio) / sr
|
| 100 |
+
return audio, duration
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def preprocess_audio_from_bytes(
|
| 104 |
+
audio_bytes: bytes,
|
| 105 |
+
target_sr: int = TARGET_SR,
|
| 106 |
+
max_duration: int = 300,
|
| 107 |
+
) -> Tuple[np.ndarray, float]:
|
| 108 |
+
"""Preprocess audio from raw bytes (e.g., base64 decoded)."""
|
| 109 |
+
import io
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
audio, sr = librosa.load(io.BytesIO(audio_bytes), sr=target_sr, mono=True)
|
| 113 |
+
duration = len(audio) / sr
|
| 114 |
+
|
| 115 |
+
if duration > max_duration:
|
| 116 |
+
raise AudioProcessingError(
|
| 117 |
+
f"Audio too long: {duration:.1f}s > {max_duration}s limit"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if duration < 1.0:
|
| 121 |
+
raise AudioProcessingError(
|
| 122 |
+
f"Audio too short: {duration:.1f}s < 1.0s minimum"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return audio, duration
|
| 126 |
+
|
| 127 |
+
except AudioProcessingError:
|
| 128 |
+
raise
|
| 129 |
+
except Exception as e:
|
| 130 |
+
raise AudioProcessingError(f"Failed to process audio bytes: {e}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A1-Max MuQ LoRA - HuggingFace Inference Endpoints dependencies
|
| 2 |
+
# This file is read by HF Endpoints to install Python packages
|
| 3 |
+
|
| 4 |
+
# PyTorch and ML
|
| 5 |
+
torch>=2.0.0
|
| 6 |
+
transformers>=4.30.0
|
| 7 |
+
pytorch-lightning>=2.0.0
|
| 8 |
+
|
| 9 |
+
# Audio embedding models
|
| 10 |
+
muq # MuQ - Music Understanding Quantized from ByteDance/OpenMuQ
|
| 11 |
+
|
| 12 |
+
# Audio processing
|
| 13 |
+
librosa>=0.10.0
|
| 14 |
+
soundfile>=0.12.0
|
| 15 |
+
|
| 16 |
+
# Utilities
|
| 17 |
+
numpy>=1.24.0
|
| 18 |
+
scipy>=1.10.0
|
| 19 |
+
requests>=2.28.0
|
sync_checkpoints.sh
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Sync A1-Max MuQ LoRA checkpoints from Google Drive
|
| 3 |
+
# Run this before building the Docker image or uploading to HuggingFace
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
CHECKPOINT_DIR="./checkpoints"
|
| 8 |
+
GDRIVE_PATH="gdrive:crescendai_data/checkpoints/a1_max_sweep/A1max_r32_L7-12_ls0.1"
|
| 9 |
+
|
| 10 |
+
echo "A1-Max MuQ LoRA Checkpoint Sync"
|
| 11 |
+
echo "================================"
|
| 12 |
+
echo ""
|
| 13 |
+
|
| 14 |
+
echo "Creating checkpoint directories..."
|
| 15 |
+
mkdir -p "$CHECKPOINT_DIR/fold_0"
|
| 16 |
+
mkdir -p "$CHECKPOINT_DIR/fold_1"
|
| 17 |
+
mkdir -p "$CHECKPOINT_DIR/fold_2"
|
| 18 |
+
mkdir -p "$CHECKPOINT_DIR/fold_3"
|
| 19 |
+
|
| 20 |
+
echo ""
|
| 21 |
+
echo "Syncing A1-Max checkpoints (4-fold ensemble, 80.8% pairwise)..."
|
| 22 |
+
echo "Source: $GDRIVE_PATH"
|
| 23 |
+
echo ""
|
| 24 |
+
|
| 25 |
+
# Sync each fold's best checkpoint
|
| 26 |
+
for fold in 0 1 2 3; do
|
| 27 |
+
echo "Syncing fold_$fold..."
|
| 28 |
+
rclone copyto "$GDRIVE_PATH/fold_${fold}/best.ckpt" "$CHECKPOINT_DIR/fold_$fold/best.ckpt" --progress
|
| 29 |
+
done
|
| 30 |
+
|
| 31 |
+
echo ""
|
| 32 |
+
echo "Checkpoint sync complete!"
|
| 33 |
+
echo ""
|
| 34 |
+
echo "Directory structure:"
|
| 35 |
+
ls -la "$CHECKPOINT_DIR"
|
| 36 |
+
echo ""
|
| 37 |
+
|
| 38 |
+
for fold in 0 1 2 3; do
|
| 39 |
+
echo "fold_$fold:"
|
| 40 |
+
ls -la "$CHECKPOINT_DIR/fold_$fold"
|
| 41 |
+
done
|
| 42 |
+
|
| 43 |
+
echo ""
|
| 44 |
+
echo "Expected HuggingFace repository structure:"
|
| 45 |
+
echo " checkpoints/"
|
| 46 |
+
echo " fold_0/best.ckpt"
|
| 47 |
+
echo " fold_1/best.ckpt"
|
| 48 |
+
echo " fold_2/best.ckpt"
|
| 49 |
+
echo " fold_3/best.ckpt"
|
| 50 |
+
echo ""
|
| 51 |
+
echo "Model: A1-Max MuQ LoRA r32 L7-12 (6-dim, 80.8% pairwise, R2=0.50)"
|