star-gpt-sovits / pure_api.py
nekoaoxiang
Add api support
d290960
# 在开头加入路径
import os, sys
import importlib
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from src.config_manager import Inference_Config
from src.config_manager import __version__ as frontend_version
inference_config = Inference_Config()
import soundfile as sf
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import tempfile
import uvicorn
import json
# 将当前文件所在的目录添加到 sys.path
from Synthesizers.base import Base_TTS_Task, Base_TTS_Synthesizer
# 创建合成器实例
tts_synthesizer:Base_TTS_Synthesizer = None
def set_tts_synthesizer(synthesizer:Base_TTS_Synthesizer):
global tts_synthesizer
tts_synthesizer = synthesizer
# 存储临时文件的字典
temp_files = {}
async def character_list(request: Request):
res = JSONResponse(tts_synthesizer.get_characters())
return res
async def tts(request: Request):
from time import time as tt
t1 = tt()
print(f"Request Time: {t1}")
# 尝试从JSON中获取数据,如果不是JSON,则从查询参数中获取
if request.method == "GET":
data = request.query_params
else:
data = await request.json()
task:Base_TTS_Task = tts_synthesizer.params_parser(data)
if task.task_type == "text" and task.text.strip() == "":
return HTTPException(status_code=400, detail="Text is empty")
elif task.task_type == "ssml" and task.ssml.strip() == "":
return HTTPException(status_code=400, detail="SSML is empty")
md5_value = task.md5
if task.stream == False:
# TODO: use SQL instead of dict
if task.save_temp and md5_value in temp_files:
return FileResponse(path=temp_files[md5_value], media_type=f'audio/{task.format}')
else:
# 假设 gen 是你的音频生成器
try:
save_path = tts_synthesizer.generate(task, return_type="filepath")
except Exception as e:
return HTTPException(status_code=500, detail=str(e))
if task.save_temp:
temp_files[md5_value] = save_path
t2 = tt()
print(f"total time: {t2-t1}")
# 返回文件响应,FileResponse 会负责将文件发送给客户端
return FileResponse(save_path, media_type=f"audio/{task.format}", filename=os.path.basename(save_path))
else:
gen = tts_synthesizer.generate(task, return_type="numpy")
return StreamingResponse(gen, media_type='audio/wav')
if __name__ == "__main__":
# 动态导入合成器模块, 此处可写成 from Synthesizers.xxx import TTS_Synthesizer, TTS_Task
from importlib import import_module
from src.api_utils import get_localhost_ipv4_address
synthesizer_name = inference_config.synthesizer
synthesizer_module = import_module(f"Synthesizers.{synthesizer_name}")
TTS_Synthesizer = synthesizer_module.TTS_Synthesizer
TTS_Task = synthesizer_module.TTS_Task
tts_synthesizer = TTS_Synthesizer(debug_mode=True)
print(f"Backend Version: {__version__}")
tts_host = inference_config.tts_host
tts_port = inference_config.tts_port
ipv4_address = get_localhost_ipv4_address(tts_host)
ipv4_link = f"http://{ipv4_address}:{tts_port}"
print(f"INFO: Local Network URL: {ipv4_link}")
app = FastAPI()
# 设置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_api_route('/tts', tts, methods=["GET", "POST"])
app.add_api_route('/character_list', character_list, methods=["GET"])
uvicorn.run(app, host=tts_host, port=tts_port)