HoneyTian commited on
Commit
3147eb6
1 Parent(s): 234de07
Files changed (2) hide show
  1. examples/wenet/downaload_model.py +16 -3
  2. main.py +16 -3
examples/wenet/downaload_model.py CHANGED
@@ -51,15 +51,28 @@ def main():
51
  pretrained_model_dir = Path(args.pretrained_model_dir)
52
  pretrained_model_dir.mkdir(exist_ok=True)
53
 
54
- model_dir = pretrained_model_dir / "huggingface" / args.repo_id
55
- model_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  print("download model")
58
  model_filename = huggingface_hub.hf_hub_download(
59
  repo_id=args.repo_id,
60
  filename=args.model_filename,
61
  subfolder=args.model_sub_folder,
62
- local_dir="E:/Users/tianx/HuggingSpaces/asr/pretrained_models/huggingface/luomingshuang/transducer_stateless2",
63
  )
64
  print(model_filename)
65
  exit(0)
 
51
  pretrained_model_dir = Path(args.pretrained_model_dir)
52
  pretrained_model_dir.mkdir(exist_ok=True)
53
 
54
+ repo_id: Path = Path(args.repo_id)
55
+ if len(repo_id.parts) == 1:
56
+ repo_name = repo_id.parts[-1]
57
+ repo_name = repo_name[:40]
58
+ folder = repo_name
59
+ elif len(repo_id.parts) == 2:
60
+ repo_supplier = repo_id.parts[-2]
61
+ repo_name = repo_id.parts[-1]
62
+ repo_name = repo_name[:40]
63
+ folder = "{}/{}".format(repo_supplier, repo_name)
64
+ else:
65
+ raise AssertionError("repo_id parts count invalid: {}".format(len(repo_id.parts)))
66
+
67
+ local_model_dir = pretrained_model_dir / "huggingface" / folder
68
+ local_model_dir.mkdir(parents=True, exist_ok=True)
69
 
70
  print("download model")
71
  model_filename = huggingface_hub.hf_hub_download(
72
  repo_id=args.repo_id,
73
  filename=args.model_filename,
74
  subfolder=args.model_sub_folder,
75
+ local_dir=local_model_dir.as_posix(),
76
  )
77
  print(model_filename)
78
  exit(0)
main.py CHANGED
@@ -99,10 +99,23 @@ def process(
99
  if m_dict is None:
100
  raise AssertionError("repo_id invalid: {}".format(repo_id))
101
 
102
- # load recognizer
103
- local_model_dir = pretrained_model_dir / "huggingface" / repo_id
104
- # local_model_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
105
 
 
106
  recognizer = nn_models.load_recognizer(
107
  local_model_dir=local_model_dir,
108
  decoding_method=decoding_method,
 
99
  if m_dict is None:
100
  raise AssertionError("repo_id invalid: {}".format(repo_id))
101
 
102
+ # local_model_dir
103
+ repo_id: Path = Path(repo_id)
104
+ if len(repo_id.parts) == 1:
105
+ repo_name = repo_id.parts[-1]
106
+ repo_name = repo_name[:40]
107
+ folder = repo_name
108
+ elif len(repo_id.parts) == 2:
109
+ repo_supplier = repo_id.parts[-2]
110
+ repo_name = repo_id.parts[-1]
111
+ repo_name = repo_name[:40]
112
+ folder = "{}/{}".format(repo_supplier, repo_name)
113
+ else:
114
+ raise AssertionError("repo_id parts count invalid: {}".format(len(repo_id.parts)))
115
+
116
+ local_model_dir = pretrained_model_dir / "huggingface" / folder
117
 
118
+ # load recognizer
119
  recognizer = nn_models.load_recognizer(
120
  local_model_dir=local_model_dir,
121
  decoding_method=decoding_method,