File size: 6,096 Bytes
675e88f
d07b421
 
 
1e06da5
 
 
9504acd
d07b421
 
9504acd
d07b421
 
 
 
 
 
 
 
1e06da5
d07b421
1e06da5
 
 
 
 
 
 
 
7ff0270
 
1e06da5
 
 
 
 
 
d07b421
 
 
 
 
 
 
 
 
 
 
1e06da5
 
 
 
 
 
 
 
7ff0270
 
d07b421
 
 
 
1e06da5
 
 
d07b421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c82e12
 
81bed86
8c82e12
d07b421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
from utils import get_base_answer, get_nudging_answer
from constant import js_code_label, custom_css, HEADER_MD, BASE_MODELS, NUDGING_MODELS
import datetime
import logging
# add logging info to console 
logging.basicConfig(level=logging.INFO)

addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now() 

base_models = BASE_MODELS
nudging_models = NUDGING_MODELS

def respond_base(
    system_prompt: str,
    message: str,
    max_tokens: int,
    base_model: str,
    request:gr.Request
):
    global LAST_UPDATE_TIME, addr_limit_counter 
    # if already 24 hours passed, reset the counter
    if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1):
        addr_limit_counter = {}
        LAST_UPDATE_TIME = datetime.datetime.now()
    host_addr = request.client.host
    if host_addr not in addr_limit_counter:
        addr_limit_counter[host_addr] = 0
    if addr_limit_counter[host_addr] > 50:
        raise gr.Error("You have reached the limit of 50 requests for today.", duration=10)

    base_answer = get_base_answer(base_model=base_model, system_prompt=system_prompt, question=message, max_tokens=max_tokens)
    addr_limit_counter[host_addr] += 1
    logging.info(f"Requesting chat completion from OpenAI API with model {base_model}")
    logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
    return [(message, base_answer)]

def respond_nudging(
    system_prompt: str,
    message: str,
    # history: list[tuple[str, str]],
    max_tokens: int,
    nudging_thres: float,
    base_model: str,
    nudging_model: str,
    request:gr.Request
):  
    global LAST_UPDATE_TIME, addr_limit_counter
    # if already 24 hours passed, reset the counter
    if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1):
        addr_limit_counter = {}
        LAST_UPDATE_TIME = datetime.datetime.now()
    host_addr = request.client.host
    if host_addr not in addr_limit_counter:
        addr_limit_counter[host_addr] = 0
    if addr_limit_counter[host_addr] > 50:
        raise gr.Error("You have reached the limit of 50 requests for today.", duration=10)
    all_info = get_nudging_answer(base_model=base_model, nudging_model=nudging_model, system_prompt=system_prompt, question=message, max_token_total=max_tokens, top_prob_thres=nudging_thres)
    all_completions = all_info["all_completions"]
    nudging_words = all_info["all_nudging_words"]
    formatted_response = format_response(all_completions, nudging_words)
    addr_limit_counter[host_addr] += 1
    logging.info(f"Requesting chat completion from OpenAI API with model {base_model} and {nudging_model}")
    logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
    return [(message, formatted_response)]

def clear_fn():
    # mega_hist["base"] = []
    # mega_hist["aligned"] = []
    return None, None, None

def format_response(all_completions, nudging_words):
    html_code = ""
    for all_completion, nudging_word in zip(all_completions, nudging_words):
        # each all_completion = nudging_word + base_completion
        base_completion = all_completion[len(nudging_word):]
        base_completion = base_completion
        nudging_word = nudging_word
        html_code += f"<mark>{nudging_word}</mark>{base_completion}"
    return html_code
        
with gr.Blocks(gr.themes.Soft(), js=js_code_label, css=custom_css) as demo:  
    api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
    
    gr.Markdown(HEADER_MD)

    with gr.Group():
        with gr.Row():
            with gr.Column(scale=1.5):
                system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter your system prompt here")
                message = gr.Textbox(label="Prompt", placeholder="Enter your message here")
                with gr.Row(): 
                    with gr.Column(scale=2):
                        with gr.Row():
                            base_model_choice = gr.Dropdown(label="Base Model", choices=base_models, interactive=True)
                            nudging_model_choice = gr.Dropdown(label="Nudging Model", choices=nudging_models, interactive=True)
                        with gr.Accordion("Nudging Parameters", open=True):
                            with gr.Row():
                                max_tokens = gr.Slider(label="Max tokens", value=256, minimum=0, maximum=512, step=16, interactive=True, visible=True)
                                nudging_thres = gr.Slider(label="Nudging Threshold", step=0.1, minimum=0.1, maximum=0.9, value=0.4)
                        with gr.Row():
                            btn = gr.Button("Generate")
                        with gr.Row():
                            stop_btn = gr.Button("Stop")
                            clear_btn = gr.Button("Clear")
    with gr.Row():
        chat_b = gr.Chatbot(height=500, label="Base Answer")
        chat_a = gr.Chatbot(height=500, label="Nudging Answer", elem_id="chatbot")
        
    base_model_choice.value = "Llama-2-70B"
    nudging_model_choice.value = "Llama-2-13B-chat"
    system_prompt.value = "Answer the question by walking through the reasoning steps."
    message.value = "Question: There were 39 girls and 4 boys trying out for the schools basketball team. If only 26 of them got called back, how many students didn't make the cut?"

    model_type_left = gr.Textbox(visible=False, value="base")
    model_type_right = gr.Textbox(visible=False, value="aligned")

    go1 = btn.click(respond_nudging, [system_prompt, message, max_tokens, nudging_thres, base_model_choice, nudging_model_choice], chat_a)
    go2 = btn.click(respond_base, [system_prompt, message, max_tokens, base_model_choice], chat_b)
    
    stop_btn.click(None, None, None, cancels=[go1, go2])
    clear_btn.click(clear_fn, None, [message, chat_a, chat_b])
    
if __name__ == "__main__": 
    demo.launch(show_api=False)