File size: 4,461 Bytes
d07b421
 
 
 
 
9504acd
d07b421
 
9504acd
d07b421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr 
from typing import List
from utils import get_base_answer, get_nudging_answer
from constant import js_code_label, custom_css, HEADER_MD, BASE_MODELS, NUDGING_MODELS
import datetime

addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now() 

base_models = BASE_MODELS
nudging_models = NUDGING_MODELS

def respond_base(
    system_prompt: str,
    message: str,
    max_tokens: int,
    base_model: str,
):
    return [(message, get_base_answer(base_model=base_model, system_prompt=system_prompt, question=message, max_tokens=max_tokens))]

def respond_nudging(
    system_prompt: str,
    message: str,
    # history: list[tuple[str, str]],
    max_tokens: int,
    nudging_thres: float,
    base_model: str,
    nudging_model: str,
    request:gr.Request
):  
    all_info = get_nudging_answer(base_model=base_model, nudging_model=nudging_model, system_prompt=system_prompt, question=message, max_token_total=max_tokens, top_prob_thres=nudging_thres)
    all_completions = all_info["all_completions"]
    nudging_words = all_info["all_nudging_words"]
    formatted_response = format_response(all_completions, nudging_words)
    return [(message, formatted_response)]

def clear_fn():
    # mega_hist["base"] = []
    # mega_hist["aligned"] = []
    return None, None, None

def format_response(all_completions, nudging_words):
    html_code = ""
    for all_completion, nudging_word in zip(all_completions, nudging_words):
        # each all_completion = nudging_word + base_completion
        base_completion = all_completion[len(nudging_word):]
        base_completion = base_completion
        nudging_word = nudging_word
        html_code += f"<mark>{nudging_word}</mark>{base_completion}"
    return html_code
        
with gr.Blocks(gr.themes.Soft(), js=js_code_label, css=custom_css) as demo:  
    api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
    
    gr.Markdown(HEADER_MD)

    with gr.Row():
        chat_a = gr.Chatbot(height=500, label="Nudging Answer", elem_id="chatbot")
        chat_b = gr.Chatbot(height=500, label="Base Answer")

    with gr.Group():
        with gr.Row():
            with gr.Column(scale=1.5):
                system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter your system prompt here")
                message = gr.Textbox(label="Prompt", placeholder="Enter your message here")
                with gr.Row(): 
                    with gr.Column(scale=2):
                        with gr.Row():
                            base_model_choice = gr.Dropdown(label="Base Model", choices=base_models, interactive=True)
                            nudging_model_choice = gr.Dropdown(label="Nudging Model", choices=nudging_models, interactive=True)
                        with gr.Accordion("Nudging Parameters", open=True):
                            with gr.Row():
                                max_tokens = gr.Slider(label="Max tokens", value=256, minimum=0, maximum=512, step=16, interactive=True, visible=True)
                                nudging_thres = gr.Slider(label="Nudging Threshold", step=0.1, minimum=0.1, maximum=0.9, value=0.4)
                        with gr.Row():
                            btn = gr.Button("Generate")
                        with gr.Row():
                            stop_btn = gr.Button("Stop")
                            clear_btn = gr.Button("Clear")
    
    base_model_choice.value = "Llama-2-70B"
    nudging_model_choice.value = "Llama-2-13B-chat"
    system_prompt.value = "Answer the question by walking through the reasoning steps."
    message.value = "Question: There were 39 girls and 4 boys trying out for the schools basketball team. If only 26 of them got called back, how many students didn't make the cut?"

    model_type_left = gr.Textbox(visible=False, value="base")
    model_type_right = gr.Textbox(visible=False, value="aligned")

    go1 = btn.click(respond_nudging, [system_prompt, message, max_tokens, nudging_thres, base_model_choice, nudging_model_choice], chat_a)
    go2 = btn.click(respond_base, [system_prompt, message, max_tokens, base_model_choice], chat_b)
    
    stop_btn.click(None, None, None, cancels=[go1, go2])
    clear_btn.click(clear_fn, None, [message, chat_a, chat_b])
    
if __name__ == "__main__": 
    demo.launch(show_api=False)