|
""" |
|
Style-Bert-VITS2-Editor用のサーバー。 |
|
次のリポジトリ |
|
https://github.com/litagin02/Style-Bert-VITS2-Editor |
|
をビルドしてできあがったファイルをWebフォルダに入れて実行する。 |
|
|
|
TODO: リファクタリングやドキュメンテーションやAPI整理、辞書周りの改善などが必要。 |
|
""" |
|
|
|
import argparse |
|
import io |
|
import shutil |
|
import sys |
|
import webbrowser |
|
import zipfile |
|
from datetime import datetime |
|
from io import BytesIO |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import requests |
|
import torch |
|
import uvicorn |
|
from fastapi import APIRouter, FastAPI, HTTPException, status |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse, Response |
|
from fastapi.staticfiles import StaticFiles |
|
from pydantic import BaseModel |
|
from scipy.io import wavfile |
|
|
|
from config import get_path_config |
|
from initialize import download_default_models |
|
from style_bert_vits2.constants import ( |
|
DEFAULT_ASSIST_TEXT_WEIGHT, |
|
DEFAULT_NOISE, |
|
DEFAULT_NOISEW, |
|
DEFAULT_SDP_RATIO, |
|
DEFAULT_STYLE, |
|
DEFAULT_STYLE_WEIGHT, |
|
VERSION, |
|
Languages, |
|
) |
|
from style_bert_vits2.logging import logger |
|
from style_bert_vits2.nlp import bert_models |
|
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk |
|
from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone |
|
from style_bert_vits2.nlp.japanese.normalizer import normalize_text |
|
from style_bert_vits2.nlp.japanese.user_dict import ( |
|
apply_word, |
|
delete_word, |
|
read_dict, |
|
rewrite_word, |
|
update_dict, |
|
) |
|
from style_bert_vits2.tts_model import TTSModelHolder, TTSModelInfo |
|
|
|
|
|
|
|
|
|
|
|
STATIC_DIR = Path("static") |
|
|
|
LAST_DOWNLOAD_FILE = STATIC_DIR / "last_download.txt" |
|
|
|
|
|
def download_static_files(user, repo, asset_name): |
|
"""Style-Bert-VITS2エディターの最新のビルドzipをダウンロードして展開する。""" |
|
|
|
logger.info("Checking for new release...") |
|
latest_release = get_latest_release(user, repo) |
|
if latest_release is None: |
|
logger.warning( |
|
"Failed to fetch the latest release. Proceeding without static files." |
|
) |
|
return |
|
|
|
if not new_release_available(latest_release): |
|
logger.info("No new release available. Proceeding with existing static files.") |
|
return |
|
|
|
logger.info("New release available. Downloading static files...") |
|
asset_url = get_asset_url(latest_release, asset_name) |
|
if asset_url: |
|
if STATIC_DIR.exists(): |
|
shutil.rmtree(STATIC_DIR) |
|
STATIC_DIR.mkdir(parents=True, exist_ok=True) |
|
download_and_extract(asset_url, STATIC_DIR) |
|
save_last_download(latest_release) |
|
else: |
|
logger.warning("Asset not found. Proceeding without static files.") |
|
|
|
|
|
def get_latest_release(user, repo): |
|
url = f"https://api.github.com/repos/{user}/{repo}/releases/latest" |
|
try: |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
return response.json() |
|
except requests.RequestException: |
|
return None |
|
|
|
|
|
def get_asset_url(release, asset_name): |
|
for asset in release["assets"]: |
|
if asset["name"] == asset_name: |
|
return asset["browser_download_url"] |
|
return None |
|
|
|
|
|
def download_and_extract(url, extract_to: Path): |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref: |
|
zip_ref.extractall(extract_to) |
|
|
|
|
|
extracted_dirs = list(extract_to.iterdir()) |
|
if len(extracted_dirs) == 1 and extracted_dirs[0].is_dir(): |
|
for file in extracted_dirs[0].iterdir(): |
|
file.rename(extract_to / file.name) |
|
extracted_dirs[0].rmdir() |
|
|
|
|
|
if not (extract_to / "index.html").exists(): |
|
logger.warning("index.html not found in the extracted files.") |
|
|
|
|
|
def new_release_available(latest_release): |
|
if LAST_DOWNLOAD_FILE.exists(): |
|
with open(LAST_DOWNLOAD_FILE) as file: |
|
last_download_str = file.read().strip() |
|
|
|
last_download_str = last_download_str.replace("Z", "+00:00") |
|
last_download = datetime.fromisoformat(last_download_str) |
|
return ( |
|
datetime.fromisoformat( |
|
latest_release["published_at"].replace("Z", "+00:00") |
|
) |
|
> last_download |
|
) |
|
return True |
|
|
|
|
|
def save_last_download(latest_release): |
|
with open(LAST_DOWNLOAD_FILE, "w") as file: |
|
file.write(latest_release["published_at"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyopenjtalk.initialize_worker() |
|
|
|
|
|
update_dict() |
|
|
|
|
|
|
|
|
|
bert_models.load_model(Languages.JP) |
|
bert_models.load_tokenizer(Languages.JP) |
|
|
|
|
|
class AudioResponse(Response): |
|
media_type = "audio/wav" |
|
|
|
|
|
origins = [ |
|
"http://localhost:3000", |
|
"http://localhost:8000", |
|
"http://127.0.0.1:3000", |
|
"http://127.0.0.1:8000", |
|
] |
|
|
|
path_config = get_path_config() |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_dir", type=str, default=path_config.assets_root) |
|
parser.add_argument("--device", type=str, default="cuda") |
|
parser.add_argument("--port", type=int, default=8000) |
|
parser.add_argument("--inbrowser", action="store_true") |
|
parser.add_argument("--line_length", type=int, default=None) |
|
parser.add_argument("--line_count", type=int, default=None) |
|
parser.add_argument("--skip_default_models", action="store_true") |
|
parser.add_argument("--skip_static_files", action="store_true") |
|
args = parser.parse_args() |
|
device = args.device |
|
if device == "cuda" and not torch.cuda.is_available(): |
|
device = "cpu" |
|
model_dir = Path(args.model_dir) |
|
port = int(args.port) |
|
if not args.skip_default_models: |
|
download_default_models() |
|
skip_static_files = bool(args.skip_static_files) |
|
|
|
model_holder = TTSModelHolder(model_dir, device) |
|
if len(model_holder.model_names) == 0: |
|
logger.error(f"Models not found in {model_dir}.") |
|
sys.exit(1) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=origins, |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
router = APIRouter() |
|
|
|
|
|
@router.get("/version") |
|
def version() -> str: |
|
return VERSION |
|
|
|
|
|
class MoraTone(BaseModel): |
|
mora: str |
|
tone: int |
|
|
|
|
|
class TextRequest(BaseModel): |
|
text: str |
|
|
|
|
|
@router.post("/g2p") |
|
async def read_item(item: TextRequest): |
|
try: |
|
|
|
text = normalize_text(item.text) |
|
kata_tone_list = g2kata_tone(text) |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"Failed to convert {item.text} to katakana and tone, {e}", |
|
) |
|
return [MoraTone(mora=kata, tone=tone) for kata, tone in kata_tone_list] |
|
|
|
|
|
@router.post("/normalize") |
|
async def normalize(item: TextRequest): |
|
return normalize_text(item.text) |
|
|
|
|
|
@router.get("/models_info", response_model=list[TTSModelInfo]) |
|
def models_info(): |
|
return model_holder.models_info |
|
|
|
|
|
class SynthesisRequest(BaseModel): |
|
model: str |
|
modelFile: str |
|
text: str |
|
moraToneList: list[MoraTone] |
|
style: str = DEFAULT_STYLE |
|
styleWeight: float = DEFAULT_STYLE_WEIGHT |
|
assistText: str = "" |
|
assistTextWeight: float = DEFAULT_ASSIST_TEXT_WEIGHT |
|
speed: float = 1.0 |
|
noise: float = DEFAULT_NOISE |
|
noisew: float = DEFAULT_NOISEW |
|
sdpRatio: float = DEFAULT_SDP_RATIO |
|
language: Languages = Languages.JP |
|
silenceAfter: float = 0.5 |
|
pitchScale: float = 1.0 |
|
intonationScale: float = 1.0 |
|
speaker: Optional[str] = None |
|
|
|
|
|
@router.post("/synthesis", response_class=AudioResponse) |
|
def synthesis(request: SynthesisRequest): |
|
if args.line_length is not None and len(request.text) > args.line_length: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"1行の文字数は{args.line_length}文字以下にしてください。", |
|
) |
|
try: |
|
model = model_holder.get_model( |
|
model_name=request.model, model_path_str=request.modelFile |
|
) |
|
except Exception as e: |
|
logger.error(e) |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Failed to load model {request.model} from {request.modelFile}, {e}", |
|
) |
|
text = request.text |
|
kata_tone_list = [ |
|
(mora_tone.mora, mora_tone.tone) for mora_tone in request.moraToneList |
|
] |
|
phone_tone = kata_tone2phone_tone(kata_tone_list) |
|
tone = [t for _, t in phone_tone] |
|
try: |
|
sid = 0 if request.speaker is None else model.spk2id[request.speaker] |
|
except KeyError: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"Speaker {request.speaker} not found in {model.spk2id}", |
|
) |
|
sr, audio = model.infer( |
|
text=text, |
|
language=request.language, |
|
sdp_ratio=request.sdpRatio, |
|
noise=request.noise, |
|
noise_w=request.noisew, |
|
length=1 / request.speed, |
|
given_tone=tone, |
|
style=request.style, |
|
style_weight=request.styleWeight, |
|
assist_text=request.assistText, |
|
assist_text_weight=request.assistTextWeight, |
|
use_assist_text=bool(request.assistText), |
|
line_split=False, |
|
pitch_scale=request.pitchScale, |
|
intonation_scale=request.intonationScale, |
|
speaker_id=sid, |
|
) |
|
|
|
with BytesIO() as wavContent: |
|
wavfile.write(wavContent, sr, audio) |
|
return Response(content=wavContent.getvalue(), media_type="audio/wav") |
|
|
|
|
|
class MultiSynthesisRequest(BaseModel): |
|
lines: list[SynthesisRequest] |
|
|
|
|
|
@router.post("/multi_synthesis", response_class=AudioResponse) |
|
def multi_synthesis(request: MultiSynthesisRequest): |
|
lines = request.lines |
|
if args.line_count is not None and len(lines) > args.line_count: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"行数は{args.line_count}行以下にしてください。", |
|
) |
|
audios = [] |
|
sr = None |
|
for i, req in enumerate(lines): |
|
if args.line_length is not None and len(req.text) > args.line_length: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"1行の文字数は{args.line_length}文字以下にしてください。", |
|
) |
|
try: |
|
model = model_holder.get_model( |
|
model_name=req.model, model_path_str=req.modelFile |
|
) |
|
except Exception as e: |
|
logger.error(e) |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Failed to load model {req.model} from {req.modelFile}, {e}", |
|
) |
|
text = req.text |
|
kata_tone_list = [ |
|
(mora_tone.mora, mora_tone.tone) for mora_tone in req.moraToneList |
|
] |
|
phone_tone = kata_tone2phone_tone(kata_tone_list) |
|
tone = [t for _, t in phone_tone] |
|
sr, audio = model.infer( |
|
text=text, |
|
language=req.language, |
|
sdp_ratio=req.sdpRatio, |
|
noise=req.noise, |
|
noise_w=req.noisew, |
|
length=1 / req.speed, |
|
given_tone=tone, |
|
style=req.style, |
|
style_weight=req.styleWeight, |
|
assist_text=req.assistText, |
|
assist_text_weight=req.assistTextWeight, |
|
use_assist_text=bool(req.assistText), |
|
line_split=False, |
|
pitch_scale=req.pitchScale, |
|
intonation_scale=req.intonationScale, |
|
) |
|
audios.append(audio) |
|
if i < len(lines) - 1: |
|
silence = int(sr * req.silenceAfter) |
|
audios.append(np.zeros(silence, dtype=np.int16)) |
|
audio = np.concatenate(audios) |
|
|
|
with BytesIO() as wavContent: |
|
wavfile.write(wavContent, sr, audio) |
|
return Response(content=wavContent.getvalue(), media_type="audio/wav") |
|
|
|
|
|
class UserDictWordRequest(BaseModel): |
|
surface: str |
|
pronunciation: str |
|
accent_type: int |
|
priority: int = 5 |
|
|
|
|
|
@router.get("/user_dict") |
|
def get_user_dict(): |
|
return read_dict() |
|
|
|
|
|
@router.post("/user_dict_word") |
|
def add_user_dict_word(request: UserDictWordRequest): |
|
uuid = apply_word( |
|
surface=request.surface, |
|
pronunciation=request.pronunciation, |
|
accent_type=request.accent_type, |
|
priority=request.priority, |
|
) |
|
update_dict() |
|
|
|
return JSONResponse( |
|
status_code=status.HTTP_201_CREATED, |
|
content={"uuid": uuid}, |
|
) |
|
|
|
|
|
@router.put("/user_dict_word/{uuid}") |
|
def update_user_dict_word(uuid: str, request: UserDictWordRequest): |
|
rewrite_word( |
|
word_uuid=uuid, |
|
surface=request.surface, |
|
pronunciation=request.pronunciation, |
|
accent_type=request.accent_type, |
|
priority=request.priority, |
|
) |
|
update_dict() |
|
return JSONResponse(status_code=status.HTTP_200_OK, content={"uuid": uuid}) |
|
|
|
|
|
@router.delete("/user_dict_word/{uuid}") |
|
def delete_user_dict_word(uuid: str): |
|
delete_word(uuid) |
|
update_dict() |
|
return JSONResponse(status_code=status.HTTP_200_OK, content={"uuid": uuid}) |
|
|
|
|
|
app.include_router(router, prefix="/api") |
|
|
|
if __name__ == "__main__": |
|
if not skip_static_files: |
|
download_static_files("litagin02", "Style-Bert-VITS2-Editor", "out.zip") |
|
app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static") |
|
if args.inbrowser: |
|
webbrowser.open(f"http://localhost:{port}") |
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|