File size: 527 Bytes
5ad45ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from typing import Dict, List, Tuple
import numpy as np
import os
import torch

class PreTrainedPipeline():
    def __init__(self, path=''):
        self.model = torch.hub.load('sigsep/open-unmix-pytorch', 'umxhq')
        self.sampling_rate = int(self.model.sample_rate.item())

    def __call__(self, inputs): 
        estimates = self.model(inputs.unsqueeze(0))
        vocals = estimates[0][0].detach().numpy()
        n = vocals.shape[0]
        return vocals, self.sampling_rate, [f"label_{i}" for i in range(n)]