hubert-sd / model.py
Narsil's picture
Narsil HF staff
Update model.py
af7cc76
raw history blame
No virus
1.44 kB
"""
This is just an example of what people would submit for inference.
"""
import os
from typing import Dict, List
import torch
from s3prl.downstream.runner import Runner
class PreTrainedModel(Runner):
def __init__(self, path=""):
"""
Initialize downstream model.
"""
ckp_file = os.path.join(path, "hubert_sd.ckpt")
ckp = torch.load(ckp_file, map_location="cpu")
ckp["Args"].init_ckpt = ckp_file
ckp["Args"].mode = "inference"
ckp["Args"].device = "cpu" # Just to try in my computer
Runner.__init__(self, ckp["Args"], ckp["Config"])
def __call__(self, inputs) -> List[int]:
"""
Args: inputs (:obj:`np.array`): The raw waveform of audio received. By
default at 16KHz.
Return: A list with logits.
"""
for entry in self.all_entries:
entry.model.eval()
inputs = [torch.FloatTensor(inputs)]
with torch.no_grad():
features = self.upstream.model(inputs)
features = self.featurizer.model(inputs, features)
preds = self.downstream.model.inference(features, [])
return preds[0]
"""
import io
import soundfile as sf
from urllib.request import urlopen
model = PreTrainedModel()
url = "https://huggingface.co/datasets/lewtun/s3prl-sd-dummy/raw/main/audio.wav"
data, samplerate = sf.read(io.BytesIO(urlopen(url).read()))
print(model(data))
"""