vits-simple-api / bert_vits2 /text /japanese_bert.py
Artrajz's picture
update
14e19a5
raw
history blame contribute delete
No virus
1.71 kB
import os
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import config
from logger import logger
from utils.download import download_and_verify
from config import DEVICE as device
URLS = [
"https://huggingface.co/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin",
]
TARGET_PATH = os.path.join(config.ABS_PATH, "bert_vits2/bert/bert-base-japanese-v3/pytorch_model.bin")
EXPECTED_MD5 = None
if not os.path.exists(TARGET_PATH):
success, message = download_and_verify(URLS, TARGET_PATH, EXPECTED_MD5)
try:
logger.info("Loading bert-base-japanese-v3...")
tokenizer = AutoTokenizer.from_pretrained(config.ABS_PATH + "/bert_vits2/bert/bert-base-japanese-v3")
model = AutoModelForMaskedLM.from_pretrained(config.ABS_PATH + "/bert_vits2/bert/bert-base-japanese-v3").to(
device)
logger.info("Loading finished.")
except Exception as e:
logger.error(e)
logger.error(f"Please download pytorch_model.bin from cl-tohoku/bert-base-japanese-v3.")
def get_bert_feature(text, word2ph, device=config.DEVICE):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
assert inputs["input_ids"].shape[-1] == len(word2ph)
word2phone = word2ph
phone_level_feature = []
for i in range(len(word2phone)):
repeat_feature = res[i].repeat(word2phone[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T