File size: 1,017 Bytes
13c43fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os
import json

import numpy as np
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC


class TranscriberModel:
    def __init__(self, lang: str = 'en'):
        with open(os.path.join(os.path.dirname(__file__), 'config.json')) as f:
            config = json.load(f)
        self.processor = Wav2Vec2Processor.from_pretrained(config['language_model_names'][lang])
        self.model = Wav2Vec2ForCTC.from_pretrained(config['language_model_names'][lang])
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def forward(self, speech_array: np.array, sampling_rate: int = 16000) -> str:
        model_input = self.processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
        with torch.no_grad():
            logits = self.model(model_input.input_values, attention_mask=model_input.attention_mask).logits
            predicted_ids = torch.argmax(logits, dim=-1)
        return self.processor.batch_decode(predicted_ids)