import logging import torch import os import base64 from pyannote.audio import Pipeline from transformers import pipeline, AutoModelForCausalLM from huggingface_hub import HfApi from pydantic import ValidationError from starlette.exceptions import HTTPException from config import model_settings, InferenceConfig logger = logging.getLogger(__name__) class EndpointHandler(): def __init__(self, path=""): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") logger.info(f"Using device: {device.type}") torch_dtype = torch.float32 if device.type == "cpu" else torch.float16 self.asr_pipeline = pipeline( "automatic-speech-recognition", model=model_settings.asr_model, torch_dtype=torch_dtype, device=device ) def __call__(self, inputs): file = inputs.pop("inputs") file = base64.b64decode(file) parameters = inputs.pop("parameters", {}) try: parameters = InferenceConfig(**parameters) except ValidationError as e: logger.error(f"Error validating parameters: {e}") raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}") logger.info(f"inference parameters: {parameters}") generate_kwargs = { "task": parameters.task, "language": parameters.language } try: asr_outputs = self.asr_pipeline( file, chunk_length_s=parameters.chunk_length_s, batch_size=parameters.batch_size, generate_kwargs=generate_kwargs, return_timestamps=True, ) except RuntimeError as e: logger.error(f"ASR inference error: {str(e)}") raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}") except Exception as e: logger.error(f"Unknown error during ASR inference: {str(e)}") raise HTTPException(status_code=500, detail=f"Unknown error during ASR inference: {str(e)}") return { "chunks": asr_outputs["chunks"], "text": asr_outputs["text"], }