mygyasir's picture
Duplicate from konverner/deep-voice-cloning
13c43fe
raw
history blame contribute delete
No virus
1.02 kB
import os
import json
import numpy as np
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
class TranscriberModel:
def __init__(self, lang: str = 'en'):
with open(os.path.join(os.path.dirname(__file__), 'config.json')) as f:
config = json.load(f)
self.processor = Wav2Vec2Processor.from_pretrained(config['language_model_names'][lang])
self.model = Wav2Vec2ForCTC.from_pretrained(config['language_model_names'][lang])
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward(self, speech_array: np.array, sampling_rate: int = 16000) -> str:
model_input = self.processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
with torch.no_grad():
logits = self.model(model_input.input_values, attention_mask=model_input.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
return self.processor.batch_decode(predicted_ids)