| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import traceback |
| import numpy as np |
| import torch |
| import base64 |
| import io |
| import os |
| import logging |
| import whisper |
| import soundfile as sf |
| from inference import OmniInference |
| import tempfile |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| class AudioRequest(BaseModel): |
| audio_data: str |
| sample_rate: int |
|
|
| class AudioResponse(BaseModel): |
| audio_data: str |
| text: str = "" |
|
|
| |
| INITIALIZATION_STATUS = { |
| "model_loaded": False, |
| "error": None |
| } |
|
|
| |
| model = None |
|
|
| def initialize_model(): |
| """Initialize the OmniInference model""" |
| global model, INITIALIZATION_STATUS |
| try: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"Initializing OmniInference model on device: {device}") |
| |
| ckpt_path = os.path.abspath('models') |
| logger.info(f"Loading models from: {ckpt_path}") |
| |
| if not os.path.exists(ckpt_path): |
| raise RuntimeError(f"Checkpoint path {ckpt_path} does not exist") |
| |
| model = OmniInference(ckpt_path, device=device) |
| model.warm_up() |
| |
| INITIALIZATION_STATUS["model_loaded"] = True |
| logger.info("OmniInference model initialized successfully") |
| return True |
| except Exception as e: |
| INITIALIZATION_STATUS["error"] = str(e) |
| logger.error(f"Failed to initialize model: {e}\n{traceback.format_exc()}") |
| return False |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """Initialize model on startup""" |
| initialize_model() |
|
|
| @app.get("/api/v1/health") |
| def health_check(): |
| """Health check endpoint""" |
| status = { |
| "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", |
| "initialization_status": INITIALIZATION_STATUS |
| } |
| |
| if model is not None: |
| status.update({ |
| "device": str(model.device), |
| "model_loaded": True, |
| "warm_up_complete": True |
| }) |
| |
| return status |
|
|
| @app.post("/api/v1/inference") |
| async def inference(request: AudioRequest) -> AudioResponse: |
| """Run inference with OmniInference model""" |
| if not INITIALIZATION_STATUS["model_loaded"]: |
| raise HTTPException( |
| status_code=503, |
| detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" |
| ) |
| |
| try: |
| logger.info(f"Received inference request with sample rate: {request.sample_rate}") |
| |
| |
| audio_bytes = base64.b64decode(request.audio_data) |
| audio_array = np.load(io.BytesIO(audio_bytes)) |
| |
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav: |
| |
| if request.sample_rate != 16000: |
| |
| logger.warning("Sample rate conversion not implemented. Assuming 16kHz.") |
| |
| |
| audio_data = whisper.pad_or_trim(audio_array.flatten()) |
| |
| sf.write(temp_wav.name, audio_data, 16000) |
| |
| |
| final_text = "" |
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav_out: |
| |
| for audio_stream, text_stream in model.run_AT_batch_stream( |
| temp_wav.name, |
| stream_stride=4, |
| max_returned_tokens=2048, |
| save_path=temp_wav_out.name, |
| sample_rate=request.sample_rate |
| ): |
| if text_stream: |
| final_text += text_stream |
| final_audio, sample_rate = sf.read(temp_wav_out.name) |
| assert sample_rate == request.sample_rate |
|
|
| |
| buffer = io.BytesIO() |
| np.save(buffer, final_audio) |
| audio_b64 = base64.b64encode(buffer.getvalue()).decode() |
|
|
| return AudioResponse( |
| audio_data=audio_b64, |
| text=final_text.strip() |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Inference failed: {str(e)}", exc_info=True) |
| raise HTTPException( |
| status_code=500, |
| detail=str(e) |
| ) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|