| """ |
| SALMONN FastAPI Server |
| HTTP API for audio understanding and transcription. |
| """ |
|
|
| import os |
| import tempfile |
| import shutil |
| from pathlib import Path |
| from typing import Optional |
|
|
| import yaml |
| import uvicorn |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
| from fastapi.responses import JSONResponse |
| from pydantic import BaseModel |
| from omegaconf import OmegaConf |
|
|
| from inference import SALMONNInference |
|
|
| |
| CONFIG_PATH = os.environ.get("SALMONN_CONFIG", "config.yaml") |
|
|
| with open(CONFIG_PATH, "r") as f: |
| config = OmegaConf.create(yaml.safe_load(f)) |
|
|
| |
| app = FastAPI( |
| title="SALMONN API", |
| description="Audio Language Model for Speech, Audio Events, and Music Understanding", |
| version="1.0.0", |
| ) |
|
|
| |
| model: Optional[SALMONNInference] = None |
|
|
|
|
| class TranscribeResponse(BaseModel): |
| text: str |
| status: str = "success" |
|
|
|
|
| class ChatResponse(BaseModel): |
| question: str |
| answer: str |
| status: str = "success" |
|
|
|
|
| class HealthResponse(BaseModel): |
| status: str |
| model_loaded: bool |
| device: str |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """Load model on startup.""" |
| global model |
| print("Starting SALMONN server...") |
| model = SALMONNInference(CONFIG_PATH) |
| model.load() |
| print("Server ready!") |
|
|
|
|
| @app.get("/", response_model=dict) |
| async def root(): |
| """Root endpoint with API info.""" |
| return { |
| "name": "SALMONN API", |
| "version": "1.0.0", |
| "endpoints": { |
| "/health": "Health check", |
| "/transcribe": "Transcribe audio (POST)", |
| "/chat": "Ask questions about audio (POST)", |
| } |
| } |
|
|
|
|
| @app.get("/health", response_model=HealthResponse) |
| async def health(): |
| """Health check endpoint.""" |
| return HealthResponse( |
| status="healthy" if model and model._loaded else "loading", |
| model_loaded=model._loaded if model else False, |
| device=str(model.device) if model else "unknown", |
| ) |
|
|
|
|
| @app.post("/transcribe", response_model=TranscribeResponse) |
| async def transcribe( |
| audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"), |
| ): |
| """ |
| Transcribe an audio file to text. |
| |
| - **audio**: Audio file to transcribe |
| |
| Returns transcribed text. |
| """ |
| if not model or not model._loaded: |
| raise HTTPException(status_code=503, detail="Model not loaded yet") |
| |
| |
| suffix = Path(audio.filename).suffix if audio.filename else ".wav" |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
| shutil.copyfileobj(audio.file, tmp) |
| tmp_path = tmp.name |
| |
| try: |
| text = model.transcribe(tmp_path) |
| return TranscribeResponse(text=text) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
| finally: |
| os.unlink(tmp_path) |
|
|
|
|
| @app.post("/chat", response_model=ChatResponse) |
| async def chat( |
| audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"), |
| question: str = Form(..., description="Question about the audio"), |
| ): |
| """ |
| Ask a question about an audio file. |
| |
| - **audio**: Audio file to analyze |
| - **question**: Question about the audio content |
| |
| Returns the model's answer. |
| """ |
| if not model or not model._loaded: |
| raise HTTPException(status_code=503, detail="Model not loaded yet") |
| |
| |
| suffix = Path(audio.filename).suffix if audio.filename else ".wav" |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
| shutil.copyfileobj(audio.file, tmp) |
| tmp_path = tmp.name |
| |
| try: |
| answer = model.chat(tmp_path, question) |
| return ChatResponse(question=question, answer=answer) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
| finally: |
| os.unlink(tmp_path) |
|
|
|
|
| @app.post("/describe") |
| async def describe( |
| audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"), |
| ): |
| """ |
| Get a detailed description of the audio content. |
| |
| - **audio**: Audio file to describe |
| |
| Returns description of the audio. |
| """ |
| if not model or not model._loaded: |
| raise HTTPException(status_code=503, detail="Model not loaded yet") |
| |
| |
| suffix = Path(audio.filename).suffix if audio.filename else ".wav" |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
| shutil.copyfileobj(audio.file, tmp) |
| tmp_path = tmp.name |
| |
| try: |
| description = model.describe(tmp_path) |
| return {"description": description, "status": "success"} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
| finally: |
| os.unlink(tmp_path) |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run( |
| "server:app", |
| host=config.server.host, |
| port=config.server.port, |
| reload=config.server.get("reload", False), |
| ) |
|
|