ChrisPreston's picture
Upload 95 files
93f4bab
raw
history blame
No virus
2.27 kB
import os.path
from io import BytesIO
from pathlib import Path
import numpy as np
import onnxruntime as ort
import torch
from modules.hubert.cn_hubert import load_cn_model, get_cn_hubert_units
from modules.hubert.hubert_model import hubert_soft, get_units
from modules.hubert.hubert_onnx import get_onnx_units
from utils.hparams import hparams
class HubertEncoder:
def __init__(self, pt_path='checkpoints/hubert/hubert_soft.pt', hubert_mode='', onnx=False):
self.hubert_mode = hubert_mode
self.onnx = onnx
if 'use_cn_hubert' not in hparams.keys():
hparams['use_cn_hubert'] = False
if hparams['use_cn_hubert'] or self.hubert_mode == 'cn_hubert':
pt_path = "checkpoints/cn_hubert/chinese-hubert-base-fairseq-ckpt.pt"
self.dev = torch.device("cuda")
self.hbt_model = load_cn_model(pt_path)
else:
if onnx:
self.hbt_model = ort.InferenceSession("onnx/hubert_soft.onnx",
providers=['CUDAExecutionProvider', 'CPUExecutionProvider', ])
else:
pt_path = list(Path(pt_path).parent.rglob('*.pt'))[0]
if 'hubert_gpu' in hparams.keys():
self.use_gpu = hparams['hubert_gpu']
else:
self.use_gpu = True
self.dev = torch.device("cuda" if self.use_gpu and torch.cuda.is_available() else "cpu")
self.hbt_model = hubert_soft(str(pt_path)).to(self.dev)
print(f"| load 'model' from '{pt_path}'")
def encode(self, wav_path):
if isinstance(wav_path, BytesIO):
npy_path = ""
wav_path.seek(0)
else:
npy_path = Path(wav_path).with_suffix('.npy')
if os.path.exists(npy_path):
units = np.load(str(npy_path))
elif self.onnx:
units = get_onnx_units(self.hbt_model, wav_path).squeeze(0)
elif hparams['use_cn_hubert'] or self.hubert_mode == 'cn_hubert':
units = get_cn_hubert_units(self.hbt_model, wav_path, self.dev).cpu().numpy()[0]
else:
units = get_units(self.hbt_model, wav_path, self.dev).cpu().numpy()[0]
return units # [T,256]