| import io |
| import os |
| import time |
| from http import HTTPStatus |
|
|
| import numpy as np |
| import ormsgpack |
| import soundfile as sf |
| import torch |
| from kui.asgi import ( |
| Body, |
| HTTPException, |
| HttpView, |
| JSONResponse, |
| Routes, |
| StreamResponse, |
| request, |
| ) |
| from loguru import logger |
| from typing_extensions import Annotated |
|
|
| from fish_speech.utils.schema import ( |
| ServeASRRequest, |
| ServeASRResponse, |
| ServeChatRequest, |
| ServeTTSRequest, |
| ServeVQGANDecodeRequest, |
| ServeVQGANDecodeResponse, |
| ServeVQGANEncodeRequest, |
| ServeVQGANEncodeResponse, |
| ) |
| from tools.server.agent import get_response_generator |
| from tools.server.api_utils import ( |
| buffer_to_async_generator, |
| get_content_type, |
| inference_async, |
| ) |
| from tools.server.inference import inference_wrapper as inference |
| from tools.server.model_manager import ModelManager |
| from tools.server.model_utils import ( |
| batch_asr, |
| batch_vqgan_decode, |
| cached_vqgan_batch_encode, |
| ) |
|
|
| MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1)) |
|
|
| routes = Routes() |
|
|
|
|
| @routes.http("/v1/health") |
| class Health(HttpView): |
| @classmethod |
| async def get(cls): |
| return JSONResponse({"status": "ok"}) |
|
|
| @classmethod |
| async def post(cls): |
| return JSONResponse({"status": "ok"}) |
|
|
|
|
| @routes.http.post("/v1/vqgan/encode") |
| async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]): |
| |
| model_manager: ModelManager = request.app.state.model_manager |
| decoder_model = model_manager.decoder_model |
|
|
| |
| start_time = time.time() |
| tokens = cached_vqgan_batch_encode(decoder_model, req.audios) |
| logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms") |
|
|
| |
| return ormsgpack.packb( |
| ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]), |
| option=ormsgpack.OPT_SERIALIZE_PYDANTIC, |
| ) |
|
|
|
|
| @routes.http.post("/v1/vqgan/decode") |
| async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]): |
| |
| model_manager: ModelManager = request.app.state.model_manager |
| decoder_model = model_manager.decoder_model |
|
|
| |
| tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens] |
| start_time = time.time() |
| audios = batch_vqgan_decode(decoder_model, tokens) |
| logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms") |
| audios = [audio.astype(np.float16).tobytes() for audio in audios] |
|
|
| |
| return ormsgpack.packb( |
| ServeVQGANDecodeResponse(audios=audios), |
| option=ormsgpack.OPT_SERIALIZE_PYDANTIC, |
| ) |
|
|
|
|
| @routes.http.post("/v1/asr") |
| async def asr(req: Annotated[ServeASRRequest, Body(exclusive=True)]): |
| |
| model_manager: ModelManager = request.app.state.model_manager |
| asr_model = model_manager.asr_model |
| lock = request.app.state.lock |
|
|
| |
| start_time = time.time() |
| audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios] |
| audios = [torch.from_numpy(audio).float() for audio in audios] |
|
|
| if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios): |
| raise HTTPException(status_code=400, content="Audio length is too long") |
|
|
| transcriptions = batch_asr( |
| asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language |
| ) |
| logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms") |
|
|
| |
| return ormsgpack.packb( |
| ServeASRResponse(transcriptions=transcriptions), |
| option=ormsgpack.OPT_SERIALIZE_PYDANTIC, |
| ) |
|
|
|
|
| @routes.http.post("/v1/tts") |
| async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]): |
| |
| app_state = request.app.state |
| model_manager: ModelManager = app_state.model_manager |
| engine = model_manager.tts_inference_engine |
| sample_rate = engine.decoder_model.spec_transform.sample_rate |
|
|
| |
| if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length: |
| raise HTTPException( |
| HTTPStatus.BAD_REQUEST, |
| content=f"Text is too long, max length is {app_state.max_text_length}", |
| ) |
|
|
| |
| if req.streaming and req.format != "wav": |
| raise HTTPException( |
| HTTPStatus.BAD_REQUEST, |
| content="Streaming only supports WAV format", |
| ) |
|
|
| |
| if req.streaming: |
| return StreamResponse( |
| iterable=inference_async(req, engine), |
| headers={ |
| "Content-Disposition": f"attachment; filename=audio.{req.format}", |
| }, |
| content_type=get_content_type(req.format), |
| ) |
| else: |
| fake_audios = next(inference(req, engine)) |
| buffer = io.BytesIO() |
| sf.write( |
| buffer, |
| fake_audios, |
| sample_rate, |
| format=req.format, |
| ) |
|
|
| return StreamResponse( |
| iterable=buffer_to_async_generator(buffer.getvalue()), |
| headers={ |
| "Content-Disposition": f"attachment; filename=audio.{req.format}", |
| }, |
| content_type=get_content_type(req.format), |
| ) |
|
|
|
|
| @routes.http.post("/v1/chat") |
| async def chat(req: Annotated[ServeChatRequest, Body(exclusive=True)]): |
| |
| if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES: |
| raise HTTPException( |
| HTTPStatus.BAD_REQUEST, |
| content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}", |
| ) |
|
|
| |
| content_type = request.headers.get("Content-Type", "application/json") |
| json_mode = "application/json" in content_type |
|
|
| |
| model_manager: ModelManager = request.app.state.model_manager |
| llama_queue = model_manager.llama_queue |
| tokenizer = model_manager.tokenizer |
| config = model_manager.config |
|
|
| device = request.app.state.device |
|
|
| |
| response_generator = get_response_generator( |
| llama_queue, tokenizer, config, req, device, json_mode |
| ) |
|
|
| |
| if req.streaming is False: |
| result = response_generator() |
| if json_mode: |
| return JSONResponse(result.model_dump()) |
| else: |
| return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) |
|
|
| return StreamResponse( |
| iterable=response_generator(), content_type="text/event-stream" |
| ) |
|
|