from typing import Dict, List, Tuple import numpy as np import os import torch class PreTrainedPipeline(): def __init__(self, path=''): self.model = torch.hub.load('sigsep/open-unmix-pytorch', 'umxhq') self.sampling_rate = int(self.model.sample_rate.item()) def __call__(self, inputs): estimates = self.model(inputs.unsqueeze(0)) vocals = estimates[0][0].detach().numpy() n = vocals.shape[0] return vocals, self.sampling_rate, [f"label_{i}" for i in range(n)]