osanseviero HF staff commited on
Commit
b00ee09
1 Parent(s): 37a3159

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +2 -1
pipeline.py CHANGED
@@ -3,6 +3,7 @@ from typing import Dict, List, Tuple
3
  import numpy as np
4
  from asteroid import separate
5
  from asteroid.models import BaseModel
 
6
 
7
 
8
  class PreTrainedPipeline():
@@ -11,7 +12,7 @@ class PreTrainedPipeline():
11
  # Preload all the elements you are going to need at inference.
12
  # For instance your model, processors, tokenizer that might be needed.
13
  # This function is only called once, so do all the heavy processing I/O here"""
14
- self.model = BaseModel.from_pretrained(path)
15
  self.sampling_rate = self.model.sample_rate
16
 
17
  def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
3
  import numpy as np
4
  from asteroid import separate
5
  from asteroid.models import BaseModel
6
+ import os
7
 
8
 
9
  class PreTrainedPipeline():
12
  # Preload all the elements you are going to need at inference.
13
  # For instance your model, processors, tokenizer that might be needed.
14
  # This function is only called once, so do all the heavy processing I/O here"""
15
+ self.model = BaseModel.from_pretrained(os.path.join(path, "pytorch_model.bin"))
16
  self.sampling_rate = self.model.sample_rate
17
 
18
  def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]: