pyctcdecode_asr / model.py
osanseviero's picture
osanseviero HF staff
Add model
4445395
raw history blame
No virus
1.29 kB
import numpy as np
from typing import Dict
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from pyctcdecode import Alphabet, BeamSearchDecoderCTC
class PreTrainedModel():
def __init__(self, path):
"""
Initialize model
"""
processor = Wav2Vec2Processor.from_pretrained(path)
model = Wav2Vec2ForCTC.from_pretrained(path)
vocab_list = list(processor.tokenizer.get_vocab().keys())
# convert ctc blank character representation
vocab_list[0] = ""
# replace special characters
vocab_list[1] = "⁇"
vocab_list[2] = "⁇"
vocab_list[3] = "⁇"
# convert space character representation
vocab_list[4] = " "
alphabet = Alphabet.build_alphabet(vocab_list, ctc_token_idx=0)
self.decoder = BeamSearchDecoderCTC(alphabet)
def __call__(self, inputs)-> Dict[str, str]:
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default at 16KHz.
Return:
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
the detected text from the input audio.
"""
return {
"text": self.decoder.decode(logits)
}