osanseviero HF staff commited on
Commit
2302c8e
1 Parent(s): 41bb0ab

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -2
model.py CHANGED
@@ -6,14 +6,15 @@ inference.
6
  from s3prl.downstream.runner import Runner
7
  from typing import Dict
8
  import torch
 
9
 
10
 
11
  class PreTrainedModel(Runner):
12
- def __init__(self):
13
  """
14
  Loads model and tokenizer from local directory
15
  """
16
- ckp_file = "hubert_asr.ckpt"
17
  ckp = torch.load(ckp_file, map_location='cpu')
18
  ckp["Args"].init_ckpt = ckp_file
19
  ckp["Args"].mode = "inference"
6
  from s3prl.downstream.runner import Runner
7
  from typing import Dict
8
  import torch
9
+ import os
10
 
11
 
12
  class PreTrainedModel(Runner):
13
+ def __init__(self, path=""):
14
  """
15
  Loads model and tokenizer from local directory
16
  """
17
+ ckp_file = os.path.join(path, "hubert_asr.ckpt")
18
  ckp = torch.load(ckp_file, map_location='cpu')
19
  ckp["Args"].init_ckpt = ckp_file
20
  ckp["Args"].mode = "inference"