cindyangelira's picture
Update app.py
ce99676 verified
raw
history blame
6.57 kB
import spaces
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
AutoProcessor,
AutoModelForSpeechSeq2Seq,
BitsAndBytesConfig,
SpeechT5Processor,
SpeechT5ForTextToSpeech,
SpeechT5HifiGan
)
from datasets import load_dataset
import numpy as np
import torchaudio
@spaces.GPU
def dummy(): # just a dummy
pass
LANGUAGE_CODES = {
"English": "en",
"Chinese": "zh"
}
def get_system_prompt(language):
if language == "Chinese":
return """你是Lin Yi(林意),一个友好的AI助手。你是我的好朋友,说话亲切自然。
请用中文回答,语气要自然友好。如果我用英文问你问题,你也要用中文回答。
记住你要像朋友一样交谈,不要太正式。"""
else:
return """You are Lin Yi, a friendly AI assistant and my good friend (hao pengyou).
Speak naturally and warmly. If I speak in Chinese, respond in English.
Remember to converse like a friend, not too formal."""
def initialize_components():
# LLM initialization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
llm = AutoModelForCausalLM.from_pretrained(
"xverse/XVERSE-13B-Chat",
quantization_config=bnb_config,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("xverse/XVERSE-13B-Chat")
# Speech-to-text
whisper_processor = AutoProcessor.from_pretrained("openai/whisper-small")
stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
"openai/whisper-small",
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
# Text-to-speech
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
# Load speaker embedding
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
return llm, tokenizer, whisper_processor, stt_model, tts_processor, tts_model, vocoder, speaker_embeddings
class ConversationManager:
def __init__(self):
self.history = []
self.current_language = "English"
def add_message(self, role, content):
self.history.append({
"role": role,
"content": content
})
def get_formatted_history(self):
system_prompt = get_system_prompt(self.current_language)
history_text = "\n".join([
f"{msg['role']}: {msg['content']}" for msg in self.history
])
return f"{system_prompt}\n\n{history_text}"
def set_language(self, language):
self.current_language = language
def speech_to_text(audio, processor, model, target_language):
"""Convert speech to text using Whisper"""
input_features = processor(
audio,
sampling_rate=16000,
return_tensors="pt"
).input_features
predicted_ids = model.generate(
input_features,
language=LANGUAGE_CODES[target_language]
)
transcription = processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
return transcription
def generate_response(prompt, llm, tokenizer):
"""Generate LLM response with optimized settings"""
inputs = tokenizer(prompt, return_tensors="pt")
outputs = llm.generate(
**inputs,
max_length=512,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def text_to_speech(text, processor, model, vocoder, speaker_embeddings):
"""Convert text to speech using SpeechT5"""
inputs = processor(text=text, return_tensors="pt")
speech = model.generate_speech(
inputs["input_ids"],
speaker_embeddings,
vocoder=vocoder
)
return speech
def create_gradio_interface():
# Initialize components
llm, tokenizer, whisper_processor, stt_model, tts_processor, tts_model, vocoder, speaker_embeddings = initialize_components()
conversation_manager = ConversationManager()
with gr.Blocks() as interface:
with gr.Row():
language_selector = gr.Dropdown(
choices=list(LANGUAGE_CODES.keys()),
value="English",
label="Select Language"
)
with gr.Row():
audio_input = gr.Audio(
source="microphone",
type="numpy",
label="Speak"
)
with gr.Row():
chat_display = gr.Textbox(
value="",
label="Conversation History",
lines=10,
readonly=True
)
with gr.Row():
audio_output = gr.Audio(
label="Lin Yi's Response",
type="numpy"
)
def process_conversation(audio, language):
conversation_manager.set_language(language)
# Speech to text
user_text = speech_to_text(
audio,
whisper_processor,
stt_model,
language
)
conversation_manager.add_message("User", user_text)
# Generate LLM response
context = conversation_manager.get_formatted_history()
response = generate_response(context, llm, tokenizer)
conversation_manager.add_message("Lin Yi", response)
# Text to speech
speech_output = text_to_speech(
response,
tts_processor,
tts_model,
vocoder,
speaker_embeddings
)
return (
conversation_manager.get_formatted_history(),
(16000, speech_output.numpy())
)
audio_input.change(
process_conversation,
inputs=[audio_input, language_selector],
outputs=[chat_display, audio_output]
)
return interface
if __name__ == "__main__":
interface = create_gradio_interface()
interface.launch()