| import ctypes |
| from ctypes import POINTER, Structure, c_bool, c_char_p, c_float, c_int32, c_int64, c_size_t, c_uint8, c_uint32 |
| from pathlib import Path |
|
|
| import numpy as np |
|
|
|
|
| class LiteAsrError(RuntimeError): |
| pass |
|
|
|
|
| class _F32Buffer(Structure): |
| _fields_ = [ |
| ("data", POINTER(c_float)), |
| ("len", c_size_t), |
| ("cap", c_size_t), |
| ("rows", c_size_t), |
| ("cols", c_size_t), |
| ] |
|
|
|
|
| class _I64Buffer(Structure): |
| _fields_ = [ |
| ("data", POINTER(c_int64)), |
| ("len", c_size_t), |
| ("cap", c_size_t), |
| ] |
|
|
|
|
| class _Utf8Buffer(Structure): |
| _fields_ = [ |
| ("data", POINTER(c_uint8)), |
| ("len", c_size_t), |
| ("cap", c_size_t), |
| ] |
|
|
|
|
| def _default_dll_candidates() -> list[Path]: |
| return [ |
| Path("target/debug/liteasr_ffi.dll"), |
| Path("target/release/liteasr_ffi.dll"), |
| ] |
|
|
|
|
| class LiteAsrFfi: |
| def __init__(self, dll_path: str | Path | None = None) -> None: |
| if dll_path is None: |
| for candidate in _default_dll_candidates(): |
| if candidate.exists(): |
| dll_path = candidate |
| break |
| if dll_path is None: |
| raise LiteAsrError("liteasr_ffi.dll not found. Build with `cargo build`.") |
|
|
| self._dll_path = Path(dll_path) |
| if not self._dll_path.exists(): |
| raise LiteAsrError(f"DLL not found: {self._dll_path}") |
|
|
| self._lib = ctypes.CDLL(str(self._dll_path)) |
| self._configure_signatures() |
|
|
| @property |
| def dll_path(self) -> Path: |
| return self._dll_path |
|
|
| def _configure_signatures(self) -> None: |
| self._lib.liteasr_last_error_message.argtypes = [] |
| self._lib.liteasr_last_error_message.restype = c_char_p |
|
|
| self._lib.liteasr_preprocess_wav.argtypes = [ |
| c_char_p, |
| c_uint32, |
| c_size_t, |
| POINTER(_F32Buffer), |
| ] |
| self._lib.liteasr_preprocess_wav.restype = c_int32 |
|
|
| self._lib.liteasr_build_prompt_ids.argtypes = [ |
| c_char_p, |
| c_char_p, |
| c_char_p, |
| c_bool, |
| c_bool, |
| c_bool, |
| POINTER(_I64Buffer), |
| ] |
| self._lib.liteasr_build_prompt_ids.restype = c_int32 |
|
|
| self._lib.liteasr_decode_tokens.argtypes = [ |
| c_char_p, |
| POINTER(c_int64), |
| c_size_t, |
| c_bool, |
| POINTER(_Utf8Buffer), |
| ] |
| self._lib.liteasr_decode_tokens.restype = c_int32 |
|
|
| self._lib.liteasr_apply_suppression.argtypes = [ |
| POINTER(c_float), |
| c_size_t, |
| POINTER(c_int64), |
| c_size_t, |
| POINTER(c_int64), |
| c_size_t, |
| c_size_t, |
| ] |
| self._lib.liteasr_apply_suppression.restype = c_int32 |
|
|
| self._lib.liteasr_free_f32_buffer.argtypes = [POINTER(_F32Buffer)] |
| self._lib.liteasr_free_f32_buffer.restype = None |
| self._lib.liteasr_free_i64_buffer.argtypes = [POINTER(_I64Buffer)] |
| self._lib.liteasr_free_i64_buffer.restype = None |
| self._lib.liteasr_free_utf8_buffer.argtypes = [POINTER(_Utf8Buffer)] |
| self._lib.liteasr_free_utf8_buffer.restype = None |
|
|
| def _raise_last_error(self, fallback: str) -> None: |
| message = self._lib.liteasr_last_error_message() |
| if message: |
| text = message.decode("utf-8", errors="replace") |
| raise LiteAsrError(text) |
| raise LiteAsrError(fallback) |
|
|
| def preprocess_wav(self, wav_path: str | Path, target_sr: int, n_mels: int) -> np.ndarray: |
| out = _F32Buffer() |
| rc = self._lib.liteasr_preprocess_wav( |
| str(Path(wav_path)).encode("utf-8"), |
| int(target_sr), |
| int(n_mels), |
| ctypes.byref(out), |
| ) |
| if rc != 0: |
| self._raise_last_error("liteasr_preprocess_wav failed") |
|
|
| try: |
| flat = np.ctypeslib.as_array(out.data, shape=(out.len,)) |
| matrix = flat.reshape((out.rows, out.cols)).copy() |
| return matrix |
| finally: |
| self._lib.liteasr_free_f32_buffer(ctypes.byref(out)) |
|
|
| def build_prompt_ids( |
| self, |
| tokenizer_json_path: str | Path, |
| language: str, |
| task: str, |
| with_timestamps: bool, |
| omit_language_token: bool, |
| omit_notimestamps_token: bool, |
| ) -> list[int]: |
| out = _I64Buffer() |
| rc = self._lib.liteasr_build_prompt_ids( |
| str(Path(tokenizer_json_path)).encode("utf-8"), |
| language.encode("utf-8"), |
| task.encode("utf-8"), |
| bool(with_timestamps), |
| bool(omit_language_token), |
| bool(omit_notimestamps_token), |
| ctypes.byref(out), |
| ) |
| if rc != 0: |
| self._raise_last_error("liteasr_build_prompt_ids failed") |
|
|
| try: |
| arr = np.ctypeslib.as_array(out.data, shape=(out.len,)) |
| return [int(v) for v in arr.tolist()] |
| finally: |
| self._lib.liteasr_free_i64_buffer(ctypes.byref(out)) |
|
|
| def decode_tokens( |
| self, |
| tokenizer_json_path: str | Path, |
| token_ids: list[int], |
| skip_special_tokens: bool = True, |
| ) -> str: |
| out = _Utf8Buffer() |
| token_np = np.array(token_ids, dtype=np.int64) |
| token_ptr = token_np.ctypes.data_as(POINTER(c_int64)) |
| rc = self._lib.liteasr_decode_tokens( |
| str(Path(tokenizer_json_path)).encode("utf-8"), |
| token_ptr, |
| int(token_np.shape[0]), |
| bool(skip_special_tokens), |
| ctypes.byref(out), |
| ) |
| if rc != 0: |
| self._raise_last_error("liteasr_decode_tokens failed") |
|
|
| try: |
| data = ctypes.string_at(out.data, out.len) |
| return data.decode("utf-8", errors="replace") |
| finally: |
| self._lib.liteasr_free_utf8_buffer(ctypes.byref(out)) |
|
|
| def apply_suppression( |
| self, |
| logits: np.ndarray, |
| suppress_ids: list[int], |
| begin_suppress_ids: list[int], |
| step: int, |
| ) -> np.ndarray: |
| if logits.dtype != np.float32 or not logits.flags["C_CONTIGUOUS"]: |
| logits = np.ascontiguousarray(logits, dtype=np.float32) |
|
|
| suppress_np = np.array(suppress_ids or [], dtype=np.int64) |
| begin_np = np.array(begin_suppress_ids or [], dtype=np.int64) |
|
|
| suppress_ptr = ( |
| suppress_np.ctypes.data_as(POINTER(c_int64)) |
| if suppress_np.size > 0 |
| else ctypes.cast(0, POINTER(c_int64)) |
| ) |
| begin_ptr = ( |
| begin_np.ctypes.data_as(POINTER(c_int64)) |
| if begin_np.size > 0 |
| else ctypes.cast(0, POINTER(c_int64)) |
| ) |
|
|
| rc = self._lib.liteasr_apply_suppression( |
| logits.ctypes.data_as(POINTER(c_float)), |
| int(logits.shape[0]), |
| suppress_ptr, |
| int(suppress_np.shape[0]), |
| begin_ptr, |
| int(begin_np.shape[0]), |
| int(step), |
| ) |
| if rc != 0: |
| self._raise_last_error("liteasr_apply_suppression failed") |
| return logits |
|
|