ariankhalfani commited on
Commit
03fe636
1 Parent(s): 1789b0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -10,16 +10,13 @@ API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-b
10
  API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_tacotron"
11
  API_URL_WHISPER = "https://api-inference.huggingface.co/models/openai/whisper-large-v2"
12
 
13
- # Hugging Face API Token
14
- API_TOKEN = os.getenv("HF_API_KEY") # Ensure you have set this environment variable
15
- HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
16
-
17
  # Retry settings
18
  MAX_RETRIES = 5
19
  RETRY_DELAY = 1 # seconds
20
 
21
  # Function to query the Whisper model for audio transcription
22
- def query_whisper(audio_path):
 
23
  for attempt in range(MAX_RETRIES):
24
  try:
25
  if not audio_path:
@@ -30,7 +27,7 @@ def query_whisper(audio_path):
30
  with open(audio_path, "rb") as f:
31
  data = f.read()
32
 
33
- response = requests.post(API_URL_WHISPER, headers=HEADERS, files={"file": data})
34
  response.raise_for_status()
35
  return response.json()
36
 
@@ -43,12 +40,13 @@ def query_whisper(audio_path):
43
  return {"error": str(e)}
44
 
45
  # Function to query the RoBERTa model
46
- def query_roberta(prompt, context):
 
47
  payload = {"inputs": {"question": prompt, "context": context}}
48
 
49
  for attempt in range(MAX_RETRIES):
50
  try:
51
- response = requests.post(API_URL_ROBERTA, headers=HEADERS, json=payload)
52
  response.raise_for_status()
53
  return response.json()
54
  except Exception as e:
@@ -60,12 +58,13 @@ def query_roberta(prompt, context):
60
  return {"error": str(e)}
61
 
62
  # Function to generate speech from text using ESPnet TTS
63
- def generate_speech(answer):
 
64
  payload = {"inputs": answer}
65
 
66
  for attempt in range(MAX_RETRIES):
67
  try:
68
- response = requests.post(API_URL_TTS, headers=HEADERS, json=payload)
69
  response.raise_for_status()
70
  audio = response.content
71
 
@@ -82,25 +81,25 @@ def generate_speech(answer):
82
  return {"error": str(e)}
83
 
84
  # Function to handle the entire process
85
- def handle_all(context, audio):
86
  for attempt in range(MAX_RETRIES):
87
  try:
88
  # Step 1: Transcribe audio
89
- transcription = query_whisper(audio)
90
  if 'error' in transcription:
91
  raise Exception(transcription['error'])
92
 
93
  question = transcription.get("text", "No transcription found")
94
 
95
  # Step 2: Get answer from RoBERTa
96
- answer = query_roberta(question, context)
97
  if 'error' in answer:
98
  raise Exception(answer['error'])
99
 
100
  answer_text = answer.get('answer', 'No answer found')
101
 
102
  # Step 3: Generate speech from answer
103
- audio_file_path = generate_speech(answer_text)
104
  if 'error' in audio_file_path:
105
  raise Exception(audio_file_path['error'])
106
 
@@ -118,6 +117,7 @@ def handle_all(context, audio):
118
  iface = gr.Interface(
119
  fn=handle_all,
120
  inputs=[
 
121
  gr.Textbox(lines=2, label="Context", placeholder="Enter the context here..."),
122
  gr.Audio(type="filepath", label="Record your voice")
123
  ],
 
10
  API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_tacotron"
11
  API_URL_WHISPER = "https://api-inference.huggingface.co/models/openai/whisper-large-v2"
12
 
 
 
 
 
13
  # Retry settings
14
  MAX_RETRIES = 5
15
  RETRY_DELAY = 1 # seconds
16
 
17
  # Function to query the Whisper model for audio transcription
18
+ def query_whisper(api_token, audio_path):
19
+ headers = {"Authorization": f"Bearer {api_token}"}
20
  for attempt in range(MAX_RETRIES):
21
  try:
22
  if not audio_path:
 
27
  with open(audio_path, "rb") as f:
28
  data = f.read()
29
 
30
+ response = requests.post(API_URL_WHISPER, headers=headers, files={"file": data})
31
  response.raise_for_status()
32
  return response.json()
33
 
 
40
  return {"error": str(e)}
41
 
42
  # Function to query the RoBERTa model
43
+ def query_roberta(api_token, prompt, context):
44
+ headers = {"Authorization": f"Bearer {api_token}"}
45
  payload = {"inputs": {"question": prompt, "context": context}}
46
 
47
  for attempt in range(MAX_RETRIES):
48
  try:
49
+ response = requests.post(API_URL_ROBERTA, headers=headers, json=payload)
50
  response.raise_for_status()
51
  return response.json()
52
  except Exception as e:
 
58
  return {"error": str(e)}
59
 
60
  # Function to generate speech from text using ESPnet TTS
61
+ def generate_speech(api_token, answer):
62
+ headers = {"Authorization": f"Bearer {api_token}"}
63
  payload = {"inputs": answer}
64
 
65
  for attempt in range(MAX_RETRIES):
66
  try:
67
+ response = requests.post(API_URL_TTS, headers=headers, json=payload)
68
  response.raise_for_status()
69
  audio = response.content
70
 
 
81
  return {"error": str(e)}
82
 
83
  # Function to handle the entire process
84
+ def handle_all(api_token, context, audio):
85
  for attempt in range(MAX_RETRIES):
86
  try:
87
  # Step 1: Transcribe audio
88
+ transcription = query_whisper(api_token, audio)
89
  if 'error' in transcription:
90
  raise Exception(transcription['error'])
91
 
92
  question = transcription.get("text", "No transcription found")
93
 
94
  # Step 2: Get answer from RoBERTa
95
+ answer = query_roberta(api_token, question, context)
96
  if 'error' in answer:
97
  raise Exception(answer['error'])
98
 
99
  answer_text = answer.get('answer', 'No answer found')
100
 
101
  # Step 3: Generate speech from answer
102
+ audio_file_path = generate_speech(api_token, answer_text)
103
  if 'error' in audio_file_path:
104
  raise Exception(audio_file_path['error'])
105
 
 
117
  iface = gr.Interface(
118
  fn=handle_all,
119
  inputs=[
120
+ gr.Textbox(lines=1, label="Hugging Face API Token", type="password", placeholder="Enter your Hugging Face API token..."),
121
  gr.Textbox(lines=2, label="Context", placeholder="Enter the context here..."),
122
  gr.Audio(type="filepath", label="Record your voice")
123
  ],