|
|
import gradio as gr |
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, SpeechT5Processor, SpeechT5ForTextToSpeech |
|
|
import torch |
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
asr = pipeline( |
|
|
task="automatic-speech-recognition", |
|
|
model="openai/whisper-small", |
|
|
device=-1 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm_model_id = "google/flan-t5-base" |
|
|
tokenizer = AutoTokenizer.from_pretrained(llm_model_id) |
|
|
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_id).to("cpu") |
|
|
|
|
|
def ask_llm(prompt, max_new_tokens=200): |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device) |
|
|
with torch.no_grad(): |
|
|
outputs = llm_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
top_k=50, |
|
|
top_p=0.95 |
|
|
) |
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") |
|
|
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") |
|
|
|
|
|
speaker_embedding = torch.randn(1, 512) |
|
|
|
|
|
def text_to_speech(text, out_path="output.wav"): |
|
|
inputs = processor(text=text, return_tensors="pt") |
|
|
speech = tts_model.generate_speech(inputs["input_ids"], speaker_embedding) |
|
|
sf.write(out_path, speech.numpy(), 16000) |
|
|
return out_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def full_pipeline(audio_file): |
|
|
if not audio_file: |
|
|
return "No audio input detected.", None |
|
|
|
|
|
try: |
|
|
result = asr(audio_file, chunk_length_s=30, stride_length_s=[5, 5]) |
|
|
except Exception as e: |
|
|
return f"ASR error: {e}", None |
|
|
|
|
|
user_text = result.get("text", "") |
|
|
|
|
|
try: |
|
|
llm_response = ask_llm(f"پاسخ بده به زبان ساده: {user_text}") |
|
|
except Exception as e: |
|
|
return f"Assistant generation error: {e}", None |
|
|
|
|
|
try: |
|
|
audio_path = text_to_speech(llm_response, "response.wav") |
|
|
except Exception as e: |
|
|
return f"TTS error: {e}", None |
|
|
|
|
|
return f"User said: {user_text}\nAssistant: {llm_response}", audio_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=full_pipeline, |
|
|
inputs=gr.Audio(type="filepath", label="Record or upload audio"), |
|
|
outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")], |
|
|
title="Persian Voice Assistant (Reliable LLM)", |
|
|
description="ASR → Flan-T5-Base → TTS" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|