File size: 2,265 Bytes
93f4bab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os.path
from io import BytesIO
from pathlib import Path

import numpy as np
import onnxruntime as ort
import torch

from modules.hubert.cn_hubert import load_cn_model, get_cn_hubert_units
from modules.hubert.hubert_model import hubert_soft, get_units
from modules.hubert.hubert_onnx import get_onnx_units
from utils.hparams import hparams


class HubertEncoder:
    def __init__(self, pt_path='checkpoints/hubert/hubert_soft.pt', hubert_mode='', onnx=False):
        self.hubert_mode = hubert_mode
        self.onnx = onnx
        if 'use_cn_hubert' not in hparams.keys():
            hparams['use_cn_hubert'] = False
        if hparams['use_cn_hubert'] or self.hubert_mode == 'cn_hubert':
            pt_path = "checkpoints/cn_hubert/chinese-hubert-base-fairseq-ckpt.pt"
            self.dev = torch.device("cuda")
            self.hbt_model = load_cn_model(pt_path)
        else:
            if onnx:
                self.hbt_model = ort.InferenceSession("onnx/hubert_soft.onnx",
                                                      providers=['CUDAExecutionProvider', 'CPUExecutionProvider', ])
            else:
                pt_path = list(Path(pt_path).parent.rglob('*.pt'))[0]
                if 'hubert_gpu' in hparams.keys():
                    self.use_gpu = hparams['hubert_gpu']
                else:
                    self.use_gpu = True
                self.dev = torch.device("cuda" if self.use_gpu and torch.cuda.is_available() else "cpu")
                self.hbt_model = hubert_soft(str(pt_path)).to(self.dev)
        print(f"| load 'model' from '{pt_path}'")

    def encode(self, wav_path):
        if isinstance(wav_path, BytesIO):
            npy_path = ""
            wav_path.seek(0)
        else:
            npy_path = Path(wav_path).with_suffix('.npy')
        if os.path.exists(npy_path):
            units = np.load(str(npy_path))
        elif self.onnx:
            units = get_onnx_units(self.hbt_model, wav_path).squeeze(0)
        elif hparams['use_cn_hubert'] or self.hubert_mode == 'cn_hubert':
            units = get_cn_hubert_units(self.hbt_model, wav_path, self.dev).cpu().numpy()[0]
        else:
            units = get_units(self.hbt_model, wav_path, self.dev).cpu().numpy()[0]
        return units  # [T,256]