lewtun's picture
lewtun HF staff
Add model file and checkpoint
adb6f04
"""
This is just an example of what people would submit for inference.
"""
import os
from typing import Dict
import torch
from s3prl.downstream.runner import Runner
class PreTrainedModel(Runner):
def __init__(self, path=""):
"""
Initialize downstream model.
"""
ckp_file = os.path.join(path, "hubert_sd.ckpt")
ckp = torch.load(ckp_file, map_location="cpu")
ckp["Args"].init_ckpt = ckp_file
ckp["Args"].mode = "inference"
ckp["Args"].device = "cpu" # Just to try in my computer
Runner.__init__(self, ckp["Args"], ckp["Config"])
def __call__(self, inputs) -> Dict[str, str]:
"""
Args: inputs (:obj:`np.array`): The raw waveform of audio received. By
default at 16KHz.
Return: A :obj:`dict`:. The object should return a dictionary {"frames":
"XXX"} containing the frames where one, both, or none of the
speakers are speaking.
"""
for entry in self.all_entries:
entry.model.eval()
inputs = [torch.FloatTensor(inputs)]
with torch.no_grad():
features = self.upstream.model(inputs)
features = self.featurizer.model(inputs, features)
preds = self.downstream.model.inference(features, [])
return preds[0]
"""
import io
import soundfile as sf
from urllib.request import urlopen
model = PreTrainedModel()
url = "https://huggingface.co/datasets/lewtun/s3prl-sd-dummy/raw/main/audio.wav"
data, samplerate = sf.read(io.BytesIO(urlopen(url).read()))
print(model(data))
"""