|
import re |
|
from threading import Lock |
|
|
|
import pyrootutils |
|
import uvicorn |
|
from kui.asgi import ( |
|
Depends, |
|
FactoryClass, |
|
HTTPException, |
|
HttpRoute, |
|
Kui, |
|
OpenAPI, |
|
Routes, |
|
) |
|
from kui.cors import CORSConfig |
|
from kui.openapi.specification import Info |
|
from kui.security import bearer_auth |
|
from loguru import logger |
|
from typing_extensions import Annotated |
|
|
|
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
|
|
|
from tools.server.api_utils import MsgPackRequest, parse_args |
|
from tools.server.exception_handler import ExceptionHandler |
|
from tools.server.model_manager import ModelManager |
|
from tools.server.views import routes |
|
|
|
|
|
class API(ExceptionHandler): |
|
def __init__(self): |
|
self.args = parse_args() |
|
self.routes = routes |
|
|
|
def api_auth(endpoint): |
|
async def verify(token: Annotated[str, Depends(bearer_auth)]): |
|
if token != self.args.api_key: |
|
raise HTTPException(401, None, "Invalid token") |
|
return await endpoint() |
|
|
|
async def passthrough(): |
|
return await endpoint() |
|
|
|
if self.args.api_key is not None: |
|
return verify |
|
else: |
|
return passthrough |
|
|
|
self.openapi = OpenAPI( |
|
Info( |
|
{ |
|
"title": "Fish Speech API", |
|
"version": "1.5.0", |
|
} |
|
), |
|
).routes |
|
|
|
|
|
self.app = Kui( |
|
routes=self.routes + self.openapi[1:], |
|
exception_handlers={ |
|
HTTPException: self.http_exception_handler, |
|
Exception: self.other_exception_handler, |
|
}, |
|
factory_class=FactoryClass(http=MsgPackRequest), |
|
cors_config=CORSConfig(), |
|
) |
|
|
|
|
|
self.app.state.lock = Lock() |
|
self.app.state.device = self.args.device |
|
self.app.state.max_text_length = self.args.max_text_length |
|
|
|
|
|
self.app.on_startup(self.initialize_app) |
|
|
|
async def initialize_app(self, app: Kui): |
|
|
|
app.state.model_manager = ModelManager( |
|
mode=self.args.mode, |
|
device=self.args.device, |
|
half=self.args.half, |
|
compile=self.args.compile, |
|
asr_enabled=self.args.load_asr_model, |
|
llama_checkpoint_path=self.args.llama_checkpoint_path, |
|
decoder_checkpoint_path=self.args.decoder_checkpoint_path, |
|
decoder_config_name=self.args.decoder_config_name, |
|
) |
|
|
|
logger.info(f"Startup done, listening server at http://{self.args.listen}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
api = API() |
|
|
|
|
|
match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen) |
|
if match: |
|
host, port = match.groups() |
|
else: |
|
host, port = api.args.listen.split(":") |
|
|
|
uvicorn.run( |
|
api.app, |
|
host=host, |
|
port=int(port), |
|
workers=api.args.workers, |
|
log_level="info", |
|
) |
|
|