File size: 2,760 Bytes
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from modules import config
from modules import generate_audio as generate

from functools import lru_cache
from typing import Callable

from modules.api.Api import APIManager

from modules.api.impl import (
    base_api,
    tts_api,
    ssml_api,
    google_api,
    openai_api,
    refiner_api,
)

torch._dynamo.config.cache_size_limit = 64
torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision("high")


def create_api():
    api = APIManager()

    base_api.setup(api)
    tts_api.setup(api)
    ssml_api.setup(api)
    google_api.setup(api)
    openai_api.setup(api)
    refiner_api.setup(api)

    return api


def conditional_cache(condition: Callable):
    def decorator(func):
        @lru_cache(None)
        def cached_func(*args, **kwargs):
            return func(*args, **kwargs)

        def wrapper(*args, **kwargs):
            if condition(*args, **kwargs):
                return cached_func(*args, **kwargs)
            else:
                return func(*args, **kwargs)

        return wrapper

    return decorator


if __name__ == "__main__":
    import argparse
    import uvicorn

    parser = argparse.ArgumentParser(
        description="Start the FastAPI server with command line arguments"
    )
    parser.add_argument(
        "--host", type=str, default="0.0.0.0", help="Host to run the server on"
    )
    parser.add_argument(
        "--port", type=int, default=8000, help="Port to run the server on"
    )
    parser.add_argument(
        "--reload", action="store_true", help="Enable auto-reload for development"
    )
    parser.add_argument("--compile", action="store_true", help="Enable model compile")
    parser.add_argument(
        "--lru_size",
        type=int,
        default=64,
        help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
    )
    parser.add_argument(
        "--cors_origin",
        type=str,
        default="*",
        help="Allowed CORS origins. Use '*' to allow all origins.",
    )

    args = parser.parse_args()

    config.args = args

    if args.compile:
        print("Model compile is enabled")
        config.enable_model_compile = True

    def should_cache(*args, **kwargs):
        spk_seed = kwargs.get("spk_seed", -1)
        infer_seed = kwargs.get("infer_seed", -1)
        return spk_seed != -1 and infer_seed != -1

    if args.lru_size > 0:
        config.lru_size = args.lru_size
        generate.generate_audio = conditional_cache(should_cache)(
            generate.generate_audio
        )

    api = create_api()
    config.api = api

    if args.cors_origin:
        api.set_cors(allow_origins=[args.cors_origin])

    uvicorn.run(api.app, host=args.host, port=args.port, reload=args.reload)