MusIre commited on
Commit
eb7f955
β€’
1 Parent(s): b309292

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -4,9 +4,8 @@ subprocess.run(["pip", "install", "datasets"])
4
  subprocess.run(["pip", "install", "transformers"])
5
  subprocess.run(["pip", "install", "torch", "torchvision", "torchaudio", "-f", "https://download.pytorch.org/whl/torch_stable.html"])
6
 
7
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
8
- from datasets import load_dataset
9
  import gradio as gr
 
10
 
11
  # Load model and processor
12
  processor = WhisperProcessor.from_pretrained("openai/whisper-large")
@@ -16,16 +15,21 @@ model.config.forced_decoder_ids = None
16
  # Function to perform ASR on audio data
17
  def transcribe_audio(audio_data):
18
  # Process audio data using the Whisper processor
19
- input_features = processor(audio_data, return_tensors="pt").input_features
20
 
21
  # Generate token ids
22
  predicted_ids = model.generate(input_features)
23
-
24
  # Decode token ids to text
25
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
26
-
27
  return transcription[0]
28
 
 
 
 
 
 
29
  # Create Gradio interface
30
- audio_input = gr.Audio(preprocessing_fn=None)
31
  gr.Interface(fn=transcribe_audio, inputs=audio_input, outputs="text").launch()
 
4
  subprocess.run(["pip", "install", "transformers"])
5
  subprocess.run(["pip", "install", "torch", "torchvision", "torchaudio", "-f", "https://download.pytorch.org/whl/torch_stable.html"])
6
 
 
 
7
  import gradio as gr
8
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
9
 
10
  # Load model and processor
11
  processor = WhisperProcessor.from_pretrained("openai/whisper-large")
 
15
  # Function to perform ASR on audio data
16
  def transcribe_audio(audio_data):
17
  # Process audio data using the Whisper processor
18
+ input_features = processor(audio_data, return_tensors="pt").input_features
19
 
20
  # Generate token ids
21
  predicted_ids = model.generate(input_features)
22
+
23
  # Decode token ids to text
24
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
25
+
26
  return transcription[0]
27
 
28
+ # Custom preprocessing function
29
+ def preprocess_audio(audio_data):
30
+ # Apply any custom preprocessing to the audio data here if needed
31
+ return audio_data
32
+
33
  # Create Gradio interface
34
+ audio_input = gr.Audio(preprocess=preprocess_audio)
35
  gr.Interface(fn=transcribe_audio, inputs=audio_input, outputs="text").launch()