aiola commited on
Commit
b12baee
·
verified ·
1 Parent(s): 385e91b

update readme

Browse files
Files changed (1) hide show
  1. README.md +51 -16
README.md CHANGED
@@ -41,6 +41,7 @@ from experiments.utils import set_logger, get_device, remove_suppress_tokens
41
  from experiments.utils.utils import UNSUPPRESS_TOKEN
42
  import torchaudio
43
  import numpy as np
 
44
  set_logger()
45
 
46
 
@@ -63,7 +64,9 @@ def main(model_path, audio_file_path, prompt, max_new_tokens, language, device):
63
  if signal.ndim == 2:
64
  signal = torch.mean(signal, dim=0)
65
  # pre-process to get the input features
66
- input_features = processor(signal, sampling_rate=target_sample_rate, return_tensors="pt").input_features
 
 
67
  input_features = input_features.to(device)
68
 
69
  prompt = prompt.lower() # lowercase the prompt, to align with training
@@ -74,8 +77,11 @@ def main(model_path, audio_file_path, prompt, max_new_tokens, language, device):
74
  # generate token ids by running model forward sequentially
75
  logging.info(f"Inference with prompt: '{prompt}'.")
76
  predicted_ids = model.generate(
77
- input_features, max_new_tokens=max_new_tokens, language=language,
78
- prompt_ids=prompt_ids, generation_config=model.generation_config
 
 
 
79
  )
80
 
81
  # post-process token ids to text
@@ -84,21 +90,50 @@ def main(model_path, audio_file_path, prompt, max_new_tokens, language, device):
84
 
85
 
86
  if __name__ == "__main__":
87
- parser = argparse.ArgumentParser(description="Transcribe audio using Whisper model.")
88
- parser.add_argument('--model-path', type=str,
89
- required=True,
90
- default="aiola/whisper-ner-v1",
91
- help='Path to the pre-trained model components.')
92
- parser.add_argument('--audio-file-path',
93
- type=str,
94
- required=True,
95
- help='Path to the audio file to transcribecd.')
96
- parser.add_argument('--prompt', type=str, default='father', help='Prompt text to guide the transcription.')
97
- parser.add_argument('--max-new-tokens', type=int, default=256, help='Maximum number of new tokens to generate.')
98
- parser.add_argument('--language', type=str, default='en', help='Language code for the transcription.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  args = parser.parse_args()
101
  device = get_device()
102
- main(args.model_path, args.audio_file_path, args.prompt, args.max_new_tokens, args.language, device)
 
 
 
 
 
 
 
103
 
104
  ```
 
41
  from experiments.utils.utils import UNSUPPRESS_TOKEN
42
  import torchaudio
43
  import numpy as np
44
+
45
  set_logger()
46
 
47
 
 
64
  if signal.ndim == 2:
65
  signal = torch.mean(signal, dim=0)
66
  # pre-process to get the input features
67
+ input_features = processor(
68
+ signal, sampling_rate=target_sample_rate, return_tensors="pt"
69
+ ).input_features
70
  input_features = input_features.to(device)
71
 
72
  prompt = prompt.lower() # lowercase the prompt, to align with training
 
77
  # generate token ids by running model forward sequentially
78
  logging.info(f"Inference with prompt: '{prompt}'.")
79
  predicted_ids = model.generate(
80
+ input_features,
81
+ max_new_tokens=max_new_tokens,
82
+ language=language,
83
+ prompt_ids=prompt_ids,
84
+ generation_config=model.generation_config,
85
  )
86
 
87
  # post-process token ids to text
 
90
 
91
 
92
  if __name__ == "__main__":
93
+ parser = argparse.ArgumentParser(
94
+ description="Transcribe audio using Whisper model."
95
+ )
96
+ parser.add_argument(
97
+ "--model-path",
98
+ type=str,
99
+ required=True,
100
+ default="aiola/whisper-ner-v1",
101
+ help="Path to the pre-trained model components.",
102
+ )
103
+ parser.add_argument(
104
+ "--audio-file-path",
105
+ type=str,
106
+ required=True,
107
+ help="Path to the audio file (wav) to transcribe.",
108
+ )
109
+ parser.add_argument(
110
+ "--prompt",
111
+ type=str,
112
+ default="father",
113
+ help="Prompt text to guide the transcription.",
114
+ )
115
+ parser.add_argument(
116
+ "--max-new-tokens",
117
+ type=int,
118
+ default=256,
119
+ help="Maximum number of new tokens to generate.",
120
+ )
121
+ parser.add_argument(
122
+ "--language",
123
+ type=str,
124
+ default="en",
125
+ help="Language code for the transcription.",
126
+ )
127
 
128
  args = parser.parse_args()
129
  device = get_device()
130
+ main(
131
+ args.model_path,
132
+ args.audio_file_path,
133
+ args.prompt,
134
+ args.max_new_tokens,
135
+ args.language,
136
+ device,
137
+ )
138
 
139
  ```