lewtun's picture
lewtun HF staff
commit files to HF hub
7cb5106
raw
history blame
1.31 kB
from s3prl.downstream.runner import Runner
from typing import Dict
import torch
import os
class PreTrainedModel(Runner):
def __init__(self, path=""):
"""
Initialize downstream model.
"""
ckp_file = os.path.join(path, "model.ckpt")
ckp = torch.load(ckp_file, map_location='cpu')
ckp["Args"].init_ckpt = ckp_file
ckp["Args"].mode = "inference"
ckp["Args"].device = "cpu"
ckp["Config"]["downstream_expert"]["datarc"]["dict_path"] = os.path.join(path,'char.dict')
Runner.__init__(self, ckp["Args"], ckp["Config"])
def __call__(self, inputs)-> Dict[str, str]:
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default at 16KHz.
Return:
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
the detected text from the input audio.
"""
for entry in self.all_entries:
entry.model.eval()
inputs = [torch.FloatTensor(inputs)]
with torch.no_grad():
features = self.upstream.model(inputs)
features = self.featurizer.model(inputs, features)
preds = self.downstream.model.inference(features, [])
return {"text": preds[0]}