File size: 527 Bytes
5ad45ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from typing import Dict, List, Tuple
import numpy as np
import os
import torch
class PreTrainedPipeline():
def __init__(self, path=''):
self.model = torch.hub.load('sigsep/open-unmix-pytorch', 'umxhq')
self.sampling_rate = int(self.model.sample_rate.item())
def __call__(self, inputs):
estimates = self.model(inputs.unsqueeze(0))
vocals = estimates[0][0].detach().numpy()
n = vocals.shape[0]
return vocals, self.sampling_rate, [f"label_{i}" for i in range(n)]
|