Spaces:
Sleeping
Sleeping
File size: 4,152 Bytes
68ba2e8 941d470 e6868fd a7368c8 e6868fd a7368c8 e6868fd a7368c8 68ba2e8 e6868fd 68ba2e8 e6868fd 68ba2e8 e6868fd 68ba2e8 e6868fd 68ba2e8 a7368c8 e6868fd a7368c8 e6868fd a7368c8 e6868fd 68ba2e8 a7368c8 e6868fd a7368c8 68ba2e8 a7368c8 e6868fd 68ba2e8 e6868fd 68ba2e8 a7368c8 68ba2e8 e6868fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import time
import gradio as gr
from transformers import pipeline
import numpy as np
from openai import OpenAI
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
qa_model = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
def predict(message, history, api_key, is_paused):
client = OpenAI(api_key=api_key)
history_openai_format = []
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
response = client.chat.completions.create(
model='gpt-4o',
messages=history_openai_format,
temperature=1.0,
stream=True
)
partial_message = ""
for chunk in response:
print(is_paused)
if is_paused[0]: # Check if paused
while is_paused[0]:
print('paused')
time.sleep(0.1)
print('not paused')
if chunk.choices[0].delta.content:
partial_message += chunk.choices[0].delta.content
yield partial_message
def chat_with_api_key(api_key, message, history, is_paused):
accumulated_message = ""
for partial_message in predict(message, history, api_key, is_paused):
if is_paused[0]: # Check if paused
break
accumulated_message = partial_message
history.append((message, accumulated_message))
yield message, [[message, accumulated_message]]
def transcribe(audio):
if audio is None:
return "No audio recorded."
sr, y = audio
y = y.astype(np.float32)
y /= np.max(np.abs(y))
return transcriber({"sampling_rate": sr, "raw": y})["text"]
def answer(transcription):
context = "You are a chatbot answering general questions"
result = qa_model(question=transcription, context=context)
return result['answer']
def process_audio(audio):
if audio is None:
return "No audio recorded.", []
transcription = transcribe(audio)
answer_result = answer(transcription)
return transcription, [[transcription, answer_result]]
def update_output(api_key, audio_input, state, is_paused):
if is_paused[0]: # Check if paused
yield "", state # Return current state without making changes
else:
message = transcribe(audio_input)
responses = chat_with_api_key(api_key, message, state, is_paused)
accumulated_response = ""
for response, updated_state in responses:
if is_paused[0]: # Check if paused
break
accumulated_response = response
yield accumulated_response, updated_state
def clear_all():
return None, "", []
def toggle_pause(is_paused):
is_paused[0] = not is_paused[0]
return is_paused
def update_button_label(is_paused):
return "Resume" if is_paused[0] else "Pause"
with gr.Blocks() as demo:
answer_output = gr.Chatbot(label="Answer Result")
with gr.Row():
audio_input = gr.Audio(label="Audio Input", sources=["microphone"], type="numpy")
with gr.Column():
api_key = gr.Textbox(label="API Key", placeholder="Enter your API key", type="password")
transcription_output = gr.Textbox(label="Transcription")
clear_button = gr.Button("Clear")
pause_button = gr.Button("Pause")
state = gr.State([])
is_paused = gr.State([False]) # Using a list to hold the mutable pause state
audio_input.stop_recording(
fn=update_output,
inputs=[api_key, audio_input, state, is_paused],
outputs=[transcription_output, answer_output]
)
clear_button.click(
fn=clear_all,
inputs=[],
outputs=[audio_input, transcription_output, answer_output]
)
pause_button.click(
fn=toggle_pause,
inputs=[is_paused],
outputs=[is_paused]
).then(
fn=update_button_label,
inputs=[is_paused],
outputs=[pause_button]
)
demo.launch()
|