|
import json |
|
import os |
|
import gradio as gr |
|
from matplotlib import pyplot as plt |
|
|
|
from experiment_details import problem_topics, problems_per_topic, writing_skills_questions |
|
from data import problems |
|
from model_generate import chatbot_generate |
|
import random |
|
import re |
|
|
|
def process_markdown(prompt, question): |
|
if prompt: |
|
initial_path = './data/problems/' |
|
else: |
|
initial_path = './data/instructions/' |
|
with open(initial_path + question, 'r') as md_file: |
|
markdown_content = md_file.read() |
|
return markdown_content |
|
|
|
def strip_markdown(text): |
|
""" |
|
Strips common markdown formatting and `<span>` tags from a string. |
|
""" |
|
patterns = [ |
|
r'\!\[[^\]]*\]\([^\)]+\)', |
|
r'\[[^\]]*\]\([^\)]+\)', |
|
r'\*\*(.*?)\*\*|__(.*?)__', |
|
r'\*(.*?)\*|_(.*?)_', |
|
r'\~\~(.*?)\~\~', |
|
r'\`{1,3}(.*?)\`{1,3}', |
|
r'\#{1,6}\s', |
|
r'\>(.*?)\n', |
|
r'\-{3,}', |
|
r'\n{2,}', |
|
r'\<span[^>]*\>', |
|
r'\<\/span\>', |
|
] |
|
|
|
clean_text = text |
|
for pattern in patterns: |
|
clean_text = re.sub(pattern, '', clean_text, flags=re.DOTALL) |
|
|
|
|
|
clean_text = clean_text.strip() |
|
|
|
return clean_text |
|
|
|
|
|
def save_answer(question_answers, q_num, q_prompt, q_text, q_assist, q_assist_history=None): |
|
q_num_key = 'q' + str(q_num) |
|
question_answers[q_num_key]['Prompt'] = q_prompt |
|
question_answers[q_num_key]['Response'] = json.dumps(q_text, indent=4) |
|
question_answers[q_num_key]['Assist'] = q_assist |
|
question_answers[q_num_key]['AssistanceHistory'] = q_assist_history |
|
|
|
return question_answers |
|
|
|
|
|
|
|
def randomly_select_prompts(): |
|
prompts = [] |
|
|
|
shortstoryIDs = random.sample(range(1, 11), 2) |
|
emailIDs = random.sample(range(1, 11), 2) |
|
summaryIDs = random.sample(range(1, 11), 2) |
|
titleIDs = random.sample(range(1, 11), 2) |
|
for ssID in shortstoryIDs: |
|
instr = 'instr_shortstory.md' |
|
prompt_file = 'p_shortstory' + str(ssID) + '.md' |
|
word_count = 300 |
|
textfield_lines = 10 |
|
question_details = { |
|
'instruction': instr, |
|
'prompt_file': prompt_file, |
|
'word_count': word_count, |
|
'textfield_lines': textfield_lines |
|
} |
|
prompts.append(question_details) |
|
for eID in emailIDs: |
|
instr = 'instr_email.md' |
|
prompt_file = 'p_email' + str(eID) + '.md' |
|
word_count = 300 |
|
textfield_lines = 10 |
|
question_details = { |
|
'instruction': instr, |
|
'prompt_file': prompt_file, |
|
'word_count': word_count, |
|
'textfield_lines': textfield_lines |
|
} |
|
prompts.append(question_details) |
|
for sID in summaryIDs: |
|
instr = 'instr_summary.md' |
|
prompt_file = 'p_summary' + str(sID) + '.md' |
|
word_count = 75 |
|
textfield_lines = 5 |
|
question_details = { |
|
'instruction': instr, |
|
'prompt_file': prompt_file, |
|
'word_count': word_count, |
|
'textfield_lines': textfield_lines |
|
} |
|
prompts.append(question_details) |
|
for tID in titleIDs: |
|
instr = 'instr_title.md' |
|
prompt_file = 'p_title' + str(tID) + '.md' |
|
word_count = 10 |
|
textfield_lines = 2 |
|
question_details = { |
|
'instruction': instr, |
|
'prompt_file': prompt_file, |
|
'word_count': word_count, |
|
'textfield_lines': textfield_lines |
|
} |
|
prompts.append(question_details) |
|
return prompts |
|
|
|
|
|
def randomize_questions(questions): |
|
group1_indices = [0, 2, 4, 6] |
|
group2_indices = [1, 3, 5, 7] |
|
|
|
|
|
group1_items = [questions[i] for i in group1_indices] |
|
group2_items = [questions[i] for i in group2_indices] |
|
|
|
|
|
random.shuffle(group1_items) |
|
random.shuffle(group2_items) |
|
|
|
|
|
for idx, item in zip(group1_indices, group1_items): |
|
questions[idx] = item |
|
|
|
for idx, item in zip(group2_indices, group2_items): |
|
questions[idx] = item |
|
|
|
|
|
return questions |
|
|
|
|
|
def generate_unassisted_question(question_prompt, next_q_btn, q_num, question_answers): |
|
with gr.Column("Unassisted Writing Task", render=False) as q_unassisted: |
|
q_title_text = "#\n# Question " + str(q_num) |
|
q_title = gr.Markdown(value=q_title_text) |
|
|
|
unassisted_instr = process_markdown(False, 'instr_unassisted.md') |
|
unassisted_instr_md = gr.Markdown(value=unassisted_instr) |
|
|
|
instruction = question_prompt['instruction'] |
|
prompt = question_prompt['prompt_file'] |
|
max_word_count = question_prompt['word_count'] |
|
textfield_lines = question_prompt['textfield_lines'] |
|
|
|
with gr.Column() as instructions: |
|
|
|
gen_instr_text1 = process_markdown(False, instruction) |
|
prompt_text1 = process_markdown(True, prompt) |
|
gen_instr1 = gr.Markdown(value=gen_instr_text1) |
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
prompt1 = gr.Markdown(value=prompt_text1) |
|
with gr.Column(scale=1): |
|
word_count = gr.Textbox( |
|
label='Word Count', |
|
interactive=False, |
|
lines=1, |
|
max_lines=1, |
|
autoscroll=False, |
|
autofocus=False, |
|
|
|
) |
|
|
|
def count_words(x): |
|
num_words = len(x.split()) |
|
if num_words > max_word_count: |
|
overflow = num_words-max_word_count |
|
string_num = str(num_words) + " (REMOVE " + str(overflow) + " WORDS to submit your response)" |
|
return { |
|
text_button1: gr.Button(interactive=False), |
|
word_count: gr.Textbox(string_num) |
|
} |
|
elif num_words < 1: |
|
string_num = str(num_words) + " (Please enter your response)" |
|
return { |
|
text_button1: gr.Button(interactive=False), |
|
word_count: gr.Textbox(string_num) |
|
} |
|
else: |
|
string_num = str(num_words) |
|
return { |
|
text_button1: gr.Button(interactive=True), |
|
word_count: gr.Textbox(string_num) |
|
} |
|
|
|
success_submit_instr = process_markdown(False, 'instr_submitsuccess.md') |
|
success_submitted = gr.Markdown(value=success_submit_instr, visible=False) |
|
|
|
tab_text1 = gr.Textbox( |
|
lines=textfield_lines, |
|
interactive=True, |
|
show_copy_button=True, |
|
container=True, |
|
autoscroll=True, |
|
autofocus=True, |
|
label="Write your response here:") |
|
|
|
text_button1 = gr.Button("Submit Response", variant="primary", interactive=False) |
|
tab_text1.input(count_words, tab_text1, [text_button1, word_count], show_progress="hidden") |
|
|
|
edit_response_instr = process_markdown(False, 'instr_editresponse.md') |
|
edit_response = gr.Markdown(value=edit_response_instr, visible=False) |
|
|
|
back_btn = gr.Button("Return to previous question", visible=False) |
|
proceed_instr = process_markdown(False, 'instr_proceed.md') |
|
proceed_to_next = gr.Markdown(value=proceed_instr, visible=False) |
|
|
|
def click_back_btn(): |
|
return { |
|
success_submitted: gr.update(visible=False), |
|
edit_response: gr.update(visible=False), |
|
proceed_to_next: gr.update(visible=False), |
|
back_btn: gr.update(visible=False), |
|
gen_instr1: gr.update(visible=True), |
|
prompt1: gr.update(visible=True), |
|
tab_text1: gr.update(visible=True, interactive=True, show_label=True, show_copy_button=True, |
|
container=True), |
|
word_count: gr.update(visible=True), |
|
text_button1: gr.update(visible=True), |
|
next_q_btn: gr.update(visible=False), |
|
unassisted_instr_md: gr.update(visible=True) |
|
} |
|
|
|
back_btn.click( |
|
fn=click_back_btn, |
|
inputs=[], |
|
outputs=[success_submitted, edit_response, proceed_to_next, back_btn, gen_instr1, prompt1, tab_text1, word_count, text_button1, next_q_btn, |
|
unassisted_instr_md] |
|
) |
|
|
|
def submit_question(submission_text): |
|
save_answer(question_answers, q_num, prompt, submission_text, False) |
|
return { |
|
success_submitted: gr.update(visible=True), |
|
edit_response: gr.update(visible=True), |
|
proceed_to_next: gr.update(visible=True), |
|
back_btn: gr.update(visible=True), |
|
gen_instr1: gr.update(visible=False), |
|
prompt1: gr.update(visible=False), |
|
tab_text1: gr.update(visible=True, interactive=False, show_label=False, show_copy_button=False, container=False), |
|
word_count: gr.update(visible=False), |
|
text_button1: gr.update(visible=False), |
|
next_q_btn: gr.update(visible=True), |
|
unassisted_instr_md: gr.update(visible=False) |
|
} |
|
text_button1.click( |
|
fn=submit_question, |
|
inputs=[tab_text1], |
|
outputs=[success_submitted, edit_response, proceed_to_next, back_btn, gen_instr1, prompt1, tab_text1, word_count, |
|
text_button1, next_q_btn, unassisted_instr_md] |
|
) |
|
return q_unassisted |
|
|
|
|
|
def generate_assisted_question(question_prompt, next_q_btn, q_num, question_answers): |
|
with gr.Column("Assisted Writing Task", render=False) as q_assisted: |
|
q_title_text = "#\n# Question " + str(q_num) |
|
q_title = gr.Markdown(value=q_title_text) |
|
|
|
assisted_instr = process_markdown(False, 'instr_assisted.md') |
|
assisted_instr_md = gr.Markdown(value=assisted_instr) |
|
|
|
instruction = question_prompt['instruction'] |
|
prompt = question_prompt['prompt_file'] |
|
max_word_count = question_prompt['word_count'] |
|
textfield_lines = question_prompt['textfield_lines'] |
|
|
|
gen_instr_text2 = process_markdown(False, instruction) |
|
prompt_text2 = process_markdown(True, prompt) |
|
|
|
instruction_txt = strip_markdown(gen_instr_text2) |
|
prompt_txt = strip_markdown(prompt_text2) |
|
|
|
gen_instr2 = gr.Markdown(value=gen_instr_text2) |
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
prompt2 = gr.Markdown(value=prompt_text2) |
|
with gr.Column(scale=1): |
|
word_count = gr.Textbox( |
|
label='Word Count', |
|
interactive=False, |
|
lines=1, |
|
max_lines=1, |
|
autoscroll=False, |
|
autofocus=False, |
|
|
|
) |
|
|
|
def count_words(x): |
|
num_words = len(x.split()) |
|
if num_words > max_word_count: |
|
overflow = num_words-max_word_count |
|
string_num = str(num_words) + " (REMOVE " + str(overflow) + " WORDS to submit your response)" |
|
return { |
|
text_button2: gr.Button(interactive=False), |
|
word_count: gr.Textbox(string_num) |
|
} |
|
elif num_words < 1: |
|
string_num = str(num_words) + " (Please enter your response)" |
|
return { |
|
text_button2: gr.Button(interactive=False), |
|
word_count: gr.Textbox(string_num) |
|
} |
|
else: |
|
string_num = str(num_words) |
|
return { |
|
text_button2: gr.Button(interactive=True), |
|
word_count: gr.Textbox(string_num) |
|
} |
|
|
|
initial_user_message = "You are a helpful writing assistant. You provide useful responses to writers’ questions. " \ |
|
"Here, writers will ask you questions about a specific writing task. " \ |
|
"You may only provide at most three sentences if the writer asks you to write an answer to the entire task for them. " \ |
|
"Your goal is to assist the writer but not do all the work. " \ |
|
"Here are the task-specific "+ instruction_txt + " " + prompt_txt |
|
|
|
conversations_list = [] |
|
with gr.Column() as chatbot_col: |
|
chatbot = gr.Chatbot(conversations_list, height=350, label="Writing Helper") |
|
|
|
state = gr.State(conversations_list) |
|
initial_usr_msg_state = gr.State(initial_user_message) |
|
|
|
model_state = gr.State("chatgpt") |
|
with gr.Column() as chat_feature: |
|
with gr.Group(): |
|
with gr.Row(): |
|
txt = gr.Textbox( |
|
value="", |
|
show_label=False, |
|
placeholder="Enter text and press the Interact button. Your current answer will not be sent to the assistant. If you want the assistant to know your current answer, paste it into the chat window.", |
|
lines=2, |
|
container=False, |
|
scale=4) |
|
submit_button = gr.Button("Interact", variant="primary", scale=1, size="sm") |
|
|
|
|
|
success_submit_instr = process_markdown(False, 'instr_submitsuccess.md') |
|
success_submitted = gr.Markdown(value=success_submit_instr, visible=False) |
|
|
|
tab_text2 = gr.Textbox( |
|
lines=textfield_lines, |
|
interactive=True, |
|
show_copy_button=True, |
|
container=True, |
|
autoscroll=True, |
|
autofocus=True, |
|
label="Write your response here:") |
|
|
|
submit_button.click(chatbot_generate, [txt, state, model_state, initial_usr_msg_state], [chatbot, state, txt, submit_button]) |
|
|
|
text_button2 = gr.Button("Submit Response", variant="primary", interactive=False) |
|
tab_text2.input(count_words, tab_text2, [text_button2, word_count], show_progress="hidden") |
|
|
|
edit_response_instr = process_markdown(False, 'instr_editresponse.md') |
|
edit_response = gr.Markdown(value=edit_response_instr, visible=False) |
|
|
|
back_btn = gr.Button("Return to question", visible=False) |
|
proceed_instr = process_markdown(False, 'instr_proceed.md') |
|
proceed_to_next = gr.Markdown(value=proceed_instr, visible=False) |
|
|
|
def click_back_btn(): |
|
return { |
|
success_submitted: gr.update(visible=False), |
|
edit_response: gr.update(visible=False), |
|
proceed_to_next: gr.update(visible=False), |
|
back_btn: gr.update(visible=False), |
|
gen_instr2: gr.update(visible=True), |
|
prompt2: gr.update(visible=True), |
|
tab_text2: gr.update(visible=True, interactive=True, show_label=True, show_copy_button=True, |
|
container=True), |
|
word_count: gr.update(visible=True), |
|
text_button2: gr.update(visible=True), |
|
chatbot_col: gr.update(visible=True), |
|
chat_feature: gr.update(visible=True), |
|
next_q_btn: gr.update(visible=False), |
|
assisted_instr_md: gr.update(visible=True) |
|
} |
|
|
|
back_btn.click( |
|
fn=click_back_btn, |
|
inputs=[], |
|
outputs=[success_submitted, edit_response, proceed_to_next, back_btn, gen_instr2, prompt2, tab_text2, word_count, |
|
chatbot_col, chat_feature, text_button2, next_q_btn, assisted_instr_md] |
|
) |
|
|
|
def submit_question(submission_text, assistance_history: None): |
|
save_answer(question_answers, q_num, prompt, submission_text, True, assistance_history) |
|
return { |
|
success_submitted: gr.update(visible=True), |
|
edit_response: gr.update(visible=True), |
|
proceed_to_next: gr.update(visible=True), |
|
back_btn: gr.update(visible=True), |
|
gen_instr2: gr.update(visible=False), |
|
prompt2: gr.update(visible=False), |
|
tab_text2: gr.update(visible=True, interactive=False, show_label=False, show_copy_button=False, |
|
container=False), |
|
word_count: gr.update(visible=False), |
|
text_button2: gr.update(visible=False), |
|
chatbot_col: gr.update(visible=False), |
|
chat_feature: gr.update(visible=False), |
|
next_q_btn: gr.update(visible=True), |
|
assisted_instr_md: gr.update(visible=False) |
|
} |
|
|
|
text_button2.click( |
|
fn=submit_question, |
|
inputs=[tab_text2, state], |
|
outputs=[success_submitted, edit_response, proceed_to_next, back_btn, gen_instr2, prompt2, tab_text2, word_count, |
|
text_button2, chatbot_col, chat_feature, next_q_btn, assisted_instr_md] |
|
) |
|
return q_assisted |
|
|