nateraw commited on
Commit
5ad45ff
1 Parent(s): 570b37f

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +16 -0
pipeline.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple
2
+ import numpy as np
3
+ import os
4
+ import torch
5
+
6
+ class PreTrainedPipeline():
7
+ def __init__(self, path=''):
8
+ self.model = torch.hub.load('sigsep/open-unmix-pytorch', 'umxhq')
9
+ self.sampling_rate = int(self.model.sample_rate.item())
10
+
11
+ def __call__(self, inputs):
12
+ estimates = self.model(inputs.unsqueeze(0))
13
+ vocals = estimates[0][0].detach().numpy()
14
+ n = vocals.shape[0]
15
+ return vocals, self.sampling_rate, [f"label_{i}" for i in range(n)]
16
+