hubert_s3prl / model.py
1
"""
2
This is just an example of what people would submit for
3
inference.
4
"""
5
6
from s3prl.downstream.runner import Runner
7
from typing import Dict
8
import torch
9
import os
10
11
12
class PreTrainedModel(Runner):
13
    def __init__(self, path=""):
14
        """
15
        Initialize downstream model.
16
        """
17
        ckp_file = os.path.join(path, "hubert_asr.ckpt")
18
        ckp = torch.load(ckp_file, map_location='cpu')
19
        ckp["Args"].init_ckpt = ckp_file
20
        ckp["Args"].mode = "inference"
21
        ckp["Args"].device = "cpu" # Just to try in my computer
22
        ckp["Config"]["downstream_expert"]["datarc"]["dict_path"]=os.path.join(path,'char.dict')
23
24
        Runner.__init__(self, ckp["Args"], ckp["Config"])
25
26
    def __call__(self, inputs)-> Dict[str, str]:
27
        """
28
        Args:
29
            inputs (:obj:`np.array`):
30
                The raw waveform of audio received. By default at 16KHz.
31
        Return:
32
            A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
33
            the detected text from the input audio.
34
        """
35
        for entry in self.all_entries:
36
            entry.model.eval()
37
38
        inputs = [torch.FloatTensor(inputs)]
39
40
        with torch.no_grad():
41
            features = self.upstream.model(inputs)
42
            features = self.featurizer.model(inputs, features)
43
            preds = self.downstream.model.inference(features, [])
44
        return {"text": preds[0]}
45
46
47
"""
48
import subprocess
49
import numpy as np
50
from datasets import load_dataset
51
# This is already done in the Inference API
52
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
53
    ar = f"{sampling_rate}"
54
    ac = "1"
55
    format_for_conversion = "f32le"
56
    ffmpeg_command = [
57
        "ffmpeg",
58
        "-i",
59
        "pipe:0",
60
        "-ac",
61
        ac,
62
        "-ar",
63
        ar,
64
        "-f",
65
        format_for_conversion,
66
        "-hide_banner",
67
        "-loglevel",
68
        "quiet",
69
        "pipe:1",
70
    ]
71
72
    ffmpeg_process = subprocess.Popen(
73
        ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE
74
    )
75
    output_stream = ffmpeg_process.communicate(bpayload)
76
    out_bytes = output_stream[0]
77
78
    audio = np.frombuffer(out_bytes, np.float32).copy()
79
    if audio.shape[0] == 0:
80
        raise ValueError("Malformed soundfile")
81
    return audio
82
83
84
model = PreTrainedModel()
85
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
86
filename = ds[0]["file"]
87
with open(filename, "rb") as f:
88
    data = ffmpeg_read(f.read(), 16000)
89
    print(model(data)) 
90
"""