Spaces:
Sleeping
Sleeping
| """FastAPI backend for the piano tutorial transcription pipeline.""" | |
| import json | |
| import shutil | |
| import sys | |
| import tempfile | |
| import threading | |
| import traceback | |
| import uuid | |
| from pathlib import Path | |
| import pretty_midi | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| # Add transcriber to path | |
| TRANSCRIBER_DIR = Path(__file__).resolve().parent.parent / "transcriber" | |
| sys.path.insert(0, str(TRANSCRIBER_DIR)) | |
| app = FastAPI(title="Piano Tutorial API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Directory for temporary processing files | |
| WORK_DIR = Path(tempfile.gettempdir()) / "piano-tutorial" | |
| WORK_DIR.mkdir(exist_ok=True) | |
| async def transcribe( | |
| file: UploadFile = File(...), | |
| ): | |
| """Transcribe an uploaded audio file to MIDI. | |
| Accepts a file upload (MP3, M4A, WAV, OGG, FLAC). | |
| Returns JSON with a job_id, MIDI download URL, and chord data. | |
| """ | |
| job_id = str(uuid.uuid4())[:8] | |
| job_dir = WORK_DIR / job_id | |
| job_dir.mkdir(exist_ok=True) | |
| try: | |
| suffix = Path(file.filename).suffix or ".m4a" | |
| audio_path = job_dir / f"upload{suffix}" | |
| content = await file.read() | |
| audio_path.write_bytes(content) | |
| # Run transcription | |
| from transcribe import transcribe as run_transcribe | |
| raw_midi_path = job_dir / "transcription_raw.mid" | |
| run_transcribe(str(audio_path), str(raw_midi_path)) | |
| # Run optimization (also runs chord detection as Step 10) | |
| from optimize import optimize | |
| optimized_path = job_dir / "transcription.mid" | |
| optimize(str(audio_path), str(raw_midi_path), str(optimized_path)) | |
| if not optimized_path.exists(): | |
| raise HTTPException(500, "Optimization failed to produce output") | |
| # Load chord data if available | |
| chords_path = job_dir / "transcription_chords.json" | |
| chord_data = None | |
| if chords_path.exists(): | |
| with open(chords_path) as f: | |
| chord_data = json.load(f) | |
| return JSONResponse({ | |
| "job_id": job_id, | |
| "midi_url": f"/api/jobs/{job_id}/midi", | |
| "chords_url": f"/api/jobs/{job_id}/chords", | |
| "audio_url": f"/api/jobs/{job_id}/audio", | |
| "chords": chord_data, | |
| }) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(500, f"Transcription failed: {str(e)}") | |
| async def get_midi(job_id: str): | |
| """Download the optimized MIDI file for a completed job.""" | |
| midi_path = WORK_DIR / job_id / "transcription.mid" | |
| if not midi_path.exists(): | |
| raise HTTPException(404, f"No MIDI file found for job {job_id}") | |
| return FileResponse( | |
| midi_path, | |
| media_type="audio/midi", | |
| filename="transcription.mid", | |
| ) | |
| async def get_chords(job_id: str): | |
| """Get the detected chord data for a completed job.""" | |
| chords_path = WORK_DIR / job_id / "transcription_chords.json" | |
| if not chords_path.exists(): | |
| raise HTTPException(404, f"No chord data found for job {job_id}") | |
| with open(chords_path) as f: | |
| chord_data = json.load(f) | |
| return JSONResponse(chord_data) | |
| # ββ Full-song mode (Demucs source separation) ββββββββββββββββββββββββββ | |
| # In-memory job status for async full-song transcription | |
| job_status = {} | |
| def merge_stems(piano_midi_path, bass_midi_path, output_path): | |
| """Merge piano and bass MIDI into a single multi-track file.""" | |
| piano = pretty_midi.PrettyMIDI(str(piano_midi_path)) | |
| bass = pretty_midi.PrettyMIDI(str(bass_midi_path)) | |
| merged = pretty_midi.PrettyMIDI() | |
| # Track 0: Piano (program 0) | |
| piano_inst = pretty_midi.Instrument(program=0, name="Piano") | |
| for inst in piano.instruments: | |
| piano_inst.notes.extend(inst.notes) | |
| merged.instruments.append(piano_inst) | |
| # Track 1: Bass (program 33) | |
| bass_inst = pretty_midi.Instrument(program=33, name="Bass") | |
| for inst in bass.instruments: | |
| bass_inst.notes.extend(inst.notes) | |
| merged.instruments.append(bass_inst) | |
| merged.write(str(output_path)) | |
| def run_full_transcription(job_id, audio_path, job_dir): | |
| """Background worker for full-song transcription with Demucs.""" | |
| try: | |
| # Step 1: Demucs separation | |
| job_status[job_id] = {"step": 1, "label": "Separating instruments with AI...", "done": False} | |
| from separate import separate | |
| stems = separate(str(audio_path), str(job_dir / "stems")) | |
| # Step 2: Transcribe melodic + bass stems | |
| job_status[job_id] = {"step": 2, "label": "Transcribing instruments...", "done": False} | |
| from transcribe import transcribe as run_transcribe | |
| piano_raw = job_dir / "piano_raw.mid" | |
| run_transcribe(stems["other"], str(piano_raw)) | |
| bass_raw = job_dir / "bass_raw.mid" | |
| run_transcribe(stems["bass"], str(bass_raw)) | |
| # Step 3: Optimize transcriptions | |
| # Use the full solo piano optimizer for the melodic stem β it produces | |
| # much better rhythm, playability, and note accuracy. Also runs chord | |
| # detection and spectral analysis internally. | |
| job_status[job_id] = {"step": 3, "label": "Optimizing note accuracy...", "done": False} | |
| from optimize import optimize as optimize_piano | |
| from optimize_bass import optimize_bass | |
| piano_opt = job_dir / "transcription.tmp.mid" | |
| optimize_piano(stems["other"], str(piano_raw), str(piano_opt)) | |
| # Solo optimizer writes chords to {stem}_chords.json next to the output | |
| auto_chords = job_dir / "transcription.tmp_chords.json" | |
| chords_path = job_dir / "transcription_chords.json" | |
| if auto_chords.exists(): | |
| auto_chords.rename(chords_path) | |
| # Rename to final path | |
| piano_final = job_dir / "piano_optimized.mid" | |
| piano_opt.rename(piano_final) | |
| piano_opt = piano_final | |
| bass_opt = job_dir / "bass_optimized.mid" | |
| optimize_bass(stems["bass"], str(bass_raw), str(bass_opt)) | |
| # Load chord data | |
| chord_data = None | |
| if chords_path.exists(): | |
| with open(chords_path) as f: | |
| chord_data = json.load(f) | |
| # Step 4: Transcribe drums | |
| job_status[job_id] = {"step": 4, "label": "Transcribing drums...", "done": False} | |
| from drums import transcribe_drums | |
| drum_tab_path = job_dir / "drum_tab.json" | |
| transcribe_drums(stems["drums"], str(drum_tab_path)) | |
| # Step 5: Generate guitar and bass tabs | |
| job_status[job_id] = {"step": 5, "label": "Generating tabs...", "done": False} | |
| from tabs import midi_to_guitar_tab, midi_to_bass_tab | |
| guitar_tab = midi_to_guitar_tab(str(piano_opt), str(chords_path)) | |
| guitar_tab_path = job_dir / "guitar_tab.json" | |
| with open(guitar_tab_path, 'w') as f: | |
| json.dump(guitar_tab, f) | |
| bass_tab = midi_to_bass_tab(str(bass_opt)) | |
| bass_tab_path = job_dir / "bass_tab.json" | |
| with open(bass_tab_path, 'w') as f: | |
| json.dump(bass_tab, f) | |
| # Step 6: Merge melodic + bass into final MIDI | |
| job_status[job_id] = {"step": 6, "label": "Assembling final result...", "done": False} | |
| merged_path = job_dir / "transcription.mid" | |
| merge_stems(str(piano_opt), str(bass_opt), str(merged_path)) | |
| # Clean up large stem files and intermediates | |
| stems_dir = job_dir / "stems" | |
| if stems_dir.exists(): | |
| shutil.rmtree(stems_dir) | |
| for f in [piano_raw, bass_raw, piano_opt, bass_opt]: | |
| f.unlink(missing_ok=True) | |
| job_status[job_id] = { | |
| "step": 7, "label": "Done!", "done": True, | |
| "result": { | |
| "job_id": job_id, | |
| "midi_url": f"/api/jobs/{job_id}/midi", | |
| "chords_url": f"/api/jobs/{job_id}/chords", | |
| "audio_url": f"/api/jobs/{job_id}/audio", | |
| "guitar_tab_url": f"/api/jobs/{job_id}/guitar-tab", | |
| "bass_tab_url": f"/api/jobs/{job_id}/bass-tab", | |
| "drum_tab_url": f"/api/jobs/{job_id}/drum-tab", | |
| "chords": chord_data, | |
| "mode": "full", | |
| }, | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| job_status[job_id] = { | |
| "step": -1, "label": str(e)[:200], "done": True, "error": str(e)[:200], | |
| } | |
| async def transcribe_full(file: UploadFile = File(...)): | |
| """Start full-song transcription with Demucs source separation. | |
| Returns immediately with a job_id. Poll /api/jobs/{job_id}/status. | |
| """ | |
| job_id = str(uuid.uuid4())[:8] | |
| job_dir = WORK_DIR / job_id | |
| job_dir.mkdir(exist_ok=True) | |
| suffix = Path(file.filename).suffix or ".m4a" | |
| audio_path = job_dir / f"upload{suffix}" | |
| content = await file.read() | |
| audio_path.write_bytes(content) | |
| job_status[job_id] = {"step": 0, "label": "Starting...", "done": False} | |
| thread = threading.Thread( | |
| target=run_full_transcription, | |
| args=(job_id, audio_path, job_dir), | |
| daemon=True, | |
| ) | |
| thread.start() | |
| return JSONResponse({"job_id": job_id}) | |
| async def get_job_status(job_id: str): | |
| """Get the current status of a full-song transcription job.""" | |
| status = job_status.get(job_id) | |
| if status is None: | |
| raise HTTPException(404, f"No job found with id {job_id}") | |
| return JSONResponse(status) | |
| async def get_guitar_tab(job_id: str): | |
| """Get the guitar tab data for a completed full-song job.""" | |
| tab_path = WORK_DIR / job_id / "guitar_tab.json" | |
| if not tab_path.exists(): | |
| raise HTTPException(404, f"No guitar tab data for job {job_id}") | |
| with open(tab_path) as f: | |
| return JSONResponse(json.load(f)) | |
| async def get_bass_tab(job_id: str): | |
| """Get the bass tab data for a completed full-song job.""" | |
| tab_path = WORK_DIR / job_id / "bass_tab.json" | |
| if not tab_path.exists(): | |
| raise HTTPException(404, f"No bass tab data for job {job_id}") | |
| with open(tab_path) as f: | |
| return JSONResponse(json.load(f)) | |
| async def get_drum_tab(job_id: str): | |
| """Get the drum tab data for a completed full-song job.""" | |
| tab_path = WORK_DIR / job_id / "drum_tab.json" | |
| if not tab_path.exists(): | |
| raise HTTPException(404, f"No drum tab data for job {job_id}") | |
| with open(tab_path) as f: | |
| return JSONResponse(json.load(f)) | |
| async def get_audio(job_id: str): | |
| """Serve the original uploaded audio file back for playback.""" | |
| job_dir = WORK_DIR / job_id | |
| if not job_dir.exists(): | |
| raise HTTPException(404, f"No job found with id {job_id}") | |
| # Find the upload file (upload.mp3, upload.m4a, upload.wav, etc.) | |
| media_types = { | |
| ".mp3": "audio/mpeg", ".m4a": "audio/mp4", ".wav": "audio/wav", | |
| ".ogg": "audio/ogg", ".flac": "audio/flac", | |
| } | |
| for f in job_dir.iterdir(): | |
| if f.name.startswith("upload"): | |
| mt = media_types.get(f.suffix.lower(), "audio/mpeg") | |
| return FileResponse(f, media_type=mt, filename=f"original{f.suffix}") | |
| raise HTTPException(404, f"No audio file found for job {job_id}") | |
| async def health(): | |
| return {"status": "ok"} | |
| # Serve the built React frontend (in production) | |
| DIST_DIR = Path(__file__).resolve().parent.parent / "app" / "dist" | |
| if DIST_DIR.exists(): | |
| # Serve static assets | |
| app.mount("/assets", StaticFiles(directory=str(DIST_DIR / "assets")), name="assets") | |
| # Serve MIDI files if they exist | |
| midi_dir = DIST_DIR / "midi" | |
| if midi_dir.exists(): | |
| app.mount("/midi", StaticFiles(directory=str(midi_dir)), name="midi") | |
| # Catch-all: serve index.html for SPA routing | |
| async def serve_spa(path: str): | |
| file_path = DIST_DIR / path | |
| if file_path.is_file(): | |
| return FileResponse(file_path) | |
| return FileResponse(DIST_DIR / "index.html") | |