Spaces:
Sleeping
Sleeping
import gradio as gr | |
from utils import * | |
from save_data import add_new_data, get_sheet_service | |
class SessionManager: | |
def __init__(self): | |
self.sessions = [] | |
def add_session(self, task, human_input, cooperate_style): | |
session = { | |
"task": task, | |
"human_input": human_input, | |
"cooperate_style": cooperate_style, | |
"ai_output": None, | |
"merged_output": None, | |
"evaluation": None | |
} | |
self.sessions.append(session) | |
return len(self.sessions) - 1 | |
def update_output(self, index, output,output_type = 'merged_output'): | |
self.sessions[index][output_type] = output | |
def get_session(self, index): | |
return self.sessions[index] | |
def save_session_to_sheet(self, index, service, SHEET_ID): | |
session = self.sessions[index] | |
new_row = list(session.values()) | |
add_new_data(new_row, service, SHEET_ID, num_of_columns=6) # 6 columns in the sheet | |
def handle_interaction(task, human_input, cooperate_style, session_manager, api_key): | |
session_index = session_manager.add_session(task, human_input, cooperate_style) | |
if cooperate_style == "sequential": | |
output = merge_texts_sequential(task, human_input, api_key) | |
session_manager.update_output(session_index, output, 'ai_output') | |
ai_output = output | |
else: | |
ai_output = generate_text_with_gpt(task, api_key) | |
session_manager.update_output(session_index, ai_output, 'ai_output') | |
output = merge_texts_parallel(task, human_input, ai_output, api_key) | |
session_manager.update_output(session_index, output, 'merged_output') | |
return ai_output, output, session_index | |
def evaluate_interaction(session_index, session_manager, api_key): | |
session = session_manager.get_session(session_index) | |
evaluation = get_evaluation_with_gpt(session['task'], session['merged_output'], api_key) | |
session['evaluation'] = evaluation | |
return evaluation | |
def save_data(session_index, session_manager, service, SHEET_ID): | |
session_manager.save_session_to_sheet(session_index, service, SHEET_ID) | |
return "Data has been saved to Google Sheets." | |
if __name__ == "__main__": | |
api_key = get_api_key(local=False) | |
service, SHEET_ID = get_sheet_service(local=False) | |
session_manager = SessionManager() | |
with gr.Blocks() as app: | |
with gr.Row(): | |
task = gr.Textbox(label="Task Description") | |
human_input = gr.Textbox(label="Human Input") | |
with gr.Row(): | |
cooperate_style = gr.Radio(choices=['sequential', 'parallel'], label="Cooperation Style") | |
submit_btn = gr.Button("Create") | |
with gr.Row(): | |
ai_output = gr.Textbox(label="AI Output (if it is sequential)") | |
merged_output = gr.Textbox(label="Merged Output given with cooperation") | |
session_index = gr.Number(label="Session Index", visible=False) | |
submit_btn.click( | |
fn=lambda task, human_input, cooperate_style: handle_interaction(task, human_input, cooperate_style, session_manager, api_key), | |
inputs=[task, human_input, cooperate_style], | |
outputs=[ai_output, merged_output, session_index] | |
) | |
evaluate_btn = gr.Button("Evaluate") | |
evaluation_result = gr.Textbox(label="Evaluation Result") | |
evaluate_btn.click( | |
fn=lambda session_index: evaluate_interaction(session_index, session_manager, api_key), | |
inputs=[session_index], | |
outputs=[evaluation_result] | |
) | |
save_btn = gr.Button("Save Data") | |
save_result = gr.Label() | |
save_btn.click( | |
fn=lambda session_index: save_data(session_index, session_manager, service, SHEET_ID), | |
inputs=[session_index], | |
outputs=[save_result] | |
) | |
app.launch(share=True) | |