hubert-sd / model.py
osanseviero's picture
osanseviero HF staff
Add initial model
4fec958
raw history blame
No virus
1.43 kB
"""
This is just an example of what people would submit for inference.
"""
import os
from typing import Dict
import torch
from s3prl.downstream.runner import Runner
class PreTrainedModel(Runner):
def __init__(self, path=""):
"""
Initialize downstream model.
"""
ckp_file = os.path.join(path, "hubert_sd.ckpt")
ckp = torch.load(ckp_file, map_location="cpu")
ckp["Args"].init_ckpt = ckp_file
ckp["Args"].mode = "inference"
ckp["Args"].device = "cpu" # Just to try in my computer
Runner.__init__(self, ckp["Args"], ckp["Config"])
def __call__(self, inputs) -> list[int]:
"""
Args: inputs (:obj:`np.array`): The raw waveform of audio received. By
default at 16KHz.
Return: A list with logits.
"""
for entry in self.all_entries:
entry.model.eval()
inputs = [torch.FloatTensor(inputs)]
with torch.no_grad():
features = self.upstream.model(inputs)
features = self.featurizer.model(inputs, features)
preds = self.downstream.model.inference(features, [])
return preds[0]
"""
import io
import soundfile as sf
from urllib.request import urlopen
model = PreTrainedModel()
url = "https://huggingface.co/datasets/lewtun/s3prl-sd-dummy/raw/main/audio.wav"
data, samplerate = sf.read(io.BytesIO(urlopen(url).read()))
print(model(data))
"""