import torch from modules import config from modules import generate_audio as generate from functools import lru_cache from typing import Callable from modules.api.Api import APIManager from modules.api.impl import ( base_api, tts_api, ssml_api, google_api, openai_api, refiner_api, ) torch._dynamo.config.cache_size_limit = 64 torch._dynamo.config.suppress_errors = True torch.set_float32_matmul_precision("high") def create_api(): api = APIManager() base_api.setup(api) tts_api.setup(api) ssml_api.setup(api) google_api.setup(api) openai_api.setup(api) refiner_api.setup(api) return api def conditional_cache(condition: Callable): def decorator(func): @lru_cache(None) def cached_func(*args, **kwargs): return func(*args, **kwargs) def wrapper(*args, **kwargs): if condition(*args, **kwargs): return cached_func(*args, **kwargs) else: return func(*args, **kwargs) return wrapper return decorator if __name__ == "__main__": import argparse import uvicorn parser = argparse.ArgumentParser( description="Start the FastAPI server with command line arguments" ) parser.add_argument( "--host", type=str, default="0.0.0.0", help="Host to run the server on" ) parser.add_argument( "--port", type=int, default=8000, help="Port to run the server on" ) parser.add_argument( "--reload", action="store_true", help="Enable auto-reload for development" ) parser.add_argument("--compile", action="store_true", help="Enable model compile") parser.add_argument( "--lru_size", type=int, default=64, help="Set the size of the request cache pool, set it to 0 will disable lru_cache", ) parser.add_argument( "--cors_origin", type=str, default="*", help="Allowed CORS origins. Use '*' to allow all origins.", ) args = parser.parse_args() config.args = args if args.compile: print("Model compile is enabled") config.enable_model_compile = True def should_cache(*args, **kwargs): spk_seed = kwargs.get("spk_seed", -1) infer_seed = kwargs.get("infer_seed", -1) return spk_seed != -1 and infer_seed != -1 if args.lru_size > 0: config.lru_size = args.lru_size generate.generate_audio = conditional_cache(should_cache)( generate.generate_audio ) api = create_api() config.api = api if args.cors_origin: api.set_cors(allow_origins=[args.cors_origin]) uvicorn.run(api.app, host=args.host, port=args.port, reload=args.reload)