carmi commited on
Commit
889465b
1 Parent(s): 5e4a493

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +36 -0
inference.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import torch
3
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
4
+ from transformers import WhisperTokenizer
5
+ import torchaudio
6
+
7
+
8
+ def transcribe_audio(files_dir_path):
9
+ """
10
+ Transcribe an audio file using the Whisper model.
11
+ Args:
12
+ files_dir_path (str): The path to the audio files directory.
13
+ """
14
+ for file_path in glob.glob(files_dir_path + '/*.wav'):
15
+ audio_input, samplerate = torchaudio.load(file_path)
16
+ inputs = processor(audio_input.squeeze(), return_tensors="pt", sampling_rate=samplerate)
17
+ with torch.no_grad():
18
+ predicted_ids = model.generate(inputs["input_features"].to(device))
19
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
20
+ print(transcription[0])
21
+
22
+
23
+ if __name__ == '__main__':
24
+ wav_dir_path = '/home/user/Desktop/arb_stt/test/'
25
+ checkpoint_path = '/home/user/Desktop/arb_stt/best_models/medium/checkpoint-3300'
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+
28
+ # Initialize model and processor
29
+ tokenizer = WhisperTokenizer.from_pretrained(f'{checkpoint_path}/tokenizer', language="Arabic", task="transcribe")
30
+ processor = WhisperProcessor.from_pretrained(f'{checkpoint_path}/processor', language="Arabic", task="transcribe")
31
+ model = WhisperForConditionalGeneration.from_pretrained(checkpoint_path).to(device)
32
+ model.generation_config.language = "arabic"
33
+ model.generation_config.task = "transcribe"
34
+ model.eval()
35
+
36
+ transcribe_audio(wav_dir_path)