| """ |
| Style-Bert-VITS2-Editor用のサーバー。 |
| 次のリポジトリ |
| https://github.com/litagin02/Style-Bert-VITS2-Editor |
| をビルドしてできあがったファイルをWebフォルダに入れて実行する。 |
| |
| TODO: リファクタリングやドキュメンテーションやAPI整理、辞書周りの改善などが必要。 |
| """ |
|
|
| import argparse |
| import io |
| import shutil |
| import sys |
| import webbrowser |
| import zipfile |
| from datetime import datetime |
| from io import BytesIO |
| from pathlib import Path |
| from typing import Optional |
|
|
| import numpy as np |
| import requests |
| import torch |
| import uvicorn |
| from fastapi import APIRouter, FastAPI, HTTPException, status |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse, Response |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel |
| from scipy.io import wavfile |
|
|
| from config import get_path_config |
| from initialize import download_default_models |
| from style_bert_vits2.constants import ( |
| DEFAULT_ASSIST_TEXT_WEIGHT, |
| DEFAULT_NOISE, |
| DEFAULT_NOISEW, |
| DEFAULT_SDP_RATIO, |
| DEFAULT_STYLE, |
| DEFAULT_STYLE_WEIGHT, |
| VERSION, |
| Languages, |
| ) |
| from style_bert_vits2.logging import logger |
| from style_bert_vits2.nlp import bert_models |
| from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk |
| from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone |
| from style_bert_vits2.nlp.japanese.normalizer import normalize_text |
| from style_bert_vits2.nlp.japanese.user_dict import ( |
| apply_word, |
| delete_word, |
| read_dict, |
| rewrite_word, |
| update_dict, |
| ) |
| from style_bert_vits2.tts_model import TTSModelHolder, TTSModelInfo |
|
|
|
|
| |
|
|
| |
| STATIC_DIR = Path("static") |
| |
| LAST_DOWNLOAD_FILE = STATIC_DIR / "last_download.txt" |
|
|
|
|
| def download_static_files(user, repo, asset_name): |
| """Style-Bert-VITS2エディターの最新のビルドzipをダウンロードして展開する。""" |
|
|
| logger.info("Checking for new release...") |
| latest_release = get_latest_release(user, repo) |
| if latest_release is None: |
| logger.warning( |
| "Failed to fetch the latest release. Proceeding without static files." |
| ) |
| return |
|
|
| if not new_release_available(latest_release): |
| logger.info("No new release available. Proceeding with existing static files.") |
| return |
|
|
| logger.info("New release available. Downloading static files...") |
| asset_url = get_asset_url(latest_release, asset_name) |
| if asset_url: |
| if STATIC_DIR.exists(): |
| shutil.rmtree(STATIC_DIR) |
| STATIC_DIR.mkdir(parents=True, exist_ok=True) |
| download_and_extract(asset_url, STATIC_DIR) |
| save_last_download(latest_release) |
| else: |
| logger.warning("Asset not found. Proceeding without static files.") |
|
|
|
|
| def get_latest_release(user, repo): |
| url = f"https://api.github.com/repos/{user}/{repo}/releases/latest" |
| try: |
| response = requests.get(url) |
| response.raise_for_status() |
| return response.json() |
| except requests.RequestException: |
| return None |
|
|
|
|
| def get_asset_url(release, asset_name): |
| for asset in release["assets"]: |
| if asset["name"] == asset_name: |
| return asset["browser_download_url"] |
| return None |
|
|
|
|
| def download_and_extract(url, extract_to: Path): |
| response = requests.get(url) |
| response.raise_for_status() |
| with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref: |
| zip_ref.extractall(extract_to) |
|
|
| |
| extracted_dirs = list(extract_to.iterdir()) |
| if len(extracted_dirs) == 1 and extracted_dirs[0].is_dir(): |
| for file in extracted_dirs[0].iterdir(): |
| file.rename(extract_to / file.name) |
| extracted_dirs[0].rmdir() |
|
|
| |
| if not (extract_to / "index.html").exists(): |
| logger.warning("index.html not found in the extracted files.") |
|
|
|
|
| def new_release_available(latest_release): |
| if LAST_DOWNLOAD_FILE.exists(): |
| with open(LAST_DOWNLOAD_FILE) as file: |
| last_download_str = file.read().strip() |
| |
| last_download_str = last_download_str.replace("Z", "+00:00") |
| last_download = datetime.fromisoformat(last_download_str) |
| return ( |
| datetime.fromisoformat( |
| latest_release["published_at"].replace("Z", "+00:00") |
| ) |
| > last_download |
| ) |
| return True |
|
|
|
|
| def save_last_download(latest_release): |
| with open(LAST_DOWNLOAD_FILE, "w") as file: |
| file.write(latest_release["published_at"]) |
|
|
|
|
| |
| |
|
|
| |
| |
| pyopenjtalk.initialize_worker() |
|
|
| |
| update_dict() |
|
|
| |
| |
| |
| bert_models.load_model(Languages.JP) |
| bert_models.load_tokenizer(Languages.JP) |
|
|
|
|
| class AudioResponse(Response): |
| media_type = "audio/wav" |
|
|
|
|
| origins = [ |
| "http://localhost:3000", |
| "http://localhost:8000", |
| "http://127.0.0.1:3000", |
| "http://127.0.0.1:8000", |
| ] |
|
|
| path_config = get_path_config() |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_dir", type=str, default=path_config.assets_root) |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--port", type=int, default=8000) |
| parser.add_argument("--inbrowser", action="store_true") |
| parser.add_argument("--line_length", type=int, default=None) |
| parser.add_argument("--line_count", type=int, default=None) |
| parser.add_argument("--skip_default_models", action="store_true") |
| parser.add_argument("--skip_static_files", action="store_true") |
| args = parser.parse_args() |
| device = args.device |
| if device == "cuda" and not torch.cuda.is_available(): |
| device = "cpu" |
| model_dir = Path(args.model_dir) |
| port = int(args.port) |
| if not args.skip_default_models: |
| download_default_models() |
| skip_static_files = bool(args.skip_static_files) |
|
|
| model_holder = TTSModelHolder(model_dir, device) |
| if len(model_holder.model_names) == 0: |
| logger.error(f"Models not found in {model_dir}.") |
| sys.exit(1) |
|
|
|
|
| app = FastAPI() |
|
|
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=origins, |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| router = APIRouter() |
|
|
|
|
| @router.get("/version") |
| def version() -> str: |
| return VERSION |
|
|
|
|
| class MoraTone(BaseModel): |
| mora: str |
| tone: int |
|
|
|
|
| class TextRequest(BaseModel): |
| text: str |
|
|
|
|
| @router.post("/g2p") |
| async def read_item(item: TextRequest): |
| try: |
| |
| text = normalize_text(item.text) |
| kata_tone_list = g2kata_tone(text) |
| except Exception as e: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Failed to convert {item.text} to katakana and tone, {e}", |
| ) |
| return [MoraTone(mora=kata, tone=tone) for kata, tone in kata_tone_list] |
|
|
|
|
| @router.post("/normalize") |
| async def normalize(item: TextRequest): |
| return normalize_text(item.text) |
|
|
|
|
| @router.get("/models_info", response_model=list[TTSModelInfo]) |
| def models_info(): |
| return model_holder.models_info |
|
|
|
|
| class SynthesisRequest(BaseModel): |
| model: str |
| modelFile: str |
| text: str |
| moraToneList: list[MoraTone] |
| style: str = DEFAULT_STYLE |
| styleWeight: float = DEFAULT_STYLE_WEIGHT |
| assistText: str = "" |
| assistTextWeight: float = DEFAULT_ASSIST_TEXT_WEIGHT |
| speed: float = 1.0 |
| noise: float = DEFAULT_NOISE |
| noisew: float = DEFAULT_NOISEW |
| sdpRatio: float = DEFAULT_SDP_RATIO |
| language: Languages = Languages.JP |
| silenceAfter: float = 0.5 |
| pitchScale: float = 1.0 |
| intonationScale: float = 1.0 |
| speaker: Optional[str] = None |
|
|
|
|
| @router.post("/synthesis", response_class=AudioResponse) |
| def synthesis(request: SynthesisRequest): |
| if args.line_length is not None and len(request.text) > args.line_length: |
| raise HTTPException( |
| status_code=400, |
| detail=f"1行の文字数は{args.line_length}文字以下にしてください。", |
| ) |
| try: |
| model = model_holder.get_model( |
| model_name=request.model, model_path_str=request.modelFile |
| ) |
| except Exception as e: |
| logger.error(e) |
| raise HTTPException( |
| status_code=500, |
| detail=f"Failed to load model {request.model} from {request.modelFile}, {e}", |
| ) |
| text = request.text |
| kata_tone_list = [ |
| (mora_tone.mora, mora_tone.tone) for mora_tone in request.moraToneList |
| ] |
| phone_tone = kata_tone2phone_tone(kata_tone_list) |
| tone = [t for _, t in phone_tone] |
| try: |
| sid = 0 if request.speaker is None else model.spk2id[request.speaker] |
| except KeyError: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Speaker {request.speaker} not found in {model.spk2id}", |
| ) |
| sr, audio = model.infer( |
| text=text, |
| language=request.language, |
| sdp_ratio=request.sdpRatio, |
| noise=request.noise, |
| noise_w=request.noisew, |
| length=1 / request.speed, |
| given_tone=tone, |
| style=request.style, |
| style_weight=request.styleWeight, |
| assist_text=request.assistText, |
| assist_text_weight=request.assistTextWeight, |
| use_assist_text=bool(request.assistText), |
| line_split=False, |
| pitch_scale=request.pitchScale, |
| intonation_scale=request.intonationScale, |
| speaker_id=sid, |
| ) |
|
|
| with BytesIO() as wavContent: |
| wavfile.write(wavContent, sr, audio) |
| return Response(content=wavContent.getvalue(), media_type="audio/wav") |
|
|
|
|
| class MultiSynthesisRequest(BaseModel): |
| lines: list[SynthesisRequest] |
|
|
|
|
| @router.post("/multi_synthesis", response_class=AudioResponse) |
| def multi_synthesis(request: MultiSynthesisRequest): |
| lines = request.lines |
| if args.line_count is not None and len(lines) > args.line_count: |
| raise HTTPException( |
| status_code=400, |
| detail=f"行数は{args.line_count}行以下にしてください。", |
| ) |
| audios = [] |
| sr = None |
| for i, req in enumerate(lines): |
| if args.line_length is not None and len(req.text) > args.line_length: |
| raise HTTPException( |
| status_code=400, |
| detail=f"1行の文字数は{args.line_length}文字以下にしてください。", |
| ) |
| try: |
| model = model_holder.get_model( |
| model_name=req.model, model_path_str=req.modelFile |
| ) |
| except Exception as e: |
| logger.error(e) |
| raise HTTPException( |
| status_code=500, |
| detail=f"Failed to load model {req.model} from {req.modelFile}, {e}", |
| ) |
| text = req.text |
| kata_tone_list = [ |
| (mora_tone.mora, mora_tone.tone) for mora_tone in req.moraToneList |
| ] |
| phone_tone = kata_tone2phone_tone(kata_tone_list) |
| tone = [t for _, t in phone_tone] |
| sr, audio = model.infer( |
| text=text, |
| language=req.language, |
| sdp_ratio=req.sdpRatio, |
| noise=req.noise, |
| noise_w=req.noisew, |
| length=1 / req.speed, |
| given_tone=tone, |
| style=req.style, |
| style_weight=req.styleWeight, |
| assist_text=req.assistText, |
| assist_text_weight=req.assistTextWeight, |
| use_assist_text=bool(req.assistText), |
| line_split=False, |
| pitch_scale=req.pitchScale, |
| intonation_scale=req.intonationScale, |
| ) |
| audios.append(audio) |
| if i < len(lines) - 1: |
| silence = int(sr * req.silenceAfter) |
| audios.append(np.zeros(silence, dtype=np.int16)) |
| audio = np.concatenate(audios) |
|
|
| with BytesIO() as wavContent: |
| wavfile.write(wavContent, sr, audio) |
| return Response(content=wavContent.getvalue(), media_type="audio/wav") |
|
|
|
|
| class UserDictWordRequest(BaseModel): |
| surface: str |
| pronunciation: str |
| accent_type: int |
| priority: int = 5 |
|
|
|
|
| @router.get("/user_dict") |
| def get_user_dict(): |
| return read_dict() |
|
|
|
|
| @router.post("/user_dict_word") |
| def add_user_dict_word(request: UserDictWordRequest): |
| uuid = apply_word( |
| surface=request.surface, |
| pronunciation=request.pronunciation, |
| accent_type=request.accent_type, |
| priority=request.priority, |
| ) |
| update_dict() |
|
|
| return JSONResponse( |
| status_code=status.HTTP_201_CREATED, |
| content={"uuid": uuid}, |
| ) |
|
|
|
|
| @router.put("/user_dict_word/{uuid}") |
| def update_user_dict_word(uuid: str, request: UserDictWordRequest): |
| rewrite_word( |
| word_uuid=uuid, |
| surface=request.surface, |
| pronunciation=request.pronunciation, |
| accent_type=request.accent_type, |
| priority=request.priority, |
| ) |
| update_dict() |
| return JSONResponse(status_code=status.HTTP_200_OK, content={"uuid": uuid}) |
|
|
|
|
| @router.delete("/user_dict_word/{uuid}") |
| def delete_user_dict_word(uuid: str): |
| delete_word(uuid) |
| update_dict() |
| return JSONResponse(status_code=status.HTTP_200_OK, content={"uuid": uuid}) |
|
|
|
|
| app.include_router(router, prefix="/api") |
|
|
| if __name__ == "__main__": |
| if not skip_static_files: |
| download_static_files("litagin02", "Style-Bert-VITS2-Editor", "out.zip") |
| app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static") |
| if args.inbrowser: |
| webbrowser.open(f"http://localhost:{port}") |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|