Spaces:
Sleeping
Sleeping
File size: 857 Bytes
960cd20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
import logging
logging.getLogger("numba").setLevel(logging.WARNING)
from transformers import (
Wav2Vec2FeatureExtractor,
HubertModel,
)
import torch.nn as nn
class CNHubert(nn.Module):
def __init__(self, cnhubert_base_path):
super().__init__()
self.model = HubertModel.from_pretrained(cnhubert_base_path)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
cnhubert_base_path
)
def forward(self, x):
input_values = self.feature_extractor(
x, return_tensors="pt", sampling_rate=16000
).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"]
return feats
def get_content(hmodel, wav_16k_tensor):
with torch.no_grad():
feats = hmodel(wav_16k_tensor)
return feats.transpose(1, 2)
|