mikefish commited on
Commit
a02905a
·
verified ·
1 Parent(s): 7620eca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -1,23 +1,18 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, WhisperProcessor, WhisperForConditionalGeneration
3
  import torch
4
  import edge_tts
5
  import asyncio
6
- import soundfile as sf
7
- import io
8
  import numpy as np
9
 
10
  class FrenchLearningApp:
11
  def __init__(self):
12
  # Initialize models
13
  self.conversation_model = pipeline("text-generation", model="gpt2")
14
- self.translation_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
15
- self.translation_tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
16
 
17
  # Initialize Whisper model
18
  self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
19
  self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
20
- self.whisper_model.config.forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="french", task="transcribe")
21
 
22
  self.context = "Start a conversation in French"
23
  self.learning_goals = []
@@ -44,15 +39,18 @@ class FrenchLearningApp:
44
  return (24000, audio_float), french_text # 24000 is the default sample rate for edge-tts
45
 
46
  def process_user_response(self, audio):
47
- # Transcribe audio to text using Whisper
48
  input_features = self.whisper_processor(audio, sampling_rate=16000, return_tensors="pt").input_features
 
 
 
49
  predicted_ids = self.whisper_model.generate(input_features)
50
  french_text = self.whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
51
 
52
- # Translate French to English
53
- inputs = self.translation_tokenizer(french_text, return_tensors="pt")
54
- translated = self.translation_model.generate(**inputs)
55
- english_text = self.translation_tokenizer.decode(translated[0], skip_special_tokens=True)
56
 
57
  # Analyze response (simplified)
58
  analysis = self.analyze_response(english_text)
@@ -103,4 +101,4 @@ def launch_app():
103
  interface.launch()
104
 
105
  if __name__ == "__main__":
106
- launch_app()
 
1
  import gradio as gr
2
+ from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration
3
  import torch
4
  import edge_tts
5
  import asyncio
 
 
6
  import numpy as np
7
 
8
  class FrenchLearningApp:
9
  def __init__(self):
10
  # Initialize models
11
  self.conversation_model = pipeline("text-generation", model="gpt2")
 
 
12
 
13
  # Initialize Whisper model
14
  self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
15
  self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
 
16
 
17
  self.context = "Start a conversation in French"
18
  self.learning_goals = []
 
39
  return (24000, audio_float), french_text # 24000 is the default sample rate for edge-tts
40
 
41
  def process_user_response(self, audio):
42
+ # Transcribe audio to French text using Whisper
43
  input_features = self.whisper_processor(audio, sampling_rate=16000, return_tensors="pt").input_features
44
+
45
+ # Generate French transcription
46
+ self.whisper_model.config.forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="french", task="transcribe")
47
  predicted_ids = self.whisper_model.generate(input_features)
48
  french_text = self.whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
49
 
50
+ # Translate French to English using Whisper
51
+ self.whisper_model.config.forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="french", task="translate")
52
+ predicted_ids = self.whisper_model.generate(input_features)
53
+ english_text = self.whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
54
 
55
  # Analyze response (simplified)
56
  analysis = self.analyze_response(english_text)
 
101
  interface.launch()
102
 
103
  if __name__ == "__main__":
104
+ launch_app()