aiola commited on
Commit
385e91b
·
verified ·
1 Parent(s): 45820bd

Update readme

Browse files
Files changed (1) hide show
  1. README.md +74 -11
README.md CHANGED
@@ -24,18 +24,81 @@ During training, the model is prompted with NER labels and optimized to output t
24
  ---------
25
 
26
  ## Training Details
27
-
28
-
29
-
30
 
31
  ---------
32
 
33
  ## Usage
34
-
35
-
36
-
37
-
38
-
39
-
40
-
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ---------
25
 
26
  ## Training Details
27
+ `aiola/whisper-ner-v1` was trained on the Nuner dataset to perform audio translation with ner at the same time in English only.
 
 
28
 
29
  ---------
30
 
31
  ## Usage
32
+ To use `whisper-ner-v1` install [`whisper-ner`](https://github.com/aiola-lab/whisper-ner) repo following the README instructions.
33
+
34
+ Inference can be done using the following code:
35
+ ```python
36
+ import logging
37
+ import argparse
38
+ import torch
39
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
40
+ from experiments.utils import set_logger, get_device, remove_suppress_tokens
41
+ from experiments.utils.utils import UNSUPPRESS_TOKEN
42
+ import torchaudio
43
+ import numpy as np
44
+ set_logger()
45
+
46
+
47
+ @torch.no_grad()
48
+ def main(model_path, audio_file_path, prompt, max_new_tokens, language, device):
49
+ # load model and processor from pre-trained
50
+ processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
51
+ model = WhisperForConditionalGeneration.from_pretrained(model_path)
52
+ remove_suppress_tokens(model)
53
+ logging.info(f"removed suppress tokens: {UNSUPPRESS_TOKEN}")
54
+
55
+ model = model.to(device)
56
+
57
+ # load audio file: user is responsible for loading the audio files themselves
58
+ target_sample_rate = 16000
59
+ signal, sampling_rate = torchaudio.load(audio_file_path)
60
+ resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate)
61
+ signal = resampler(signal)
62
+ # convert to mono or remove first dim if needed
63
+ if signal.ndim == 2:
64
+ signal = torch.mean(signal, dim=0)
65
+ # pre-process to get the input features
66
+ input_features = processor(signal, sampling_rate=target_sample_rate, return_tensors="pt").input_features
67
+ input_features = input_features.to(device)
68
+
69
+ prompt = prompt.lower() # lowercase the prompt, to align with training
70
+
71
+ prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt")
72
+ prompt_ids = prompt_ids.to(device)
73
+
74
+ # generate token ids by running model forward sequentially
75
+ logging.info(f"Inference with prompt: '{prompt}'.")
76
+ predicted_ids = model.generate(
77
+ input_features, max_new_tokens=max_new_tokens, language=language,
78
+ prompt_ids=prompt_ids, generation_config=model.generation_config
79
+ )
80
+
81
+ # post-process token ids to text
82
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
83
+ print(transcription)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ parser = argparse.ArgumentParser(description="Transcribe audio using Whisper model.")
88
+ parser.add_argument('--model-path', type=str,
89
+ required=True,
90
+ default="aiola/whisper-ner-v1",
91
+ help='Path to the pre-trained model components.')
92
+ parser.add_argument('--audio-file-path',
93
+ type=str,
94
+ required=True,
95
+ help='Path to the audio file to transcribecd.')
96
+ parser.add_argument('--prompt', type=str, default='father', help='Prompt text to guide the transcription.')
97
+ parser.add_argument('--max-new-tokens', type=int, default=256, help='Maximum number of new tokens to generate.')
98
+ parser.add_argument('--language', type=str, default='en', help='Language code for the transcription.')
99
+
100
+ args = parser.parse_args()
101
+ device = get_device()
102
+ main(args.model_path, args.audio_file_path, args.prompt, args.max_new_tokens, args.language, device)
103
+
104
+ ```