lewtun's picture
lewtun HF staff
commit files to HF hub
7cb5106
raw history blame
No virus
1.31 kB
from s3prl.downstream.runner import Runner
from typing import Dict
import torch
import os
class PreTrainedModel(Runner):
def __init__(self, path=""):
"""
Initialize downstream model.
"""
ckp_file = os.path.join(path, "model.ckpt")
ckp = torch.load(ckp_file, map_location='cpu')
ckp["Args"].init_ckpt = ckp_file
ckp["Args"].mode = "inference"
ckp["Args"].device = "cpu"
ckp["Config"]["downstream_expert"]["datarc"]["dict_path"] = os.path.join(path,'char.dict')
Runner.__init__(self, ckp["Args"], ckp["Config"])
def __call__(self, inputs)-> Dict[str, str]:
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default at 16KHz.
Return:
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
the detected text from the input audio.
"""
for entry in self.all_entries:
entry.model.eval()
inputs = [torch.FloatTensor(inputs)]
with torch.no_grad():
features = self.upstream.model(inputs)
features = self.featurizer.model(inputs, features)
preds = self.downstream.model.inference(features, [])
return {"text": preds[0]}