ahricat commited on
Commit
71b2ec4
1 Parent(s): 4222d8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -44
app.py CHANGED
@@ -1,53 +1,47 @@
1
- @app.function(
 
2
  class InteractiveChat:
3
- def __init__(self, model_name="openai/whisper-large", tts_choice="OpenVoice", **kwargs):
4
- self.whisper_processor = WhisperProcessor.from_pretrained(model_name)
5
- self.whisper_model = WhisperForConditionalGeneration.from_pretrained(model_name)
 
 
6
  self.zephyr_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
7
  self.zephyr_model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta", device_map="auto")
8
- self.zephyr_pipeline = pipeline("text-generation", model=self.zephyr_model, tokenizer=self.zephyr_tokenizer)
9
- self.tts_choice = tts_choice
10
 
11
  def generate_response(self, input_data):
12
- input_features = self.whisper_processor(input_data, sampling_rate=16_000, return_tensors="pt").input_features
13
  predicted_ids = self.whisper_model.generate(input_features)
14
- transcription = self.whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
15
-
 
 
 
 
16
 
17
- response = self.zephyr_pipeline(transcription, max_length=1000)[0]["generated_text"]
18
- return transcription, response
 
 
19
 
20
  def speak(self, text):
21
- try:
22
- if self.tts_choice == "OpenVoice":
23
- model_path = snapshot_download("facebook/mms-tts-eng")
24
- pipe = pipeline("text-to-speech", model=model_path)
25
- audio_array = pipe(text).audio
26
- pygame.mixer.init()
27
- sound = pygame.sndarray.make_sound(audio_array)
28
- sound.play()
29
- pygame.time.delay(int(sound.get_length() * 1000))
30
- else: # gTTS
31
- tts = gTTS(text=text, lang='en')
32
- tts.save("response.mp3")
33
- pygame.mixer.init()
34
- pygame.mixer.music.load("response.mp3")
35
- pygame.mixer.music.play()
36
- while pygame.mixer.music.get_busy():
37
- pygame.time.Clock().tick(10)
38
- except Exception as e:
39
- print("Error occurred during speech generation:", e)
40
-
41
-
42
- with gr.Blocks() as demo:
43
- model_choice = gr.Dropdown(["openai/whisper-large"], label="Whisper Model", value="openai/whisper-large")
44
- tts_choice = gr.Radio(["OpenVoice", "gTTS"], label="TTS Engine", value="OpenVoice")
45
- input_data = gr.Audio(source="microphone", type="numpy", label="Speak Your Message")
46
- output_text = gr.Textbox(label="Transcription and Response")
47
-
48
- model_choice.change(lambda x, y: InteractiveChat(x, y), inputs=[model_choice, tts_choice], outputs=None)
49
- tts_choice.change(lambda x, y: InteractiveChat(y, x), inputs=[tts_choice, model_choice], outputs=None)
50
- input_data.change(lambda x, model: model.generate_response(x), inputs=[input_data, model_choice],
51
- outputs=output_text)
52
- input_data.change(lambda x, model: model.speak(x[1]), inputs=[output_text, model_choice],
53
- outputs=None)) # Speak the response
 
1
+ import gradio as gr
2
+
3
  class InteractiveChat:
4
+
5
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large")
6
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
7
+
8
+ def __init__(self):
9
  self.zephyr_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
10
  self.zephyr_model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta", device_map="auto")
 
 
11
 
12
  def generate_response(self, input_data):
13
+ input_features = self.whisper_processor(input_data)
14
  predicted_ids = self.whisper_model.generate(input_features)
15
+ transcription = self.whisper_processor.batch_decode(predicted_ids)
16
+
17
+ response = self.get_zephyr_response(transcription)
18
+
19
+ self.speak(response)
20
+ return response
21
 
22
+ def get_zephyr_response(self, transcription):
23
+ zephyr_pipeline = pipeline("text-generation")
24
+ response = zephyr_pipeline(transcription)[0]["generated_text"]
25
+ return response
26
 
27
  def speak(self, text):
28
+ speech_client = SpeechClient()
29
+ speech_client.synthesize(text)
30
+
31
+ def generate_response(self, input):
32
+
33
+ # get transcription from Whisper
34
+
35
+ response = self.get_zephyr_response(transcription)
36
+
37
+ self.speak(response)
38
+
39
+ return response
40
+
41
+ interface = gr.Interface(
42
+ gr.Audio(type="microphone"),
43
+ gr.Textbox(),
44
+ self.generate_response
45
+ )
46
+
47
+ interface.launch()