File size: 4,152 Bytes
68ba2e8
941d470
e6868fd
a7368c8
e6868fd
a7368c8
 
e6868fd
a7368c8
68ba2e8
e6868fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68ba2e8
 
 
 
 
 
 
e6868fd
 
 
 
68ba2e8
e6868fd
68ba2e8
 
 
e6868fd
 
68ba2e8
a7368c8
 
 
 
 
 
 
 
 
e6868fd
 
 
 
a7368c8
 
 
e6868fd
a7368c8
 
e6868fd
 
68ba2e8
 
 
 
 
 
 
 
 
 
 
 
a7368c8
 
e6868fd
a7368c8
68ba2e8
 
 
 
 
 
 
a7368c8
e6868fd
 
 
 
 
 
 
68ba2e8
 
e6868fd
68ba2e8
 
 
 
 
 
 
a7368c8
 
 
 
 
 
 
68ba2e8
 
 
 
 
 
 
 
 
e6868fd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import time
import gradio as gr
from transformers import pipeline
import numpy as np
from openai import OpenAI

transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
qa_model = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")

def predict(message, history, api_key, is_paused):
    client = OpenAI(api_key=api_key)
    history_openai_format = []
    for human, assistant in history:
        history_openai_format.append({"role": "user", "content": human})
        history_openai_format.append({"role": "assistant", "content": assistant})
    history_openai_format.append({"role": "user", "content": message})

    response = client.chat.completions.create(
        model='gpt-4o',
        messages=history_openai_format,
        temperature=1.0,
        stream=True
    )

    partial_message = ""
    for chunk in response:
        print(is_paused)
        if is_paused[0]:  # Check if paused
            
            while is_paused[0]:
                print('paused')
                time.sleep(0.1)
        print('not paused')
        if chunk.choices[0].delta.content:
            partial_message += chunk.choices[0].delta.content
            yield partial_message

def chat_with_api_key(api_key, message, history, is_paused):
    accumulated_message = ""
    for partial_message in predict(message, history, api_key, is_paused):
        if is_paused[0]:  # Check if paused
            break
        accumulated_message = partial_message
        history.append((message, accumulated_message))
        yield message, [[message, accumulated_message]]

def transcribe(audio):
    if audio is None:
        return "No audio recorded."
    sr, y = audio
    y = y.astype(np.float32)
    y /= np.max(np.abs(y))
    return transcriber({"sampling_rate": sr, "raw": y})["text"]

def answer(transcription):
    context = "You are a chatbot answering general questions"
    result = qa_model(question=transcription, context=context)
    return result['answer']

def process_audio(audio):
    if audio is None:
        return "No audio recorded.", []
    transcription = transcribe(audio)
    answer_result = answer(transcription)
    return transcription, [[transcription, answer_result]]

def update_output(api_key, audio_input, state, is_paused):
    if is_paused[0]:  # Check if paused
        yield "", state  # Return current state without making changes
    else:
        message = transcribe(audio_input)
        responses = chat_with_api_key(api_key, message, state, is_paused)
        accumulated_response = ""
        for response, updated_state in responses:
            if is_paused[0]:  # Check if paused
                break
            accumulated_response = response
            yield accumulated_response, updated_state

def clear_all():
    return None, "", []

def toggle_pause(is_paused):
    is_paused[0] = not is_paused[0]
    return is_paused

def update_button_label(is_paused):
    return "Resume" if is_paused[0] else "Pause"

with gr.Blocks() as demo:
    answer_output = gr.Chatbot(label="Answer Result")
    with gr.Row():    
        audio_input = gr.Audio(label="Audio Input", sources=["microphone"], type="numpy")
        with gr.Column():
            api_key = gr.Textbox(label="API Key", placeholder="Enter your API key", type="password")
            transcription_output = gr.Textbox(label="Transcription")
            clear_button = gr.Button("Clear")
            pause_button = gr.Button("Pause")
    
    state = gr.State([])
    is_paused = gr.State([False])  # Using a list to hold the mutable pause state

    audio_input.stop_recording(
        fn=update_output,
        inputs=[api_key, audio_input, state, is_paused],
        outputs=[transcription_output, answer_output]
    )
    
    clear_button.click(
        fn=clear_all,
        inputs=[],
        outputs=[audio_input, transcription_output, answer_output]
    )

    pause_button.click(
        fn=toggle_pause,
        inputs=[is_paused],
        outputs=[is_paused]
    ).then(
        fn=update_button_label,
        inputs=[is_paused],
        outputs=[pause_button]
    )

demo.launch()