Spaces:
Runtime error
Runtime error
cindyangelira
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -7,28 +7,36 @@ from transformers import (
|
|
7 |
pipeline,
|
8 |
AutoProcessor,
|
9 |
AutoModelForSpeechSeq2Seq,
|
10 |
-
BitsAndBytesConfig
|
|
|
|
|
|
|
11 |
)
|
12 |
from datasets import load_dataset
|
13 |
import numpy as np
|
14 |
-
from transformers import AutoModelForTextToSpeech, SpeechT5HifiGan
|
15 |
import torchaudio
|
16 |
|
17 |
@spaces.GPU
|
18 |
-
def dummy():
|
19 |
pass
|
20 |
|
21 |
-
# Constants
|
22 |
-
# DEVICE = "cpu"
|
23 |
LANGUAGE_CODES = {
|
24 |
"English": "en",
|
25 |
"Chinese": "zh"
|
26 |
}
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def initialize_components():
|
30 |
-
#
|
31 |
-
# Load in 4-bit quantization to reduce memory usage
|
32 |
bnb_config = BitsAndBytesConfig(
|
33 |
load_in_4bit=True,
|
34 |
bnb_4bit_quant_type="nf4",
|
@@ -42,43 +50,45 @@ def initialize_components():
|
|
42 |
)
|
43 |
tokenizer = AutoTokenizer.from_pretrained("xverse/XVERSE-13B-Chat")
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
48 |
"openai/whisper-small",
|
49 |
torch_dtype=torch.float32,
|
50 |
low_cpu_mem_usage=True,
|
51 |
)
|
52 |
|
53 |
-
#
|
54 |
-
|
|
|
55 |
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
return
|
62 |
-
model_name,
|
63 |
-
torch_dtype=torch.float32,
|
64 |
-
low_cpu_mem_usage=True,
|
65 |
-
)
|
66 |
|
67 |
class ConversationManager:
|
68 |
def __init__(self):
|
69 |
self.history = []
|
|
|
70 |
|
71 |
-
def add_message(self, role, content
|
72 |
self.history.append({
|
73 |
"role": role,
|
74 |
-
"content": content
|
75 |
-
"audio_path": audio_path
|
76 |
})
|
77 |
|
78 |
def get_formatted_history(self):
|
79 |
-
|
|
|
80 |
f"{msg['role']}: {msg['content']}" for msg in self.history
|
81 |
])
|
|
|
|
|
|
|
|
|
82 |
|
83 |
def speech_to_text(audio, processor, model, target_language):
|
84 |
"""Convert speech to text using Whisper"""
|
@@ -113,15 +123,19 @@ def generate_response(prompt, llm, tokenizer):
|
|
113 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
114 |
return response
|
115 |
|
116 |
-
def text_to_speech(text, model, vocoder,
|
117 |
-
"""Convert text to speech using
|
118 |
-
inputs = processor(text, return_tensors="pt")
|
119 |
-
speech = model.generate_speech(
|
|
|
|
|
|
|
|
|
120 |
return speech
|
121 |
|
122 |
def create_gradio_interface():
|
123 |
# Initialize components
|
124 |
-
llm, tokenizer,
|
125 |
conversation_manager = ConversationManager()
|
126 |
|
127 |
with gr.Blocks() as interface:
|
@@ -133,7 +147,6 @@ def create_gradio_interface():
|
|
133 |
)
|
134 |
|
135 |
with gr.Row():
|
136 |
-
# Audio input
|
137 |
audio_input = gr.Audio(
|
138 |
source="microphone",
|
139 |
type="numpy",
|
@@ -141,7 +154,6 @@ def create_gradio_interface():
|
|
141 |
)
|
142 |
|
143 |
with gr.Row():
|
144 |
-
# Chat history display
|
145 |
chat_display = gr.Textbox(
|
146 |
value="",
|
147 |
label="Conversation History",
|
@@ -150,17 +162,18 @@ def create_gradio_interface():
|
|
150 |
)
|
151 |
|
152 |
with gr.Row():
|
153 |
-
# Assistant's audio response
|
154 |
audio_output = gr.Audio(
|
155 |
-
label="
|
156 |
type="numpy"
|
157 |
)
|
158 |
|
159 |
def process_conversation(audio, language):
|
|
|
|
|
160 |
# Speech to text
|
161 |
user_text = speech_to_text(
|
162 |
audio,
|
163 |
-
|
164 |
stt_model,
|
165 |
language
|
166 |
)
|
@@ -169,14 +182,15 @@ def create_gradio_interface():
|
|
169 |
# Generate LLM response
|
170 |
context = conversation_manager.get_formatted_history()
|
171 |
response = generate_response(context, llm, tokenizer)
|
172 |
-
conversation_manager.add_message("
|
173 |
|
174 |
# Text to speech
|
175 |
speech_output = text_to_speech(
|
176 |
response,
|
|
|
177 |
tts_model,
|
178 |
vocoder,
|
179 |
-
|
180 |
)
|
181 |
|
182 |
return (
|
@@ -192,7 +206,6 @@ def create_gradio_interface():
|
|
192 |
|
193 |
return interface
|
194 |
|
195 |
-
# Launch the application
|
196 |
if __name__ == "__main__":
|
197 |
interface = create_gradio_interface()
|
198 |
interface.launch()
|
|
|
7 |
pipeline,
|
8 |
AutoProcessor,
|
9 |
AutoModelForSpeechSeq2Seq,
|
10 |
+
BitsAndBytesConfig,
|
11 |
+
SpeechT5Processor,
|
12 |
+
SpeechT5ForTextToSpeech,
|
13 |
+
SpeechT5HifiGan
|
14 |
)
|
15 |
from datasets import load_dataset
|
16 |
import numpy as np
|
|
|
17 |
import torchaudio
|
18 |
|
19 |
@spaces.GPU
|
20 |
+
def dummy(): # just a dummy
|
21 |
pass
|
22 |
|
|
|
|
|
23 |
LANGUAGE_CODES = {
|
24 |
"English": "en",
|
25 |
"Chinese": "zh"
|
26 |
}
|
27 |
|
28 |
+
def get_system_prompt(language):
|
29 |
+
if language == "Chinese":
|
30 |
+
return """你是Lin Yi(林意),一个友好的AI助手。你是我的好朋友,说话亲切自然。
|
31 |
+
请用中文回答,语气要自然友好。如果我用英文问你问题,你也要用中文回答。
|
32 |
+
记住你要像朋友一样交谈,不要太正式。"""
|
33 |
+
else:
|
34 |
+
return """You are Lin Yi, a friendly AI assistant and my good friend (hao pengyou).
|
35 |
+
Speak naturally and warmly. If I speak in Chinese, respond in English.
|
36 |
+
Remember to converse like a friend, not too formal."""
|
37 |
+
|
38 |
def initialize_components():
|
39 |
+
# LLM initialization
|
|
|
40 |
bnb_config = BitsAndBytesConfig(
|
41 |
load_in_4bit=True,
|
42 |
bnb_4bit_quant_type="nf4",
|
|
|
50 |
)
|
51 |
tokenizer = AutoTokenizer.from_pretrained("xverse/XVERSE-13B-Chat")
|
52 |
|
53 |
+
# Speech-to-text
|
54 |
+
whisper_processor = AutoProcessor.from_pretrained("openai/whisper-small")
|
55 |
stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
56 |
"openai/whisper-small",
|
57 |
torch_dtype=torch.float32,
|
58 |
low_cpu_mem_usage=True,
|
59 |
)
|
60 |
|
61 |
+
# Text-to-speech
|
62 |
+
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
|
63 |
+
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
|
64 |
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
|
65 |
+
|
66 |
+
# Load speaker embedding
|
67 |
+
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
|
68 |
+
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
|
69 |
+
|
70 |
+
return llm, tokenizer, whisper_processor, stt_model, tts_processor, tts_model, vocoder, speaker_embeddings
|
|
|
|
|
|
|
|
|
71 |
|
72 |
class ConversationManager:
|
73 |
def __init__(self):
|
74 |
self.history = []
|
75 |
+
self.current_language = "English"
|
76 |
|
77 |
+
def add_message(self, role, content):
|
78 |
self.history.append({
|
79 |
"role": role,
|
80 |
+
"content": content
|
|
|
81 |
})
|
82 |
|
83 |
def get_formatted_history(self):
|
84 |
+
system_prompt = get_system_prompt(self.current_language)
|
85 |
+
history_text = "\n".join([
|
86 |
f"{msg['role']}: {msg['content']}" for msg in self.history
|
87 |
])
|
88 |
+
return f"{system_prompt}\n\n{history_text}"
|
89 |
+
|
90 |
+
def set_language(self, language):
|
91 |
+
self.current_language = language
|
92 |
|
93 |
def speech_to_text(audio, processor, model, target_language):
|
94 |
"""Convert speech to text using Whisper"""
|
|
|
123 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
124 |
return response
|
125 |
|
126 |
+
def text_to_speech(text, processor, model, vocoder, speaker_embeddings):
|
127 |
+
"""Convert text to speech using SpeechT5"""
|
128 |
+
inputs = processor(text=text, return_tensors="pt")
|
129 |
+
speech = model.generate_speech(
|
130 |
+
inputs["input_ids"],
|
131 |
+
speaker_embeddings,
|
132 |
+
vocoder=vocoder
|
133 |
+
)
|
134 |
return speech
|
135 |
|
136 |
def create_gradio_interface():
|
137 |
# Initialize components
|
138 |
+
llm, tokenizer, whisper_processor, stt_model, tts_processor, tts_model, vocoder, speaker_embeddings = initialize_components()
|
139 |
conversation_manager = ConversationManager()
|
140 |
|
141 |
with gr.Blocks() as interface:
|
|
|
147 |
)
|
148 |
|
149 |
with gr.Row():
|
|
|
150 |
audio_input = gr.Audio(
|
151 |
source="microphone",
|
152 |
type="numpy",
|
|
|
154 |
)
|
155 |
|
156 |
with gr.Row():
|
|
|
157 |
chat_display = gr.Textbox(
|
158 |
value="",
|
159 |
label="Conversation History",
|
|
|
162 |
)
|
163 |
|
164 |
with gr.Row():
|
|
|
165 |
audio_output = gr.Audio(
|
166 |
+
label="Lin Yi's Response",
|
167 |
type="numpy"
|
168 |
)
|
169 |
|
170 |
def process_conversation(audio, language):
|
171 |
+
conversation_manager.set_language(language)
|
172 |
+
|
173 |
# Speech to text
|
174 |
user_text = speech_to_text(
|
175 |
audio,
|
176 |
+
whisper_processor,
|
177 |
stt_model,
|
178 |
language
|
179 |
)
|
|
|
182 |
# Generate LLM response
|
183 |
context = conversation_manager.get_formatted_history()
|
184 |
response = generate_response(context, llm, tokenizer)
|
185 |
+
conversation_manager.add_message("Lin Yi", response)
|
186 |
|
187 |
# Text to speech
|
188 |
speech_output = text_to_speech(
|
189 |
response,
|
190 |
+
tts_processor,
|
191 |
tts_model,
|
192 |
vocoder,
|
193 |
+
speaker_embeddings
|
194 |
)
|
195 |
|
196 |
return (
|
|
|
206 |
|
207 |
return interface
|
208 |
|
|
|
209 |
if __name__ == "__main__":
|
210 |
interface = create_gradio_interface()
|
211 |
interface.launch()
|