File size: 2,114 Bytes
feb2a2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0952218
 
feb2a2b
0952218
feb2a2b
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np

from typing import Dict

import torch
import pyctcdecode

from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2ProcessorWithLM,
    Wav2Vec2ForCTC,
)


class PreTrainedPipeline():

    def __init__(self, model_path: str, language_model_fp: str):
        self.language_model_fp = language_model_fp

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
        self.model.to(self.device)

        processor = Wav2Vec2Processor.from_pretrained(model_path)
        self.sampling_rate = processor.feature_extractor.sampling_rate

        vocab = processor.tokenizer.get_vocab()
        sorted_vocab_dict = [(char, ix) for char, ix in sorted(vocab.items(), key=lambda item: item[1])]

        self.decoder = pyctcdecode.build_ctcdecoder(
            labels=[x[0] for x in sorted_vocab_dict],
            kenlm_model_path=self.language_model_fp,
        )

        self.processor_with_lm = Wav2Vec2ProcessorWithLM(
            feature_extractor=processor.feature_extractor,
            tokenizer=processor.tokenizer,
            decoder=self.decoder
        )

    def __call__(self, inputs: np.array) -> Dict[str, str]:
        """
        Args:
            inputs (:obj:`np.array`):
                The raw waveform of audio received. By default at 16KHz.
        Return:
            A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
            the detected text from the input audio.
        """

        input_values = self.processor_with_lm(
            inputs, return_tensors="pt",
            sampling_rate=self.sampling_rate
        )['input_values']

        input_values = input_values.to(self.device)

        with torch.no_grad():
            # input_values should be a 2D tensor by now. 1st dim represents audio channels.
            model_outs = self.model(input_values)
        logits = model_outs.logits.cpu().detach().numpy()

        text_predicted = self.processor_with_lm.batch_decode(logits)['text']

        return {
            "text": text_predicted
        }