cindyangelira commited on
Commit
ce99676
·
verified ·
1 Parent(s): aba5bcc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -39
app.py CHANGED
@@ -7,28 +7,36 @@ from transformers import (
7
  pipeline,
8
  AutoProcessor,
9
  AutoModelForSpeechSeq2Seq,
10
- BitsAndBytesConfig
 
 
 
11
  )
12
  from datasets import load_dataset
13
  import numpy as np
14
- from transformers import AutoModelForTextToSpeech, SpeechT5HifiGan
15
  import torchaudio
16
 
17
  @spaces.GPU
18
- def dummy(): # just a dummy
19
  pass
20
 
21
- # Constants
22
- # DEVICE = "cpu"
23
  LANGUAGE_CODES = {
24
  "English": "en",
25
  "Chinese": "zh"
26
  }
27
 
28
- # Initialize components with efficient settings
 
 
 
 
 
 
 
 
 
29
  def initialize_components():
30
- # Use XVERSE-13B-Chat as the base model - good multilingual support and reasonable size
31
- # Load in 4-bit quantization to reduce memory usage
32
  bnb_config = BitsAndBytesConfig(
33
  load_in_4bit=True,
34
  bnb_4bit_quant_type="nf4",
@@ -42,43 +50,45 @@ def initialize_components():
42
  )
43
  tokenizer = AutoTokenizer.from_pretrained("xverse/XVERSE-13B-Chat")
44
 
45
- # Whisper model for STT (small for efficiency)
46
- processor = AutoProcessor.from_pretrained("openai/whisper-small")
47
  stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
48
  "openai/whisper-small",
49
  torch_dtype=torch.float32,
50
  low_cpu_mem_usage=True,
51
  )
52
 
53
- # VITS for TTS (supports both English and Chinese)
54
- tts_model = load_model("facebook/mms-tts-eng")
 
55
  vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
56
-
57
- return llm, tokenizer, processor, stt_model, tts_model, vocoder
58
-
59
- def load_model(model_name):
60
- """Helper function to load models with optimized settings"""
61
- return AutoModelForTextToSpeech.from_pretrained(
62
- model_name,
63
- torch_dtype=torch.float32,
64
- low_cpu_mem_usage=True,
65
- )
66
 
67
  class ConversationManager:
68
  def __init__(self):
69
  self.history = []
 
70
 
71
- def add_message(self, role, content, audio_path=None):
72
  self.history.append({
73
  "role": role,
74
- "content": content,
75
- "audio_path": audio_path
76
  })
77
 
78
  def get_formatted_history(self):
79
- return "\n".join([
 
80
  f"{msg['role']}: {msg['content']}" for msg in self.history
81
  ])
 
 
 
 
82
 
83
  def speech_to_text(audio, processor, model, target_language):
84
  """Convert speech to text using Whisper"""
@@ -113,15 +123,19 @@ def generate_response(prompt, llm, tokenizer):
113
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
114
  return response
115
 
116
- def text_to_speech(text, model, vocoder, language):
117
- """Convert text to speech using MMS-TTS"""
118
- inputs = processor(text, return_tensors="pt")
119
- speech = model.generate_speech(inputs["input_ids"], vocoder)
 
 
 
 
120
  return speech
121
 
122
  def create_gradio_interface():
123
  # Initialize components
124
- llm, tokenizer, processor, stt_model, tts_model, vocoder = initialize_components()
125
  conversation_manager = ConversationManager()
126
 
127
  with gr.Blocks() as interface:
@@ -133,7 +147,6 @@ def create_gradio_interface():
133
  )
134
 
135
  with gr.Row():
136
- # Audio input
137
  audio_input = gr.Audio(
138
  source="microphone",
139
  type="numpy",
@@ -141,7 +154,6 @@ def create_gradio_interface():
141
  )
142
 
143
  with gr.Row():
144
- # Chat history display
145
  chat_display = gr.Textbox(
146
  value="",
147
  label="Conversation History",
@@ -150,17 +162,18 @@ def create_gradio_interface():
150
  )
151
 
152
  with gr.Row():
153
- # Assistant's audio response
154
  audio_output = gr.Audio(
155
- label="Assistant's Response",
156
  type="numpy"
157
  )
158
 
159
  def process_conversation(audio, language):
 
 
160
  # Speech to text
161
  user_text = speech_to_text(
162
  audio,
163
- processor,
164
  stt_model,
165
  language
166
  )
@@ -169,14 +182,15 @@ def create_gradio_interface():
169
  # Generate LLM response
170
  context = conversation_manager.get_formatted_history()
171
  response = generate_response(context, llm, tokenizer)
172
- conversation_manager.add_message("Assistant", response)
173
 
174
  # Text to speech
175
  speech_output = text_to_speech(
176
  response,
 
177
  tts_model,
178
  vocoder,
179
- language
180
  )
181
 
