|
import sys |
|
|
|
import torch |
|
from transformers import ( |
|
AutoModelForMaskedLM, |
|
AutoTokenizer, |
|
DebertaV2Model, |
|
DebertaV2Tokenizer, |
|
ClapModel, |
|
ClapProcessor, |
|
) |
|
|
|
from config import config |
|
from text.japanese import text2sep_kata |
|
|
|
|
|
class BertFeature: |
|
def __init__(self, model_path, language="ZH"): |
|
self.model_path = model_path |
|
self.language = language |
|
self.tokenizer = None |
|
self.model = None |
|
self.device = None |
|
|
|
self._prepare() |
|
|
|
def _get_device(self, device=config.bert_gen_config.device): |
|
if ( |
|
sys.platform == "darwin" |
|
and torch.backends.mps.is_available() |
|
and device == "cpu" |
|
): |
|
device = "mps" |
|
if not device: |
|
device = "cuda" |
|
return device |
|
|
|
def _prepare(self): |
|
self.device = self._get_device() |
|
|
|
if self.language == "EN": |
|
self.tokenizer = DebertaV2Tokenizer.from_pretrained(self.model_path) |
|
self.model = DebertaV2Model.from_pretrained(self.model_path).to(self.device) |
|
else: |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) |
|
self.model = AutoModelForMaskedLM.from_pretrained(self.model_path).to( |
|
self.device |
|
) |
|
self.model.eval() |
|
|
|
def get_bert_feature(self, text, word2ph): |
|
if self.language == "JP": |
|
text = "".join(text2sep_kata(text)[0]) |
|
with torch.no_grad(): |
|
inputs = self.tokenizer(text, return_tensors="pt") |
|
for i in inputs: |
|
inputs[i] = inputs[i].to(self.device) |
|
res = self.model(**inputs, output_hidden_states=True) |
|
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() |
|
|
|
word2phone = word2ph |
|
phone_level_feature = [] |
|
for i in range(len(word2phone)): |
|
repeat_feature = res[i].repeat(word2phone[i], 1) |
|
phone_level_feature.append(repeat_feature) |
|
|
|
phone_level_feature = torch.cat(phone_level_feature, dim=0) |
|
|
|
return phone_level_feature.T |
|
|
|
|
|
class ClapFeature: |
|
def __init__(self, model_path): |
|
self.model_path = model_path |
|
self.processor = None |
|
self.model = None |
|
self.device = None |
|
|
|
self._prepare() |
|
|
|
def _get_device(self, device=config.bert_gen_config.device): |
|
if ( |
|
sys.platform == "darwin" |
|
and torch.backends.mps.is_available() |
|
and device == "cpu" |
|
): |
|
device = "mps" |
|
if not device: |
|
device = "cuda" |
|
return device |
|
|
|
def _prepare(self): |
|
self.device = self._get_device() |
|
|
|
self.processor = ClapProcessor.from_pretrained(self.model_path) |
|
self.model = ClapModel.from_pretrained(self.model_path).to(self.device) |
|
self.model.eval() |
|
|
|
def get_clap_audio_feature(self, audio_data): |
|
with torch.no_grad(): |
|
inputs = self.processor( |
|
audios=audio_data, return_tensors="pt", sampling_rate=48000 |
|
).to(self.device) |
|
emb = self.model.get_audio_features(**inputs) |
|
return emb.T |
|
|
|
def get_clap_text_feature(self, text): |
|
with torch.no_grad(): |
|
inputs = self.processor(text=text, return_tensors="pt").to(self.device) |
|
emb = self.model.get_text_features(**inputs) |
|
return emb.T |
|
|