|
""" |
|
This is just an example of what people would submit for inference. |
|
""" |
|
import os |
|
from typing import Dict, List |
|
|
|
import torch |
|
from s3prl.downstream.runner import Runner |
|
|
|
class PreTrainedModel(Runner): |
|
def __init__(self, path=""): |
|
""" |
|
Initialize downstream model. |
|
""" |
|
ckp_file = os.path.join(path, "hubert_sd.ckpt") |
|
ckp = torch.load(ckp_file, map_location="cpu") |
|
ckp["Args"].init_ckpt = ckp_file |
|
ckp["Args"].mode = "inference" |
|
ckp["Args"].device = "cpu" |
|
Runner.__init__(self, ckp["Args"], ckp["Config"]) |
|
|
|
def __call__(self, inputs) -> List[int]: |
|
""" |
|
Args: inputs (:obj:`np.array`): The raw waveform of audio received. By |
|
default at 16KHz. |
|
Return: A list with logits. |
|
""" |
|
for entry in self.all_entries: |
|
entry.model.eval() |
|
|
|
inputs = [torch.FloatTensor(inputs)] |
|
|
|
with torch.no_grad(): |
|
features = self.upstream.model(inputs) |
|
features = self.featurizer.model(inputs, features) |
|
preds = self.downstream.model.inference(features, []) |
|
return preds[0] |
|
|
|
""" |
|
import io |
|
import soundfile as sf |
|
from urllib.request import urlopen |
|
model = PreTrainedModel() |
|
url = "https://huggingface.co/datasets/lewtun/s3prl-sd-dummy/raw/main/audio.wav" |
|
data, samplerate = sf.read(io.BytesIO(urlopen(url).read())) |
|
print(model(data)) |
|
""" |