vits_yz / server /main.py
byzp's picture
Update server/main.py
61415c3
import os
#maxlen[0]=int(input("Maximum text length:"))
if os.path.exists("/usr/lib/x86_64-linux-gnu/libtcmalloc.so"):
try:
os.environ["LD_PRELOAD"] = "/usr/lib/x86_64-linux-gnu/libtcmalloc.so"
import ctypes
ctypes.CDLL("libtcmalloc.so", mode=ctypes.RTLD_GLOBAL)
print("tcmalloc.so loaded.")
except Exception as e:
print(e)
print("Failed to load tcmalloc.so.")
else:
print("Cannot locate TCMalloc.")
from fastapi import FastAPI,Body,Request
from fastapi.responses import JSONResponse,Response,StreamingResponse
from starlette.responses import FileResponse
import uvicorn
import logging
from pydantic import BaseModel
import vits
import torch
import re
import threading
import cmd
blacklist=[]
maxlen=[]
with open('blacklist.txt', 'r') as f:
lines = f.readlines()
blacklist = [line.strip() for line in lines]
with open('maxlen.txt', 'r') as f:
maxlen.append(f.read())
if torch.cuda.is_available():
gpu=1
else:
print("Use CPU.")
gpu=0
if gpu==1:
import run_old
else:
import run_new
app = FastAPI()
logging.basicConfig(level=logging.WARNING)
class item(BaseModel):
command: str
@app.post("/")
def getwav(command:item,request:Request):
global maxlen,blacklist
if request.client.host in blacklist:
return JSONResponse(
status_code=403,
content={"message":"IP banned."},)
command=str(command)
print(command)
if str(command)[9:15]=="python":
s = command[9:-1]
text_match = re.search(r"--text=(\S+)", s)
if text_match:
text = text_match.group(1)
if len(text)>int(maxlen[0]):
return JSONResponse(
status_code=403,
content={"message":"The text is too long."},)
else:
return JSONResponse(
status_code=404,
content={"message":"missing text."},)
character_match = re.search(r"--character=(\d+)", s)
if character_match:
character = int(character_match.group(1))
else:
return JSONResponse(
status_code=404,
content={"message":"missing character."},)
try:
if gpu==0:
if "./vits/" in s:
result=run_new.ys(text,character)
elif "./vits_bh3/" in s:
result=run_new.bh3(text,character)
else:
return JSONResponse(
status_code=404,
content={"message":"missing py"},)
if gpu==1:
if "./vits/" in s:
result=run_old.ys(text,character)
elif "./vits_bh3/" in s:
result=run_old.bh3(text,character)
else:
return JSONResponse(
status_code=404,
content={"message":"missing py"},)
except Exception as e:
print(e)
return JSONResponse(
status_code=500,
content={"message":"Internal Server Error."},)
#os.system(command[9:-1])
response = StreamingResponse(iter([result.getvalue()]), media_type="application/octet-stream")
response.headers["Content-Disposition"] = "attachment; filename=example.wav"
return response#FileResponse('./example.wav', media_type="wav")
#uvicorn.run(app=app, host="0.0.0.0", port=7860, log_level="debug")