File size: 5,169 Bytes
9f5785e
e8892bb
86e1c2c
0f1374b
 
 
9f5785e
ab44c82
9f5785e
39422e4
0f1374b
 
 
 
 
 
 
03fe636
 
0f1374b
 
 
 
 
 
 
 
 
 
03fe636
0f1374b
 
 
 
 
 
 
 
 
 
9f5785e
9d4db1c
03fe636
 
0f1374b
 
 
 
03fe636
0f1374b
 
 
 
 
 
 
 
 
3366c2e
39422e4
03fe636
 
0f1374b
 
 
 
03fe636
0f1374b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03fe636
0f1374b
 
 
03fe636
0f1374b
 
 
 
 
 
03fe636
0f1374b
 
 
 
 
 
03fe636
0f1374b
 
 
 
5ddeb31
0f1374b
 
 
 
 
 
 
a75c3a6
 
 
0f1374b
a75c3a6
03fe636
a75c3a6
0f1374b
a75c3a6
 
 
0f1374b
a75c3a6
 
0f1374b
a75c3a6
 
 
1789b0c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import requests
import gradio as gr
import os
from pydub import AudioSegment
from io import BytesIO
import time

# Hugging Face API URLs
API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2"
API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_tacotron"
API_URL_WHISPER = "https://api-inference.huggingface.co/models/openai/whisper-large-v2"

# Retry settings
MAX_RETRIES = 5
RETRY_DELAY = 1  # seconds

# Function to query the Whisper model for audio transcription
def query_whisper(api_token, audio_path):
    headers = {"Authorization": f"Bearer {api_token}"}
    for attempt in range(MAX_RETRIES):
        try:
            if not audio_path:
                raise ValueError("Audio file path is None")
            if not os.path.exists(audio_path):
                raise FileNotFoundError(f"Audio file does not exist: {audio_path}")
            
            with open(audio_path, "rb") as f:
                data = f.read()
            
            response = requests.post(API_URL_WHISPER, headers=headers, files={"file": data})
            response.raise_for_status()
            return response.json()
        
        except Exception as e:
            print(f"Whisper model query failed: {e}")
            if attempt < MAX_RETRIES - 1:
                print(f"Retrying Whisper model query ({attempt + 1}/{MAX_RETRIES})...")
                time.sleep(RETRY_DELAY)
            else:
                return {"error": str(e)}

# Function to query the RoBERTa model
def query_roberta(api_token, prompt, context):
    headers = {"Authorization": f"Bearer {api_token}"}
    payload = {"inputs": {"question": prompt, "context": context}}
    
    for attempt in range(MAX_RETRIES):
        try:
            response = requests.post(API_URL_ROBERTA, headers=headers, json=payload)
            response.raise_for_status()
            return response.json()
        except Exception as e:
            print(f"RoBERTa model query failed: {e}")
            if attempt < MAX_RETRIES - 1:
                print(f"Retrying RoBERTa model query ({attempt + 1}/{MAX_RETRIES})...")
                time.sleep(RETRY_DELAY)
            else:
                return {"error": str(e)}

# Function to generate speech from text using ESPnet TTS
def generate_speech(api_token, answer):
    headers = {"Authorization": f"Bearer {api_token}"}
    payload = {"inputs": answer}
    
    for attempt in range(MAX_RETRIES):
        try:
            response = requests.post(API_URL_TTS, headers=headers, json=payload)
            response.raise_for_status()
            audio = response.content
            
            audio_segment = AudioSegment.from_file(BytesIO(audio), format="flac")
            audio_file_path = "/tmp/answer.wav"
            audio_segment.export(audio_file_path, format="wav")
            return audio_file_path
        except Exception as e:
            print(f"ESPnet TTS query failed: {e}")
            if attempt < MAX_RETRIES - 1:
                print(f"Retrying ESPnet TTS query ({attempt + 1}/{MAX_RETRIES})...")
                time.sleep(RETRY_DELAY)
            else:
                return {"error": str(e)}

# Function to handle the entire process
def handle_all(api_token, context, audio):
    for attempt in range(MAX_RETRIES):
        try:
            # Step 1: Transcribe audio
            transcription = query_whisper(api_token, audio)
            if 'error' in transcription:
                raise Exception(transcription['error'])
            
            question = transcription.get("text", "No transcription found")

            # Step 2: Get answer from RoBERTa
            answer = query_roberta(api_token, question, context)
            if 'error' in answer:
                raise Exception(answer['error'])

            answer_text = answer.get('answer', 'No answer found')

            # Step 3: Generate speech from answer
            audio_file_path = generate_speech(api_token, answer_text)
            if 'error' in audio_file_path:
                raise Exception(audio_file_path['error'])

            return answer_text, audio_file_path

        except Exception as e:
            print(f"Process failed: {e}")
            if attempt < MAX_RETRIES - 1:
                print(f"Retrying entire process ({attempt + 1}/{MAX_RETRIES})...")
                time.sleep(RETRY_DELAY)
            else:
                return str(e), None

# Define the Gradio interface
iface = gr.Interface(
    fn=handle_all,
    inputs=[
        gr.Textbox(lines=1, label="Hugging Face API Token", type="password", placeholder="Enter your Hugging Face API token..."),
        gr.Textbox(lines=2, label="Context", placeholder="Enter the context here..."),
        gr.Audio(type="filepath", label="Record your voice")
    ],
    outputs=[
        gr.Textbox(label="Answer"),
        gr.Audio(label="Answer as Speech", type="filepath")
    ],
    title="Chat with Roberta with Voice",
    description="Record your voice, get the transcription, use it as a question for the Roberta model, and hear the response via text-to-speech."
)

# Launch the Gradio app
iface.launch(share=True)