RobertaSpeak / app.py
ariankhalfani's picture
Update app.py
03fe636 verified
raw
history blame contribute delete
No virus
5.17 kB
import requests
import gradio as gr
import os
from pydub import AudioSegment
from io import BytesIO
import time
# Hugging Face API URLs
API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2"
API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_tacotron"
API_URL_WHISPER = "https://api-inference.huggingface.co/models/openai/whisper-large-v2"
# Retry settings
MAX_RETRIES = 5
RETRY_DELAY = 1 # seconds
# Function to query the Whisper model for audio transcription
def query_whisper(api_token, audio_path):
headers = {"Authorization": f"Bearer {api_token}"}
for attempt in range(MAX_RETRIES):
try:
if not audio_path:
raise ValueError("Audio file path is None")
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file does not exist: {audio_path}")
with open(audio_path, "rb") as f:
data = f.read()
response = requests.post(API_URL_WHISPER, headers=headers, files={"file": data})
response.raise_for_status()
return response.json()
except Exception as e:
print(f"Whisper model query failed: {e}")
if attempt < MAX_RETRIES - 1:
print(f"Retrying Whisper model query ({attempt + 1}/{MAX_RETRIES})...")
time.sleep(RETRY_DELAY)
else:
return {"error": str(e)}
# Function to query the RoBERTa model
def query_roberta(api_token, prompt, context):
headers = {"Authorization": f"Bearer {api_token}"}
payload = {"inputs": {"question": prompt, "context": context}}
for attempt in range(MAX_RETRIES):
try:
response = requests.post(API_URL_ROBERTA, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except Exception as e:
print(f"RoBERTa model query failed: {e}")
if attempt < MAX_RETRIES - 1:
print(f"Retrying RoBERTa model query ({attempt + 1}/{MAX_RETRIES})...")
time.sleep(RETRY_DELAY)
else:
return {"error": str(e)}
# Function to generate speech from text using ESPnet TTS
def generate_speech(api_token, answer):
headers = {"Authorization": f"Bearer {api_token}"}
payload = {"inputs": answer}
for attempt in range(MAX_RETRIES):
try:
response = requests.post(API_URL_TTS, headers=headers, json=payload)
response.raise_for_status()
audio = response.content
audio_segment = AudioSegment.from_file(BytesIO(audio), format="flac")
audio_file_path = "/tmp/answer.wav"
audio_segment.export(audio_file_path, format="wav")
return audio_file_path
except Exception as e:
print(f"ESPnet TTS query failed: {e}")
if attempt < MAX_RETRIES - 1:
print(f"Retrying ESPnet TTS query ({attempt + 1}/{MAX_RETRIES})...")
time.sleep(RETRY_DELAY)
else:
return {"error": str(e)}
# Function to handle the entire process
def handle_all(api_token, context, audio):
for attempt in range(MAX_RETRIES):
try:
# Step 1: Transcribe audio
transcription = query_whisper(api_token, audio)
if 'error' in transcription:
raise Exception(transcription['error'])
question = transcription.get("text", "No transcription found")
# Step 2: Get answer from RoBERTa
answer = query_roberta(api_token, question, context)
if 'error' in answer:
raise Exception(answer['error'])
answer_text = answer.get('answer', 'No answer found')
# Step 3: Generate speech from answer
audio_file_path = generate_speech(api_token, answer_text)
if 'error' in audio_file_path:
raise Exception(audio_file_path['error'])
return answer_text, audio_file_path
except Exception as e:
print(f"Process failed: {e}")
if attempt < MAX_RETRIES - 1:
print(f"Retrying entire process ({attempt + 1}/{MAX_RETRIES})...")
time.sleep(RETRY_DELAY)
else:
return str(e), None
# Define the Gradio interface
iface = gr.Interface(
fn=handle_all,
inputs=[
gr.Textbox(lines=1, label="Hugging Face API Token", type="password", placeholder="Enter your Hugging Face API token..."),
gr.Textbox(lines=2, label="Context", placeholder="Enter the context here..."),
gr.Audio(type="filepath", label="Record your voice")
],
outputs=[
gr.Textbox(label="Answer"),
gr.Audio(label="Answer as Speech", type="filepath")
],
title="Chat with Roberta with Voice",
description="Record your voice, get the transcription, use it as a question for the Roberta model, and hear the response via text-to-speech."
)
# Launch the Gradio app
iface.launch(share=True)