vgap / app.py
mkfallah's picture
Update app.py
66951ec verified
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, SpeechT5Processor, SpeechT5ForTextToSpeech
import torch
import soundfile as sf
# --------------------------
# 1. ASR (speech to text)
# --------------------------
asr = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-small",
device=-1
)
# --------------------------
# 2. Language Model (LLM) - more reliable
# --------------------------
llm_model_id = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_id).to("cpu")
def ask_llm(prompt, max_new_tokens=200):
inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
with torch.no_grad():
outputs = llm_model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
top_k=50,
top_p=0.95
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# --------------------------
# 3. TTS (text-to-speech) using SpeechT5
# --------------------------
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
speaker_embedding = torch.randn(1, 512)
def text_to_speech(text, out_path="output.wav"):
inputs = processor(text=text, return_tensors="pt")
speech = tts_model.generate_speech(inputs["input_ids"], speaker_embedding)
sf.write(out_path, speech.numpy(), 16000)
return out_path
# --------------------------
# 4. Full pipeline function
# --------------------------
def full_pipeline(audio_file):
if not audio_file:
return "No audio input detected.", None
try:
result = asr(audio_file, chunk_length_s=30, stride_length_s=[5, 5])
except Exception as e:
return f"ASR error: {e}", None
user_text = result.get("text", "")
try:
llm_response = ask_llm(f"پاسخ بده به زبان ساده: {user_text}")
except Exception as e:
return f"Assistant generation error: {e}", None
try:
audio_path = text_to_speech(llm_response, "response.wav")
except Exception as e:
return f"TTS error: {e}", None
return f"User said: {user_text}\nAssistant: {llm_response}", audio_path
# --------------------------
# 5. Gradio Interface
# --------------------------
iface = gr.Interface(
fn=full_pipeline,
inputs=gr.Audio(type="filepath", label="Record or upload audio"),
outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")],
title="Persian Voice Assistant (Reliable LLM)",
description="ASR → Flan-T5-Base → TTS"
)
if __name__ == "__main__":
iface.launch()