osanseviero's picture
Add model
d00a1b3
raw
history blame
1.67 kB
from typing import Dict, List, Union
from PIL import Image
import os
import json
import numpy as np
from fastai.learner import load_learner
class PreTrainedPipeline():
def __init__(self, path=""):
# IMPLEMENT_THIS
# Preload all the elements you are going to need at inference.
# For instance your model, processors, tokenizer that might be needed.
# This function is only called once, so do all the heavy processing I/O here"""
self.model = BaseModel.from_pretrained("")
self.sampling_rate = self.model.sample_rate
def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default sampled at `self.sampling_rate`.
The shape of this array is `T`, where `T` is the time axis
Return:
A :obj:`tuple` containing:
- :obj:`np.array`:
The return shape of the array must be `C'`x`T'`
- a :obj:`int`: the sampling rate as an int in Hz.
- a :obj:`List[str]`: the annotation for each out channel.
This can be the name of the instruments for audio source separation
or some annotation for speech enhancement. The length must be `C'`.
"""
separated = separate.numpy_separate(self.model, inputs.reshape((1, 1, -1)))
# FIXME: how to deal with multiple sources?
out = separated[0]
n = out.shape[0]
labels = [f"label_{i}" for i in range(n)]
return separated[0], int(self.model.sample_rate), labels