182
  return (
@@ -192,7 +206,6 @@ def create_gradio_interface():
192
 
193
  return interface
194
 
195
- # Launch the application
196
  if __name__ == "__main__":
197
  interface = create_gradio_interface()
198
  interface.launch()
 
7
  pipeline,
8
  AutoProcessor,
9
  AutoModelForSpeechSeq2Seq,
10
+ BitsAndBytesConfig,
11
+ SpeechT5Processor,
12
+ SpeechT5ForTextToSpeech,
13
+ SpeechT5HifiGan
14
  )
15
  from datasets import load_dataset
16
  import numpy as np
 
17
  import torchaudio
18
 
19
  @spaces.GPU
20
+ def dummy(): # just a dummy
21
  pass
22
 
 
 
23
  LANGUAGE_CODES = {
24
  "English": "en",
25
  "Chinese": "zh"
26
  }
27
 
28
+ def get_system_prompt(language):
29
+ if language == "Chinese":
30
+ return """你是Lin Yi(林意),一个友好的AI助手。你是我的好朋友,说话亲切自然。
31
+ 请用中文回答,语气要自然友好。如果我用英文问你问题,你也要用中文回答。
32
+ 记住你要像朋友一样交谈,不要太正式。"""
33
+ else:
34
+ return """You are Lin Yi, a friendly AI assistant and my good friend (hao pengyou).
35
+ Speak naturally and warmly. If I speak in Chinese, respond in English.
36
+ Remember to converse like a friend, not too formal."""
37
+
38
  def initialize_components():
39
+ # LLM initialization
 
40
  bnb_config = BitsAndBytesConfig(
41
  load_in_4bit=True,
42
  bnb_4bit_quant_type="nf4",
 
50
  )
51
  tokenizer = AutoTokenizer.from_pretrained("xverse/XVERSE-13B-Chat")
52
 
53
+ # Speech-to-text
54
+ whisper_processor = AutoProcessor.from_pretrained("openai/whisper-small")
55
  stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
56
  "openai/whisper-small",
57
  torch_dtype=torch.float32,
58
  low_cpu_mem_usage=True,
59
  )
60
 
61
+ # Text-to-speech
62
+ tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
63
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
64
  vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
65
+
66
+ # Load speaker embedding
67
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
68
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
69
+
70
+ return llm, tokenizer, whisper_processor, stt_model, tts_processor, tts_model, vocoder, speaker_embeddings
 
 
 
 
71
 
72
  class ConversationManager:
73
  def __init__(self):
74
  self.history = []
75
+ self.current_language = "English"
76
 
77
+ def add_message(self, role, content):
78
  self.history.append({
79
  "role": role,
80
+ "content": content
 
81
  })
82
 
83
  def get_formatted_history(self):
84
+ system_prompt = get_system_prompt(self.current_language)
85
+ history_text = "\n".join([
86
  f"{msg['role']}: {msg['content']}" for msg in self.history
87
  ])
88
+ return f"{system_prompt}\n\n{history_text}"
89
+
90
+ def set_language(self, language):
91
+ self.current_language = language
92
 
93
  def speech_to_text(audio, processor, model, target_language):
94
  """Convert speech to text using Whisper"""
 
123
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
124
  return response
125
 
126
+ def text_to_speech(text, processor, model, vocoder, speaker_embeddings):
127
+ """Convert text to speech using SpeechT5"""
128
+ inputs = processor(text=text, return_tensors="pt")
129
+ speech = model.generate_speech(
130
+ inputs["input_ids"],
131
+ speaker_embeddings,
132
+ vocoder=vocoder
133
+ )
134
  return speech
135
 
136
  def create_gradio_interface():
137
  # Initialize components
138
+ llm, tokenizer, whisper_processor, stt_model, tts_processor, tts_model, vocoder, speaker_embeddings = initialize_components()
139
  conversation_manager = ConversationManager()
140
 
141
  with gr.Blocks() as interface:
 
147
  )
148
 
149
  with gr.Row():
 
150
  audio_input = gr.Audio(
151
  source="microphone",
152
  type="numpy",
 
154
  )
155
 
156
  with gr.Row():
 
157
  chat_display = gr.Textbox(
158
  value="",
159
  label="Conversation History",
 
162
  )
163
 
164
  with gr.Row():
 
165
  audio_output = gr.Audio(
166
+ label="Lin Yi's Response",
167
  type="numpy"
168
  )
169
 
170
  def process_conversation(audio, language):
171
+ conversation_manager.set_language(language)
172
+
173
  # Speech to text
174
  user_text = speech_to_text(
175
  audio,
176
+ whisper_processor,
177
  stt_model,
178
  language
179
  )
 
182
  # Generate LLM response
183
  context = conversation_manager.get_formatted_history()
184
  response = generate_response(context, llm, tokenizer)
185
+ conversation_manager.add_message("Lin Yi", response)
186
 
187
  # Text to speech
188
  speech_output = text_to_speech(
189
  response,
190
+ tts_processor,
191
  tts_model,
192
  vocoder,
193
+ speaker_embeddings
194
  )
195
 
196
  return (
 
206
 
207
  return interface
208
 
 
209
  if __name__ == "__main__":
210
  interface = create_gradio_interface()
211
  interface.launch()