Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
pipeline, | |
AutoProcessor, | |
AutoModelForSpeechSeq2Seq, | |
BitsAndBytesConfig, | |
SpeechT5Processor, | |
SpeechT5ForTextToSpeech, | |
SpeechT5HifiGan | |
) | |
from datasets import load_dataset | |
import numpy as np | |
import torchaudio | |
def dummy(): # just a dummy | |
pass | |
LANGUAGE_CODES = { | |
"English": "en", | |
"Chinese": "zh" | |
} | |
def get_system_prompt(language): | |
if language == "Chinese": | |
return """你是Lin Yi(林意),一个友好的AI助手。你是我的好朋友,说话亲切自然。 | |
请用中文回答,语气要自然友好。如果我用英文问你问题,你也要用中文回答。 | |
记住你要像朋友一样交谈,不要太正式。""" | |
else: | |
return """You are Lin Yi, a friendly AI assistant and my good friend (hao pengyou). | |
Speak naturally and warmly. If I speak in Chinese, respond in English. | |
Remember to converse like a friend, not too formal.""" | |
def initialize_components(): | |
# LLM initialization | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
) | |
llm = AutoModelForCausalLM.from_pretrained( | |
"xverse/XVERSE-13B-Chat", | |
quantization_config=bnb_config, | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained("xverse/XVERSE-13B-Chat") | |
# Speech-to-text | |
whisper_processor = AutoProcessor.from_pretrained("openai/whisper-small") | |
stt_model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
"openai/whisper-small", | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True, | |
) | |
# Text-to-speech | |
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") | |
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
# Load speaker embedding | |
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) | |
return llm, tokenizer, whisper_processor, stt_model, tts_processor, tts_model, vocoder, speaker_embeddings | |
class ConversationManager: | |
def __init__(self): | |
self.history = [] | |
self.current_language = "English" | |
def add_message(self, role, content): | |
self.history.append({ | |
"role": role, | |
"content": content | |
}) | |
def get_formatted_history(self): | |
system_prompt = get_system_prompt(self.current_language) | |
history_text = "\n".join([ | |
f"{msg['role']}: {msg['content']}" for msg in self.history | |
]) | |
return f"{system_prompt}\n\n{history_text}" | |
def set_language(self, language): | |
self.current_language = language | |
def speech_to_text(audio, processor, model, target_language): | |
"""Convert speech to text using Whisper""" | |
input_features = processor( | |
audio, | |
sampling_rate=16000, | |
return_tensors="pt" | |
).input_features | |
predicted_ids = model.generate( | |
input_features, | |
language=LANGUAGE_CODES[target_language] | |
) | |
transcription = processor.batch_decode( | |
predicted_ids, | |
skip_special_tokens=True | |
)[0] | |
return transcription | |
def generate_response(prompt, llm, tokenizer): | |
"""Generate LLM response with optimized settings""" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = llm.generate( | |
**inputs, | |
max_length=512, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
def text_to_speech(text, processor, model, vocoder, speaker_embeddings): | |
"""Convert text to speech using SpeechT5""" | |
inputs = processor(text=text, return_tensors="pt") | |
speech = model.generate_speech( | |
inputs["input_ids"], | |
speaker_embeddings, | |
vocoder=vocoder | |
) | |
return speech | |
def create_gradio_interface(): | |
# Initialize components | |
llm, tokenizer, whisper_processor, stt_model, tts_processor, tts_model, vocoder, speaker_embeddings = initialize_components() | |
conversation_manager = ConversationManager() | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
language_selector = gr.Dropdown( | |
choices=list(LANGUAGE_CODES.keys()), | |
value="English", | |
label="Select Language" | |
) | |
with gr.Row(): | |
audio_input = gr.Audio( | |
source="microphone", | |
type="numpy", | |
label="Speak" | |
) | |
with gr.Row(): | |
chat_display = gr.Textbox( | |
value="", | |
label="Conversation History", | |
lines=10, | |
readonly=True | |
) | |
with gr.Row(): | |
audio_output = gr.Audio( | |
label="Lin Yi's Response", | |
type="numpy" | |
) | |
def process_conversation(audio, language): | |
conversation_manager.set_language(language) | |
# Speech to text | |
user_text = speech_to_text( | |
audio, | |
whisper_processor, | |
stt_model, | |
language | |
) | |
conversation_manager.add_message("User", user_text) | |
# Generate LLM response | |
context = conversation_manager.get_formatted_history() | |
response = generate_response(context, llm, tokenizer) | |
conversation_manager.add_message("Lin Yi", response) | |
# Text to speech | |
speech_output = text_to_speech( | |
response, | |
tts_processor, | |
tts_model, | |
vocoder, | |
speaker_embeddings | |
) | |
return ( | |
conversation_manager.get_formatted_history(), | |
(16000, speech_output.numpy()) | |
) | |
audio_input.change( | |
process_conversation, | |
inputs=[audio_input, language_selector], | |
outputs=[chat_display, audio_output] | |
) | |
return interface | |
if __name__ == "__main__": | |
interface = create_gradio_interface() | |
interface.launch() |