| | from fastapi import FastAPI, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel |
| | import traceback |
| | import numpy as np |
| | import torch |
| | import base64 |
| | import io |
| | import os |
| | import logging |
| | import whisper |
| | import soundfile as sf |
| | from inference import OmniInference |
| | import tempfile |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | app = FastAPI() |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | class AudioRequest(BaseModel): |
| | audio_data: str |
| | sample_rate: int |
| |
|
| | class AudioResponse(BaseModel): |
| | audio_data: str |
| | text: str = "" |
| |
|
| | |
| | INITIALIZATION_STATUS = { |
| | "model_loaded": False, |
| | "error": None |
| | } |
| |
|
| | |
| | model = None |
| |
|
| | def initialize_model(): |
| | """Initialize the OmniInference model""" |
| | global model, INITIALIZATION_STATUS |
| | try: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | logger.info(f"Initializing OmniInference model on device: {device}") |
| | |
| | ckpt_path = os.path.abspath('models') |
| | logger.info(f"Loading models from: {ckpt_path}") |
| | |
| | if not os.path.exists(ckpt_path): |
| | raise RuntimeError(f"Checkpoint path {ckpt_path} does not exist") |
| | |
| | model = OmniInference(ckpt_path, device=device) |
| | model.warm_up() |
| | |
| | INITIALIZATION_STATUS["model_loaded"] = True |
| | logger.info("OmniInference model initialized successfully") |
| | return True |
| | except Exception as e: |
| | INITIALIZATION_STATUS["error"] = str(e) |
| | logger.error(f"Failed to initialize model: {e}\n{traceback.format_exc()}") |
| | return False |
| |
|
| | @app.on_event("startup") |
| | async def startup_event(): |
| | """Initialize model on startup""" |
| | initialize_model() |
| |
|
| | @app.get("/api/v1/health") |
| | def health_check(): |
| | """Health check endpoint""" |
| | status = { |
| | "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", |
| | "initialization_status": INITIALIZATION_STATUS |
| | } |
| | |
| | if model is not None: |
| | status.update({ |
| | "device": str(model.device), |
| | "model_loaded": True, |
| | "warm_up_complete": True |
| | }) |
| | |
| | return status |
| |
|
| | @app.post("/api/v1/inference") |
| | async def inference(request: AudioRequest) -> AudioResponse: |
| | """Run inference with OmniInference model""" |
| | if not INITIALIZATION_STATUS["model_loaded"]: |
| | raise HTTPException( |
| | status_code=503, |
| | detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" |
| | ) |
| | |
| | try: |
| | logger.info(f"Received inference request with sample rate: {request.sample_rate}") |
| | |
| | |
| | audio_bytes = base64.b64decode(request.audio_data) |
| | audio_array = np.load(io.BytesIO(audio_bytes)) |
| | |
| | |
| | with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav: |
| | |
| | if request.sample_rate != 16000: |
| | |
| | logger.warning("Sample rate conversion not implemented. Assuming 16kHz.") |
| | |
| | |
| | audio_data = whisper.pad_or_trim(audio_array.flatten()) |
| | |
| | sf.write(temp_wav.name, audio_data, 16000) |
| | |
| | |
| | final_text = "" |
| | |
| | with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav_out: |
| | |
| | for audio_stream, text_stream in model.run_AT_batch_stream( |
| | temp_wav.name, |
| | stream_stride=4, |
| | max_returned_tokens=2048, |
| | save_path=temp_wav_out.name, |
| | sample_rate=request.sample_rate |
| | ): |
| | if text_stream: |
| | final_text += text_stream |
| | final_audio, sample_rate = sf.read(temp_wav_out.name) |
| | assert sample_rate == request.sample_rate |
| |
|
| | |
| | buffer = io.BytesIO() |
| | np.save(buffer, final_audio) |
| | audio_b64 = base64.b64encode(buffer.getvalue()).decode() |
| |
|
| | return AudioResponse( |
| | audio_data=audio_b64, |
| | text=final_text.strip() |
| | ) |
| |
|
| | except Exception as e: |
| | logger.error(f"Inference failed: {str(e)}", exc_info=True) |
| | raise HTTPException( |
| | status_code=500, |
| | detail=str(e) |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|