asr / handler.py
afurkank's picture
Update handler.py
304c741 verified
raw
history blame contribute delete
No virus
2.25 kB
import logging
import torch
import os
import base64
from pyannote.audio import Pipeline
from transformers import pipeline, AutoModelForCausalLM
from huggingface_hub import HfApi
from pydantic import ValidationError
from starlette.exceptions import HTTPException
from config import model_settings, InferenceConfig
logger = logging.getLogger(__name__)
class EndpointHandler():
def __init__(self, path=""):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
logger.info(f"Using device: {device.type}")
torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model=model_settings.asr_model,
torch_dtype=torch_dtype,
device=device
)
def __call__(self, inputs):
file = inputs.pop("inputs")
file = base64.b64decode(file)
parameters = inputs.pop("parameters", {})
try:
parameters = InferenceConfig(**parameters)
except ValidationError as e:
logger.error(f"Error validating parameters: {e}")
raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}")
logger.info(f"inference parameters: {parameters}")
generate_kwargs = {
"task": parameters.task,
"language": parameters.language
}
try:
asr_outputs = self.asr_pipeline(
file,
chunk_length_s=parameters.chunk_length_s,
batch_size=parameters.batch_size,
generate_kwargs=generate_kwargs,
return_timestamps=True,
)
except RuntimeError as e:
logger.error(f"ASR inference error: {str(e)}")
raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown error during ASR inference: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unknown error during ASR inference: {str(e)}")
return {
"chunks": asr_outputs["chunks"],
"text": asr_outputs["text"],
}