muhh-b commited on
Commit
306eb78
·
1 Parent(s): 624999a

Update transcription.py

Browse files
Files changed (1) hide show
  1. transcription.py +32 -15
transcription.py CHANGED
@@ -2,15 +2,24 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
  import whisper
4
 
5
-
6
-
7
  tokenizer = AutoTokenizer.from_pretrained("Bhuvana/t5-base-spellchecker")
8
-
9
  model = AutoModelForSeq2SeqLM.from_pretrained("Bhuvana/t5-base-spellchecker")
10
 
11
-
12
  def correct(inputs):
13
- input_ids = tokenizer.encode(inputs,return_tensors='pt')
 
 
 
 
 
 
 
 
 
 
 
14
  sample_output = model.generate(
15
  input_ids,
16
  do_sample=True,
@@ -18,11 +27,24 @@ def correct(inputs):
18
  top_p=0.99,
19
  num_return_sequences=1
20
  )
 
 
21
  res = tokenizer.decode(sample_output[0], skip_special_tokens=True)
22
  return res
23
 
 
24
  whisper_model = whisper.load_model("base")
 
 
25
  def transcribe(audio_file):
 
 
 
 
 
 
 
 
26
  # Load audio and pad/trim it to fit 30 seconds
27
  audio = whisper.load_audio(audio_file)
28
  audio = whisper.pad_or_trim(audio)
@@ -33,21 +55,16 @@ def transcribe(audio_file):
33
  # Make log-Mel spectrogram and move to the same device as the model
34
  mel = whisper.log_mel_spectrogram(mel).to(model.device)
35
 
36
- # Detect the spoken language
37
  _, probs = whisper_model.detect_language(mel)
38
 
39
- # Decode the audio
40
  options = whisper.DecodingOptions(fp16=False)
41
  result = whisper.decode(whisper_model, mel, options)
42
  result_text = result.text
43
 
44
- print('result_text:'+result_text)
 
45
 
 
46
  return correct(result_text)
47
-
48
-
49
-
50
-
51
-
52
-
53
-
 
2
  import torch
3
  import whisper
4
 
5
+ # Initialize tokenizer and model for spell checking
 
6
  tokenizer = AutoTokenizer.from_pretrained("Bhuvana/t5-base-spellchecker")
 
7
  model = AutoModelForSeq2SeqLM.from_pretrained("Bhuvana/t5-base-spellchecker")
8
 
9
+ # Function to correct spelling errors in a given input text
10
  def correct(inputs):
11
+ '''Corrects spelling errors in the input text using the spell checker model.
12
+
13
+ Args:
14
+ inputs (str): The input text to be spell-checked.
15
+
16
+ Returns:
17
+ str: The corrected version of the input text.
18
+ '''
19
+ # Encode the input text using the tokenizer
20
+ input_ids = tokenizer.encode(inputs, return_tensors='pt')
21
+
22
+ # Generate corrected output using the spell checker model
23
  sample_output = model.generate(
24
  input_ids,
25
  do_sample=True,
 
27
  top_p=0.99,
28
  num_return_sequences=1
29
  )
30
+
31
+ # Decode the corrected output and remove special tokens
32
  res = tokenizer.decode(sample_output[0], skip_special_tokens=True)
33
  return res
34
 
35
+ # Load the whisper model for audio transcription
36
  whisper_model = whisper.load_model("base")
37
+
38
+ # Function to transcribe audio file
39
  def transcribe(audio_file):
40
+ '''Transcribes the content of an audio file.
41
+
42
+ Args:
43
+ audio_file (str): The path to the audio file.
44
+
45
+ Returns:
46
+ str: The transcribed text from the audio file, with spelling errors corrected.
47
+ '''
48
  # Load audio and pad/trim it to fit 30 seconds
49
  audio = whisper.load_audio(audio_file)
50
  audio = whisper.pad_or_trim(audio)
 
55
  # Make log-Mel spectrogram and move to the same device as the model
56
  mel = whisper.log_mel_spectrogram(mel).to(model.device)
57
 
58
+ # Detect the spoken language using the whisper model
59
  _, probs = whisper_model.detect_language(mel)
60
 
61
+ # Decode the audio using the whisper model
62
  options = whisper.DecodingOptions(fp16=False)
63
  result = whisper.decode(whisper_model, mel, options)
64
  result_text = result.text
65
 
66
+ # Print the transcribed text
67
+ print('result_text:' + result_text)
68
 
69
+ # Correct any spelling errors in the transcribed text
70
  return correct(result_text)