ONNX version?

#3
by ankitguru123 - opened

I wanna use this model with

https://github.com/groxaxo/parakeet-tdt-0.6b-v3-fastapi-openai

can you provide any help? onnx + int8 will make this sooo fast and soooo better

primeLine AI Services org

Here is an example for openai compatible server with just in time convert:

import asyncio
import gc
import io
import json
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import numpy as np
import soundfile as sf
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware

middleware = [
    Middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
]


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Startup and shutdown logic"""
    global batch_processor
    logger.info("Starting ASR Proxy Server...")
    try:
        load_model()
    except Exception as e:
        logger.error(f"Failed to initialize model on startup: {e}")
    batch_processor = BatchProcessor(BATCH_SIZE, BATCH_TIMEOUT_MS, MAX_QUEUE_SIZE)
    batch_processor.start()
    logger.info(
        f"Batch processor started (batch_size={BATCH_SIZE}, "
        f"timeout={BATCH_TIMEOUT_MS}ms, max_queue={MAX_QUEUE_SIZE})"
    )
    yield
    logger.info("Shutting down ASR Proxy Server...")
    if batch_processor:
        await batch_processor.stop()


app = FastAPI(title="Openai ASR Server", middleware=middleware, lifespan=lifespan)

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Silence httpx logs
logging.getLogger("httpx").setLevel(logging.WARNING)

# Environment variables
ASR_MODEL_NAME = os.getenv("ASR_MODEL_NAME", "primeline/parakeet-primeline")
ASR_MODEL_PATH = os.getenv("ASR_MODEL_PATH", None)
ASR_QUANTIZATION = os.getenv("ASR_QUANTIZATION", None) or None
ASR_PROVIDER = os.getenv("ASR_PROVIDER", "tensorrt")
TRT_FP16_ENABLE = os.getenv("TRT_FP16_ENABLE", "true").lower() == "true"
TRT_MAX_WORKSPACE_GB = int(os.getenv("TRT_MAX_WORKSPACE_GB", "6"))
USE_VAD = os.getenv("USE_VAD", "true").lower() == "true"

# NeMo export settings (only used when ONNX files not cached)
NEMO_REPO_ID = os.getenv("NEMO_REPO_ID", "primeline/parakeet-primeline")
NEMO_FILENAME = os.getenv("NEMO_FILENAME", "2_95_WER.nemo")
ONNX_CACHE_DIR = Path(os.getenv("ONNX_CACHE_DIR", "/root/.cache/huggingface/onnx_export"))

# Batching configuration
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "8"))
BATCH_TIMEOUT_MS = float(os.getenv("BATCH_TIMEOUT_MS", "100"))
MAX_QUEUE_SIZE = int(os.getenv("MAX_QUEUE_SIZE", "64"))

# Global state
asr_model = None
model_loading = False
batch_processor: Optional["BatchProcessor"] = None


# ---------------------------------------------------------------------------
# Batch processor
# ---------------------------------------------------------------------------


@dataclass
class _BatchItem:
    """A single queued transcription request."""
    waveform: np.ndarray
    sample_rate: int
    future: asyncio.Future
    filename: str
    submit_time: float = field(default_factory=time.time)


class BatchProcessor:
    """Collects concurrent transcription requests and processes them with
    controlled GPU concurrency.  Incoming requests are queued; a background
    task drains up to *max_batch_size* items (or fewer after *batch_timeout_ms*)
    and runs inference sequentially on a single-thread executor so the event
    loop stays responsive while only one GPU call is in-flight at a time."""

    def __init__(self, max_batch_size: int, batch_timeout_ms: float, max_queue_size: int):
        self.max_batch_size = max_batch_size
        self.batch_timeout_ms = batch_timeout_ms
        self._queue: asyncio.Queue[_BatchItem] = asyncio.Queue(maxsize=max_queue_size)
        self._gpu_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="gpu")
        self._task: Optional[asyncio.Task] = None
        self._stats = {
            "batches_processed": 0,
            "items_processed": 0,
            "items_failed": 0,
        }

    # -- lifecycle ------------------------------------------------------------

    def start(self):
        self._task = asyncio.create_task(self._process_loop())

    async def stop(self):
        if self._task:
            self._task.cancel()
            try:
                await self._task
            except asyncio.CancelledError:
                pass
        self._gpu_executor.shutdown(wait=False)

    # -- public API -----------------------------------------------------------

    @property
    def queue_size(self) -> int:
        return self._queue.qsize()

    @property
    def stats(self) -> dict:
        return {**self._stats, "queue_depth": self._queue.qsize()}

    async def submit(self, waveform: np.ndarray, sample_rate: int, filename: str) -> dict:
        """Submit audio for transcription.  Blocks until result is ready.
        Raises asyncio.QueueFull when the service is overloaded."""
        loop = asyncio.get_running_loop()
        future: asyncio.Future = loop.create_future()
        item = _BatchItem(waveform=waveform, sample_rate=sample_rate,
                          future=future, filename=filename)
        try:
            self._queue.put_nowait(item)
        except asyncio.QueueFull:
            raise HTTPException(
                status_code=429,
                detail=f"Transcription queue full ({self._queue.maxsize}). Try again later.",
            )
        return await future

    # -- internal -------------------------------------------------------------

    async def _collect_batch(self) -> list[_BatchItem]:
        """Wait for at least one item, then collect up to max_batch_size
        within the timeout window."""
        batch: list[_BatchItem] = []
        # Block until the first item arrives
        batch.append(await self._queue.get())

        deadline = asyncio.get_event_loop().time() + self.batch_timeout_ms / 1000.0
        while len(batch) < self.max_batch_size:
            remaining = deadline - asyncio.get_event_loop().time()
            if remaining <= 0:
                break
            try:
                item = await asyncio.wait_for(self._queue.get(), timeout=remaining)
                batch.append(item)
            except asyncio.TimeoutError:
                break
        return batch

    def _run_inference(self, waveform: np.ndarray, sample_rate: int):
        """Blocking inference β€” called inside the thread-pool executor."""
        return asr_model.recognize(waveform, sample_rate=sample_rate)

    async def _process_loop(self):
        """Background loop: collect batches and process items."""
        loop = asyncio.get_running_loop()
        while True:
            try:
                batch = await self._collect_batch()
                logger.info(f"Batch collected: {len(batch)} item(s)")
                self._stats["batches_processed"] += 1

                for item in batch:
                    try:
                        start_time = time.time()
                        result = await loop.run_in_executor(
                            self._gpu_executor,
                            self._run_inference,
                            item.waveform,
                            item.sample_rate,
                        )
                        elapsed = round(time.time() - start_time, 3)
                        queue_wait = round(start_time - item.submit_time, 3)

                        full_text = _extract_text(result)
                        segments = _result_to_segments(result)

                        total_duration = 0.0
                        if segments:
                            total_duration = max(seg["end"] for seg in segments)
                        if total_duration == 0.0:
                            total_duration = round(len(item.waveform) / item.sample_rate, 3)

                        response = {
                            "text": full_text,
                            "segments": segments,
                            "language": "en",
                            "duration": total_duration,
                            "transcription_time": elapsed,
                            "queue_wait_time": queue_wait,
                            "task": "transcribe",
                        }
                        item.future.set_result(response)
                        self._stats["items_processed"] += 1
                        logger.info(
                            f"Batch item '{item.filename}': {elapsed}s inference, "
                            f"{queue_wait}s queue wait, {len(full_text)} chars"
                        )
                    except Exception as e:
                        if not item.future.done():
                            item.future.set_exception(e)
                        self._stats["items_failed"] += 1
                        logger.error(f"Batch item '{item.filename}' failed: {e}")

            except asyncio.CancelledError:
                # Drain remaining items on shutdown
                while not self._queue.empty():
                    try:
                        item = self._queue.get_nowait()
                        if not item.future.done():
                            item.future.set_exception(
                                HTTPException(status_code=503, detail="Server shutting down")
                            )
                    except asyncio.QueueEmpty:
                        break
                raise
            except Exception as e:
                logger.error(f"Batch processing loop error: {e}", exc_info=True)
                await asyncio.sleep(0.1)  # avoid tight error loop


def _build_providers():
    """Build ONNX Runtime provider list based on configuration."""
    if ASR_PROVIDER == "tensorrt":
        try:
            import tensorrt_libs  # noqa: F401
        except ImportError:
            logger.warning("tensorrt_libs not available, will try TensorRT anyway")

        return [
            (
                "TensorrtExecutionProvider",
                {
                    "trt_max_workspace_size": TRT_MAX_WORKSPACE_GB * 1024**3,
                    "trt_fp16_enable": TRT_FP16_ENABLE,
                },
            ),
            "CUDAExecutionProvider",
            "CPUExecutionProvider",
        ]
    elif ASR_PROVIDER == "cuda":
        return ["CUDAExecutionProvider", "CPUExecutionProvider"]
    else:
        return ["CPUExecutionProvider"]


def _ensure_onnx_export():
    """Export .nemo model to ONNX if not already cached. Returns local ONNX path."""
    onnx_dir = ONNX_CACHE_DIR / NEMO_REPO_ID.replace("/", "_")
    marker = onnx_dir / "config.json"

    if marker.exists():
        logger.info(f"ONNX export found at {onnx_dir}, skipping export.")
        return str(onnx_dir)

    logger.info(f"No ONNX export cached. Exporting {NEMO_REPO_ID}/{NEMO_FILENAME}...")
    onnx_dir.mkdir(parents=True, exist_ok=True)

    from huggingface_hub import hf_hub_download
    from nemo.collections.asr.models import ASRModel

    # Download .nemo checkpoint
    nemo_path = hf_hub_download(repo_id=NEMO_REPO_ID, filename=NEMO_FILENAME)
    logger.info(f"Downloaded .nemo to {nemo_path}, loading model for export...")

    # Load on CPU to minimise GPU memory during export
    model = ASRModel.restore_from(nemo_path, map_location="cpu")
    model.eval()

    # Export to ONNX
    onnx_path = str(onnx_dir / "model.onnx")
    logger.info(f"Exporting to ONNX: {onnx_path}")
    model.export(onnx_path)

    # NeMo produces model_encoder.onnx + model_decoder_joint.onnx
    # onnx-asr expects encoder-model.onnx + decoder_joint-model.onnx
    renames = {
        "model_encoder.onnx": "encoder-model.onnx",
        "model_decoder_joint.onnx": "decoder_joint-model.onnx",
        "model_encoder.onnx.data": "encoder-model.onnx.data",
    }
    for src_name, dst_name in renames.items():
        src = onnx_dir / src_name
        if src.exists():
            src.rename(onnx_dir / dst_name)
            logger.info(f"Renamed {src_name} -> {dst_name}")

    # Write vocab.txt
    vocab_path = onnx_dir / "vocab.txt"
    with vocab_path.open("wt") as f:
        for i, token in enumerate([*model.tokenizer.vocab, "<blk>"]):
            f.write(f"{token} {i}\n")
    logger.info(f"Wrote vocab ({i+1} tokens) to {vocab_path}")

    # Write config.json (written last β€” acts as completion marker)
    config = {
        "model_type": "nemo-conformer-tdt",
        "features_size": 128,
        "subsampling_factor": 8,
        "max_tokens_per_step": 10,
    }
    with marker.open("w") as f:
        json.dump(config, f, indent=2)
    logger.info(f"Wrote config.json to {marker}")

    # Free NeMo/torch memory before loading with onnx-asr
    del model
    gc.collect()
    try:
        import torch
        torch.cuda.empty_cache()
    except Exception:
        pass

    logger.info(f"ONNX export complete at {onnx_dir}")
    return str(onnx_dir)


def load_model():
    """Load the ASR model lazily on first request"""
    global asr_model, model_loading

    if asr_model is not None:
        return

    if model_loading:
        max_wait = 120
        waited = 0
        while model_loading and waited < max_wait:
            time.sleep(0.5)
            waited += 0.5
        return

    model_loading = True
    try:
        import onnx_asr

        # If model path not set, ensure ONNX export exists
        model_path = ASR_MODEL_PATH
        if not model_path:
            model_path = _ensure_onnx_export()

        providers = _build_providers()
        logger.info(
            f"Loading ASR model: {ASR_MODEL_NAME} from {model_path} "
            f"(quantization={ASR_QUANTIZATION}, providers={[p if isinstance(p, str) else p[0] for p in providers]})"
        )

        model = onnx_asr.load_model(
            ASR_MODEL_NAME,
            path=model_path,
            quantization=ASR_QUANTIZATION,
            providers=providers,
        )

        # Use timestamps adapter for segment-level results
        asr_model = model.with_timestamps()

        if USE_VAD:
            vad = onnx_asr.load_vad("silero", providers=["CPUExecutionProvider"])
            asr_model = model.with_vad(vad).with_timestamps()
            logger.info("VAD (Silero) enabled for long audio support.")

        logger.info("ASR model loaded successfully.")

        # Warmup
        warmup_audio_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), "flo.wav"
        )
        warmup_iterations = 3
        if os.path.exists(warmup_audio_path):
            logger.info(f"Performing {warmup_iterations} warmup transcriptions...")
            for i in range(warmup_iterations):
                try:
                    warmup_start = time.time()
                    warmup_result = asr_model.recognize(warmup_audio_path)
                    warmup_time = time.time() - warmup_start
                    if i == 0 or i == warmup_iterations - 1:
                        text = _extract_text(warmup_result)
                        logger.info(
                            f"Warmup {i+1}/{warmup_iterations}: {warmup_time:.2f}s - '{text[:80]}'"
                        )
                except Exception as e:
                    logger.warning(f"Warmup {i+1}/{warmup_iterations} failed (non-fatal): {e}")
        else:
            logger.warning(f"Warmup audio not found at {warmup_audio_path}, skipping.")

    except Exception as e:
        logger.critical(f"FATAL: Could not load ASR model. Error: {e}")
        raise
    finally:
        model_loading = False


def _extract_text(result):
    """Extract text from various onnx-asr result types."""
    if isinstance(result, str):
        return result
    if hasattr(result, "text"):
        return result.text
    # VAD iterator result
    parts = []
    try:
        for seg in result:
            if hasattr(seg, "text"):
                parts.append(seg.text)
            elif isinstance(seg, str):
                parts.append(seg)
    except TypeError:
        return str(result)
    return " ".join(parts)


def _result_to_segments(result):
    """Convert onnx-asr result to OpenAI-compatible segments list."""
    segments = []

    # Check if it's an iterator (VAD segments)
    items = []
    try:
        if hasattr(result, "__iter__") and not isinstance(result, str) and not hasattr(result, "text"):
            items = list(result)
        else:
            items = [result]
    except TypeError:
        items = [result]

    for idx, item in enumerate(items):
        if hasattr(item, "start") and hasattr(item, "end"):
            # SegmentResult / TimestampedSegmentResult from VAD
            segments.append({
                "id": idx,
                "start": round(item.start, 3),
                "end": round(item.end, 3),
                "text": item.text.strip() if hasattr(item, "text") else "",
                "seek": 0,
                "tokens": list(item.tokens) if hasattr(item, "tokens") and item.tokens else [],
                "temperature": 0.0,
                "avg_logprob": None,
                "compression_ratio": None,
                "no_speech_prob": None,
            })
        elif hasattr(item, "timestamps") and item.timestamps:
            # TimestampedResult without VAD β€” build segments from token timestamps
            segments.append({
                "id": idx,
                "start": round(item.timestamps[0], 3) if item.timestamps else 0.0,
                "end": round(item.timestamps[-1], 3) if item.timestamps else 0.0,
                "text": item.text.strip() if hasattr(item, "text") else "",
                "seek": 0,
                "tokens": list(item.tokens) if hasattr(item, "tokens") and item.tokens else [],
                "temperature": 0.0,
                "avg_logprob": None,
                "compression_ratio": None,
                "no_speech_prob": None,
            })
        elif hasattr(item, "text"):
            # Plain text result
            segments.append({
                "id": idx,
                "start": 0.0,
                "end": 0.0,
                "text": item.text.strip(),
                "seek": 0,
                "tokens": [],
                "temperature": 0.0,
                "avg_logprob": None,
                "compression_ratio": None,
                "no_speech_prob": None,
            })

    return segments


@app .get("/health")
async def health_check(deep: bool = False):
    """Health check endpoint."""
    base_status = {
        "model_loaded": asr_model is not None,
        "model_name": ASR_MODEL_NAME,
        "provider": ASR_PROVIDER,
        "quantization": ASR_QUANTIZATION,
        "vad_enabled": USE_VAD,
        "batch": batch_processor.stats if batch_processor else None,
    }

    if not deep:
        base_status["status"] = "healthy" if asr_model else "degraded"
        return base_status

    try:
        if not asr_model:
            load_model()
        if not asr_model:
            base_status["status"] = "unhealthy"
            base_status["error"] = "Model not loaded"
            return JSONResponse(content=base_status, status_code=503)

        warmup_audio_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), "flo.wav"
        )
        if not os.path.exists(warmup_audio_path):
            base_status["status"] = "degraded"
            base_status["error"] = "Health check audio file not found"
            return base_status

        start_time = time.time()
        result = asr_model.recognize(warmup_audio_path)
        transcription_time = round(time.time() - start_time, 3)

        text = _extract_text(result)
        if len(text) < 3:
            base_status["status"] = "unhealthy"
            base_status["error"] = "Transcription returned empty or too short result"
            return JSONResponse(content=base_status, status_code=503)

        base_status["status"] = "healthy"
        base_status["transcription_test"] = {
            "success": True,
            "text_length": len(text),
            "transcription_time_seconds": transcription_time,
        }
        return base_status

    except Exception as e:
        logger.error(f"Deep health check failed: {e}")
        base_status["status"] = "unhealthy"
        base_status["error"] = str(e)
        return JSONResponse(content=base_status, status_code=503)


@app .post("/v1/audio/transcriptions")
async def transcribe_rest(file: UploadFile = File(...)):
    """Handles audio transcription via REST API (OpenAI compatible).
    Requests are queued and processed in batches for scalability."""
    if not asr_model:
        load_model()
    if not asr_model:
        raise HTTPException(status_code=503, detail="ASR model not available.")
    if not batch_processor:
        raise HTTPException(status_code=503, detail="Batch processor not ready.")

    logger.info(f"transcribe_rest: Received request for file: {file.filename}")

    try:
        # Read audio bytes and decode with soundfile (handles wav, flac, ogg, etc.)
        audio_bytes = await file.read()
        waveform, sample_rate = sf.read(io.BytesIO(audio_bytes), dtype="float32")

        # Convert stereo to mono if needed
        if waveform.ndim == 2:
            waveform = waveform.mean(axis=1)

        logger.info(
            f"transcribe_rest: Audio loaded, {len(waveform)} samples, {sample_rate}Hz, "
            f"{len(waveform) / sample_rate:.1f}s β€” submitting to batch queue "
            f"(depth={batch_processor.queue_size})"
        )

        response = await batch_processor.submit(
            waveform, sample_rate, file.filename or "unknown"
        )

        logger.info(
            f"transcribe_rest: Completed for '{file.filename}', "
            f"{response['transcription_time']}s inference, "
            f"{response.get('queue_wait_time', 0)}s queued"
        )

        return JSONResponse(content=response)

    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"transcribe_rest: Unhandled exception: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))


def main():
    """Main application entry point."""
    import uvicorn

    port = int(os.getenv("PORT", "8000"))
    log_level = os.getenv("LOG_LEVEL", "warning")

    uvicorn.run(
        app,
        host="0.0.0.0",
        port=port,
        log_level=log_level,
        workers=None,
        forwarded_allow_ips="*",
        proxy_headers=True,
        timeout_keep_alive=900,
        reload=False,
    )


if __name__ == "__main__":
    main()

Sign up or log in to comment