File size: 1,314 Bytes
9843a53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from s3prl.downstream.runner import Runner
from typing import Dict
import torch
import os


class PreTrainedModel(Runner):
    def __init__(self, path=""):
        """
        Initialize downstream model.
        """
        ckp_file = os.path.join(path, "model.ckpt")
        ckp = torch.load(ckp_file, map_location='cpu')
        ckp["Args"].init_ckpt = ckp_file
        ckp["Args"].mode = "inference"
        ckp["Args"].device = "cpu"
        ckp["Config"]["downstream_expert"]["datarc"]["dict_path"] = os.path.join(path,'char.dict')

        Runner.__init__(self, ckp["Args"], ckp["Config"])

    def __call__(self, inputs)-> Dict[str, str]:
        """
        Args:
            inputs (:obj:`np.array`):
                The raw waveform of audio received. By default at 16KHz.
        Return:
            A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
            the detected text from the input audio.
        """
        for entry in self.all_entries:
            entry.model.eval()

        inputs = [torch.FloatTensor(inputs)]

        with torch.no_grad():
            features = self.upstream.model(inputs)
            features = self.featurizer.model(inputs, features)
            preds = self.downstream.model.inference(features, [])
        return {"text": preds[0]}