CooperativeBot / app.py
Naisong Zhou
revise to add google slides api
9863223
import gradio as gr
from utils import *
from save_data import add_new_data, get_sheet_service
class SessionManager:
def __init__(self):
self.sessions = []
def add_session(self, task, human_input, cooperate_style):
session = {
"task": task,
"human_input": human_input,
"cooperate_style": cooperate_style,
"ai_output": None,
"merged_output": None,
"evaluation": None
}
self.sessions.append(session)
return len(self.sessions) - 1
def update_output(self, index, output,output_type = 'merged_output'):
self.sessions[index][output_type] = output
def get_session(self, index):
return self.sessions[index]
def save_session_to_sheet(self, index, service, SHEET_ID):
session = self.sessions[index]
new_row = list(session.values())
add_new_data(new_row, service, SHEET_ID, num_of_columns=6) # 6 columns in the sheet
def handle_interaction(task, human_input, cooperate_style, session_manager, api_key):
session_index = session_manager.add_session(task, human_input, cooperate_style)
if cooperate_style == "sequential":
output = merge_texts_sequential(task, human_input, api_key)
session_manager.update_output(session_index, output, 'ai_output')
ai_output = output
else:
ai_output = generate_text_with_gpt(task, api_key)
session_manager.update_output(session_index, ai_output, 'ai_output')
output = merge_texts_parallel(task, human_input, ai_output, api_key)
session_manager.update_output(session_index, output, 'merged_output')
return ai_output, output, session_index
def evaluate_interaction(session_index, session_manager, api_key):
session = session_manager.get_session(session_index)
evaluation = get_evaluation_with_gpt(session['task'], session['merged_output'], api_key)
session['evaluation'] = evaluation
return evaluation
def save_data(session_index, session_manager, service, SHEET_ID):
session_manager.save_session_to_sheet(session_index, service, SHEET_ID)
return "Data has been saved to Google Sheets."
if __name__ == "__main__":
api_key = get_api_key(local=False)
service, SHEET_ID = get_sheet_service(local=False)
session_manager = SessionManager()
with gr.Blocks() as app:
with gr.Row():
task = gr.Textbox(label="Task Description")
human_input = gr.Textbox(label="Human Input")
with gr.Row():
cooperate_style = gr.Radio(choices=['sequential', 'parallel'], label="Cooperation Style")
submit_btn = gr.Button("Create")
with gr.Row():
ai_output = gr.Textbox(label="AI Output (if it is sequential)")
merged_output = gr.Textbox(label="Merged Output given with cooperation")
session_index = gr.Number(label="Session Index", visible=False)
submit_btn.click(
fn=lambda task, human_input, cooperate_style: handle_interaction(task, human_input, cooperate_style, session_manager, api_key),
inputs=[task, human_input, cooperate_style],
outputs=[ai_output, merged_output, session_index]
)
evaluate_btn = gr.Button("Evaluate")
evaluation_result = gr.Textbox(label="Evaluation Result")
evaluate_btn.click(
fn=lambda session_index: evaluate_interaction(session_index, session_manager, api_key),
inputs=[session_index],
outputs=[evaluation_result]
)
save_btn = gr.Button("Save Data")
save_result = gr.Label()
save_btn.click(
fn=lambda session_index: save_data(session_index, session_manager, service, SHEET_ID),
inputs=[session_index],
outputs=[save_result]
)
app.launch(share=True)