Spaces:
Running
Running
import gradio as gr | |
from typing import List | |
from utils import get_base_answer, get_nudging_answer | |
from constant import js_code_label, custom_css, HEADER_MD, BASE_MODELS, NUDGING_MODELS | |
import datetime | |
addr_limit_counter = {} | |
LAST_UPDATE_TIME = datetime.datetime.now() | |
base_models = BASE_MODELS | |
nudging_models = NUDGING_MODELS | |
def respond_base( | |
system_prompt: str, | |
message: str, | |
max_tokens: int, | |
base_model: str, | |
): | |
return [(message, get_base_answer(base_model=base_model, system_prompt=system_prompt, question=message, max_tokens=max_tokens))] | |
def respond_nudging( | |
system_prompt: str, | |
message: str, | |
# history: list[tuple[str, str]], | |
max_tokens: int, | |
nudging_thres: float, | |
base_model: str, | |
nudging_model: str, | |
request:gr.Request | |
): | |
all_info = get_nudging_answer(base_model=base_model, nudging_model=nudging_model, system_prompt=system_prompt, question=message, max_token_total=max_tokens, top_prob_thres=nudging_thres) | |
all_completions = all_info["all_completions"] | |
nudging_words = all_info["all_nudging_words"] | |
formatted_response = format_response(all_completions, nudging_words) | |
return [(message, formatted_response)] | |
def clear_fn(): | |
# mega_hist["base"] = [] | |
# mega_hist["aligned"] = [] | |
return None, None, None | |
def format_response(all_completions, nudging_words): | |
html_code = "" | |
for all_completion, nudging_word in zip(all_completions, nudging_words): | |
# each all_completion = nudging_word + base_completion | |
base_completion = all_completion[len(nudging_word):] | |
base_completion = base_completion | |
nudging_word = nudging_word | |
html_code += f"<mark>{nudging_word}</mark>{base_completion}" | |
return html_code | |
with gr.Blocks(gr.themes.Soft(), js=js_code_label, css=custom_css) as demo: | |
api_key = gr.Textbox(label="π APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False) | |
gr.Markdown(HEADER_MD) | |
with gr.Row(): | |
chat_a = gr.Chatbot(height=500, label="Nudging Answer", elem_id="chatbot") | |
chat_b = gr.Chatbot(height=500, label="Base Answer") | |
with gr.Group(): | |
with gr.Row(): | |
with gr.Column(scale=1.5): | |
system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter your system prompt here") | |
message = gr.Textbox(label="Prompt", placeholder="Enter your message here") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Row(): | |
base_model_choice = gr.Dropdown(label="Base Model", choices=base_models, interactive=True) | |
nudging_model_choice = gr.Dropdown(label="Nudging Model", choices=nudging_models, interactive=True) | |
with gr.Accordion("Nudging Parameters", open=True): | |
with gr.Row(): | |
max_tokens = gr.Slider(label="Max tokens", value=256, minimum=0, maximum=512, step=16, interactive=True, visible=True) | |
nudging_thres = gr.Slider(label="Nudging Threshold", step=0.1, minimum=0.1, maximum=0.9, value=0.4) | |
with gr.Row(): | |
btn = gr.Button("Generate") | |
with gr.Row(): | |
stop_btn = gr.Button("Stop") | |
clear_btn = gr.Button("Clear") | |
base_model_choice.value = "Llama-2-70B" | |
nudging_model_choice.value = "Llama-2-13B-chat" | |
system_prompt.value = "Answer the question by walking through the reasoning steps." | |
message.value = "Question: There were 39 girls and 4 boys trying out for the schools basketball team. If only 26 of them got called back, how many students didn't make the cut?" | |
model_type_left = gr.Textbox(visible=False, value="base") | |
model_type_right = gr.Textbox(visible=False, value="aligned") | |
go1 = btn.click(respond_nudging, [system_prompt, message, max_tokens, nudging_thres, base_model_choice, nudging_model_choice], chat_a) | |
go2 = btn.click(respond_base, [system_prompt, message, max_tokens, base_model_choice], chat_b) | |
stop_btn.click(None, None, None, cancels=[go1, go2]) | |
clear_btn.click(clear_fn, None, [message, chat_a, chat_b]) | |
if __name__ == "__main__": | |
demo.launch(show_api=False) |