File size: 2,760 Bytes
01e655b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
from modules import config
from modules import generate_audio as generate
from functools import lru_cache
from typing import Callable
from modules.api.Api import APIManager
from modules.api.impl import (
base_api,
tts_api,
ssml_api,
google_api,
openai_api,
refiner_api,
)
torch._dynamo.config.cache_size_limit = 64
torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision("high")
def create_api():
api = APIManager()
base_api.setup(api)
tts_api.setup(api)
ssml_api.setup(api)
google_api.setup(api)
openai_api.setup(api)
refiner_api.setup(api)
return api
def conditional_cache(condition: Callable):
def decorator(func):
@lru_cache(None)
def cached_func(*args, **kwargs):
return func(*args, **kwargs)
def wrapper(*args, **kwargs):
if condition(*args, **kwargs):
return cached_func(*args, **kwargs)
else:
return func(*args, **kwargs)
return wrapper
return decorator
if __name__ == "__main__":
import argparse
import uvicorn
parser = argparse.ArgumentParser(
description="Start the FastAPI server with command line arguments"
)
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Host to run the server on"
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
parser.add_argument(
"--reload", action="store_true", help="Enable auto-reload for development"
)
parser.add_argument("--compile", action="store_true", help="Enable model compile")
parser.add_argument(
"--lru_size",
type=int,
default=64,
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
)
parser.add_argument(
"--cors_origin",
type=str,
default="*",
help="Allowed CORS origins. Use '*' to allow all origins.",
)
args = parser.parse_args()
config.args = args
if args.compile:
print("Model compile is enabled")
config.enable_model_compile = True
def should_cache(*args, **kwargs):
spk_seed = kwargs.get("spk_seed", -1)
infer_seed = kwargs.get("infer_seed", -1)
return spk_seed != -1 and infer_seed != -1
if args.lru_size > 0:
config.lru_size = args.lru_size
generate.generate_audio = conditional_cache(should_cache)(
generate.generate_audio
)
api = create_api()
config.api = api
if args.cors_origin:
api.set_cors(allow_origins=[args.cors_origin])
uvicorn.run(api.app, host=args.host, port=args.port, reload=args.reload)
|