nudging_align / app.py
fywalter's picture
initial version
d07b421
raw
history blame
4.46 kB
import gradio as gr
from typing import List
from utils import get_base_answer, get_nudging_answer
from constant import js_code_label, custom_css, HEADER_MD, BASE_MODELS, NUDGING_MODELS
import datetime
addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now()
base_models = BASE_MODELS
nudging_models = NUDGING_MODELS
def respond_base(
system_prompt: str,
message: str,
max_tokens: int,
base_model: str,
):
return [(message, get_base_answer(base_model=base_model, system_prompt=system_prompt, question=message, max_tokens=max_tokens))]
def respond_nudging(
system_prompt: str,
message: str,
# history: list[tuple[str, str]],
max_tokens: int,
nudging_thres: float,
base_model: str,
nudging_model: str,
request:gr.Request
):
all_info = get_nudging_answer(base_model=base_model, nudging_model=nudging_model, system_prompt=system_prompt, question=message, max_token_total=max_tokens, top_prob_thres=nudging_thres)
all_completions = all_info["all_completions"]
nudging_words = all_info["all_nudging_words"]
formatted_response = format_response(all_completions, nudging_words)
return [(message, formatted_response)]
def clear_fn():
# mega_hist["base"] = []
# mega_hist["aligned"] = []
return None, None, None
def format_response(all_completions, nudging_words):
html_code = ""
for all_completion, nudging_word in zip(all_completions, nudging_words):
# each all_completion = nudging_word + base_completion
base_completion = all_completion[len(nudging_word):]
base_completion = base_completion
nudging_word = nudging_word
html_code += f"<mark>{nudging_word}</mark>{base_completion}"
return html_code
with gr.Blocks(gr.themes.Soft(), js=js_code_label, css=custom_css) as demo:
api_key = gr.Textbox(label="πŸ”‘ APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
gr.Markdown(HEADER_MD)
with gr.Row():
chat_a = gr.Chatbot(height=500, label="Nudging Answer", elem_id="chatbot")
chat_b = gr.Chatbot(height=500, label="Base Answer")
with gr.Group():
with gr.Row():
with gr.Column(scale=1.5):
system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter your system prompt here")
message = gr.Textbox(label="Prompt", placeholder="Enter your message here")
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
base_model_choice = gr.Dropdown(label="Base Model", choices=base_models, interactive=True)
nudging_model_choice = gr.Dropdown(label="Nudging Model", choices=nudging_models, interactive=True)
with gr.Accordion("Nudging Parameters", open=True):
with gr.Row():
max_tokens = gr.Slider(label="Max tokens", value=256, minimum=0, maximum=512, step=16, interactive=True, visible=True)
nudging_thres = gr.Slider(label="Nudging Threshold", step=0.1, minimum=0.1, maximum=0.9, value=0.4)
with gr.Row():
btn = gr.Button("Generate")
with gr.Row():
stop_btn = gr.Button("Stop")
clear_btn = gr.Button("Clear")
base_model_choice.value = "Llama-2-70B"
nudging_model_choice.value = "Llama-2-13B-chat"
system_prompt.value = "Answer the question by walking through the reasoning steps."
message.value = "Question: There were 39 girls and 4 boys trying out for the schools basketball team. If only 26 of them got called back, how many students didn't make the cut?"
model_type_left = gr.Textbox(visible=False, value="base")
model_type_right = gr.Textbox(visible=False, value="aligned")
go1 = btn.click(respond_nudging, [system_prompt, message, max_tokens, nudging_thres, base_model_choice, nudging_model_choice], chat_a)
go2 = btn.click(respond_base, [system_prompt, message, max_tokens, base_model_choice], chat_b)
stop_btn.click(None, None, None, cancels=[go1, go2])
clear_btn.click(clear_fn, None, [message, chat_a, chat_b])
if __name__ == "__main__":
demo.launch(show_api=False)