sorani2 / handler.py
Za6na's picture
Upload handler.py with huggingface_hub
3c5114a verified
"""
HuggingFace Inference Endpoint handler for Kurdish/Persian Whisper ASR.
Accepts audio (binary, base64, or filepath) and returns transcribed text.
Default model: whisper-largev3 full fine-tune.
"""
import base64
import gc
import io
import logging
from pathlib import Path
import numpy as np
import torch
import torchaudio
from transformers import WhisperForConditionalGeneration, WhisperProcessor
log = logging.getLogger(__name__)
SAMPLE_RATE = 16_000
CHUNK_SECONDS = 30
CHUNK_SAMPLES = CHUNK_SECONDS * SAMPLE_RATE
MODELS = {
"small": Path(__file__).parent / "models" / "whisper-small-peft-kurdish-on-persian-converted",
"full": Path(__file__).parent / "models" / "whisper-largev3-on-persian-centralkurdish-full",
}
DEFAULT_MODEL = "full"
# ---------------------------------------------------------------------------
# Audio helpers
# ---------------------------------------------------------------------------
def _audio_bytes_to_numpy(raw: bytes) -> np.ndarray:
"""Convert raw audio bytes to float32 mono 16 kHz numpy array.
Uses torchaudio (in-memory) instead of shelling out to ffmpeg.
"""
buf = io.BytesIO(raw)
waveform, sr = torchaudio.load(buf) # (channels, samples)
# Mix to mono.
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Resample if needed.
if sr != SAMPLE_RATE:
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
return waveform.squeeze(0).numpy()
def _chunk(audio: np.ndarray) -> list[np.ndarray]:
if len(audio) <= CHUNK_SAMPLES:
return [audio]
return [audio[i : i + CHUNK_SAMPLES] for i in range(0, len(audio), CHUNK_SAMPLES)]
# ---------------------------------------------------------------------------
# Endpoint handler
# ---------------------------------------------------------------------------
class EndpointHandler:
"""
HuggingFace Inference Endpoint handler.
Request format:
{
"inputs": <base64-encoded audio OR raw bytes>,
"parameters": {
"model": "full" | "small", # default: "full"
"language": "fa" # default: "fa"
}
}
Response format:
{"text": "transcribed text here"}
"""
def __init__(self, path: str = ""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._model: WhisperForConditionalGeneration | None = None
self._processor: WhisperProcessor | None = None
self._loaded_name: str | None = None
self._dtype = torch.float32
# If HF Inference Endpoint provides a path with model files, use it.
if path and (Path(path) / "config.json").exists():
MODELS["full"] = Path(path)
self._load(DEFAULT_MODEL)
def __call__(self, data: dict) -> dict:
inputs = data.get("inputs")
params = data.get("parameters", {}) or {}
model_name = params.get("model", DEFAULT_MODEL)
language = params.get("language", "fa")
if not inputs:
return {"error": "No audio provided in 'inputs'."}
if model_name != self._loaded_name:
self._load(model_name)
audio = self._resolve_audio(inputs)
text = self._transcribe(audio, language)
return {"text": text}
# ------------------------------------------------------------------
# Model lifecycle
# ------------------------------------------------------------------
def _load(self, name: str):
if name not in MODELS:
raise ValueError(f"Unknown model '{name}'. Choose from: {list(MODELS.keys())}")
if name == self._loaded_name:
return
self._unload()
model_path = str(MODELS[name])
is_cuda = self.device.type == "cuda"
self._processor = WhisperProcessor.from_pretrained(model_path) # type: ignore[assignment]
# Try optimal load: flash attention 2 + float16 on CUDA.
model = self._load_model(model_path, is_cuda)
model.config.use_cache = True
model.generation_config.forced_decoder_ids = None
if not is_cuda and next(model.parameters()).device.type != "cpu":
model.to(self.device) # type: ignore[arg-type]
model.eval()
# BetterTransformer fallback when Flash Attention is unavailable.
if is_cuda and not getattr(model.config, "_attn_implementation", None) == "flash_attention_2":
try:
model = model.to_bettertransformer() # type: ignore[assignment]
log.info("Using BetterTransformer (SDPA kernels).")
except Exception:
log.info("BetterTransformer unavailable, using default attention.")
# torch.compile for graph-level optimization (warmup on first call).
if is_cuda and hasattr(torch, "compile"):
try:
model = torch.compile(model, mode="reduce-overhead") # type: ignore[assignment]
log.info("Model compiled with torch.compile (reduce-overhead).")
except Exception:
log.info("torch.compile unavailable, skipping.")
self._model = model
self._dtype = torch.float16 if is_cuda else torch.float32
self._loaded_name = name
def _load_model(
self, model_path: str, is_cuda: bool,
) -> WhisperForConditionalGeneration:
"""Load model with best available acceleration, falling back gracefully."""
# Attempt 1: Flash Attention 2 + float16 (requires Ampere / sm_80+).
can_flash = (
is_cuda
and torch.cuda.get_device_capability()[0] >= 8
)
if can_flash:
try:
return WhisperForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
device_map="auto",
)
except (ImportError, ValueError, RuntimeError) as exc:
log.info("Flash Attention 2 unavailable (%s), trying standard load.", exc)
# Attempt 2: Standard CUDA load (float16, auto device map).
if is_cuda:
try:
return WhisperForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
)
except (ImportError, ValueError, RuntimeError) as exc:
log.info("Auto device_map failed (%s), falling back to manual.", exc)
# Attempt 3: Manual load (CPU or CUDA without device_map).
dtype = torch.float16 if is_cuda else torch.float32
model = WhisperForConditionalGeneration.from_pretrained(
model_path,
quantization_config=None,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
model.to(self.device) # type: ignore[arg-type]
return model
def _unload(self):
del self._model, self._processor
self._model = None
self._processor = None
self._loaded_name = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ------------------------------------------------------------------
# Audio resolution
# ------------------------------------------------------------------
def _resolve_audio(self, inputs) -> np.ndarray: # type: ignore[override]
"""Accept base64 string or raw bytes."""
if isinstance(inputs, str):
raw = base64.b64decode(inputs)
elif isinstance(inputs, bytes):
raw = inputs
else:
raise ValueError("'inputs' must be base64-encoded string or raw bytes.")
return _audio_bytes_to_numpy(raw)
# ------------------------------------------------------------------
# Inference
# ------------------------------------------------------------------
def _transcribe(self, audio: np.ndarray, language: str) -> str:
assert self._model is not None and self._processor is not None
chunks = _chunk(audio)
# Batch all chunks into a single forward pass.
if len(chunks) > 1:
return self._transcribe_batched(chunks, language)
return self._transcribe_single(chunks[0], language)
def _transcribe_single(self, audio: np.ndarray, language: str) -> str:
assert self._model is not None and self._processor is not None
features = self._processor( # type: ignore[operator]
audio, sampling_rate=SAMPLE_RATE, return_tensors="pt",
)
input_features = features.input_features.to(self.device, dtype=self._dtype)
with torch.no_grad(), torch.autocast(
self.device.type, dtype=torch.float16, enabled=self.device.type == "cuda",
):
ids = self._model.generate(
input_features,
language=language,
task="transcribe",
max_new_tokens=440,
)
return self._processor.batch_decode( # type: ignore[union-attr]
ids, skip_special_tokens=True,
)[0].strip()
def _transcribe_batched(self, chunks: list[np.ndarray], language: str) -> str:
assert self._model is not None and self._processor is not None
# Pad shorter chunks to 30s so mel features align for stacking.
padded = []
for c in chunks:
if len(c) < CHUNK_SAMPLES:
c = np.pad(c, (0, CHUNK_SAMPLES - len(c)))
padded.append(c)
features = self._processor( # type: ignore[operator]
padded, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True,
)
input_features = features.input_features.to(self.device, dtype=self._dtype)
with torch.no_grad(), torch.autocast(
self.device.type, dtype=torch.float16, enabled=self.device.type == "cuda",
):
ids = self._model.generate(
input_features,
language=language,
task="transcribe",
max_new_tokens=440,
)
texts = self._processor.batch_decode( # type: ignore[union-attr]
ids, skip_special_tokens=True,
)
return " ".join(t.strip() for t in texts if t.strip())