File size: 1,440 Bytes
4fec958
 
 
 
af7cc76
4fec958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af7cc76
4fec958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
This is just an example of what people would submit for inference.
"""
import os
from typing import Dict, List

import torch
from s3prl.downstream.runner import Runner

class PreTrainedModel(Runner):
    def __init__(self, path=""):
        """
        Initialize downstream model.
        """
        ckp_file = os.path.join(path, "hubert_sd.ckpt")
        ckp = torch.load(ckp_file, map_location="cpu")
        ckp["Args"].init_ckpt = ckp_file
        ckp["Args"].mode = "inference"
        ckp["Args"].device = "cpu"  # Just to try in my computer
        Runner.__init__(self, ckp["Args"], ckp["Config"])

    def __call__(self, inputs) -> List[int]:
        """
        Args: inputs (:obj:`np.array`): The raw waveform of audio received. By
            default at 16KHz.
        Return: A list with logits.
        """
        for entry in self.all_entries:
            entry.model.eval()

        inputs = [torch.FloatTensor(inputs)]

        with torch.no_grad():
            features = self.upstream.model(inputs)
            features = self.featurizer.model(inputs, features)
            preds = self.downstream.model.inference(features, [])
        return preds[0]

"""
import io
import soundfile as sf
from urllib.request import urlopen
model = PreTrainedModel()
url = "https://huggingface.co/datasets/lewtun/s3prl-sd-dummy/raw/main/audio.wav"
data, samplerate = sf.read(io.BytesIO(urlopen(url).read()))
print(model(data))
"""