import gradio as gr from utils import get_base_answer, get_nudging_answer from constant import js_code_label, custom_css, HEADER_MD, BASE_MODELS, NUDGING_MODELS import datetime import logging # add logging info to console logging.basicConfig(level=logging.INFO) addr_limit_counter = {} LAST_UPDATE_TIME = datetime.datetime.now() base_models = BASE_MODELS nudging_models = NUDGING_MODELS def respond_base( system_prompt: str, message: str, max_tokens: int, base_model: str, request:gr.Request ): global LAST_UPDATE_TIME, addr_limit_counter # if already 24 hours passed, reset the counter if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1): addr_limit_counter = {} LAST_UPDATE_TIME = datetime.datetime.now() host_addr = request.client.host if host_addr not in addr_limit_counter: addr_limit_counter[host_addr] = 0 if addr_limit_counter[host_addr] > 50: raise gr.Error("You have reached the limit of 50 requests for today.", duration=10) base_answer = get_base_answer(base_model=base_model, system_prompt=system_prompt, question=message, max_tokens=max_tokens) addr_limit_counter[host_addr] += 1 logging.info(f"Requesting chat completion from OpenAI API with model {base_model}") logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};") return [(message, base_answer)] def respond_nudging( system_prompt: str, message: str, # history: list[tuple[str, str]], max_tokens: int, nudging_thres: float, base_model: str, nudging_model: str, request:gr.Request ): global LAST_UPDATE_TIME, addr_limit_counter # if already 24 hours passed, reset the counter if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1): addr_limit_counter = {} LAST_UPDATE_TIME = datetime.datetime.now() host_addr = request.client.host if host_addr not in addr_limit_counter: addr_limit_counter[host_addr] = 0 if addr_limit_counter[host_addr] > 50: raise gr.Error("You have reached the limit of 50 requests for today.", duration=10) all_info = get_nudging_answer(base_model=base_model, nudging_model=nudging_model, system_prompt=system_prompt, question=message, max_token_total=max_tokens, top_prob_thres=nudging_thres) all_completions = all_info["all_completions"] nudging_words = all_info["all_nudging_words"] formatted_response = format_response(all_completions, nudging_words) addr_limit_counter[host_addr] += 1 logging.info(f"Requesting chat completion from OpenAI API with model {base_model} and {nudging_model}") logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};") return [(message, formatted_response)] def clear_fn(): # mega_hist["base"] = [] # mega_hist["aligned"] = [] return None, None, None def format_response(all_completions, nudging_words): html_code = "" for all_completion, nudging_word in zip(all_completions, nudging_words): # each all_completion = nudging_word + base_completion base_completion = all_completion[len(nudging_word):] base_completion = base_completion nudging_word = nudging_word html_code += f"{nudging_word}{base_completion}" return html_code with gr.Blocks(gr.themes.Soft(), js=js_code_label, css=custom_css) as demo: api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False) gr.Markdown(HEADER_MD) with gr.Group(): with gr.Row(): with gr.Column(scale=1.5): system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter your system prompt here") message = gr.Textbox(label="Prompt", placeholder="Enter your message here") with gr.Row(): with gr.Column(scale=2): with gr.Row(): base_model_choice = gr.Dropdown(label="Base Model", choices=base_models, interactive=True) nudging_model_choice = gr.Dropdown(label="Nudging Model", choices=nudging_models, interactive=True) with gr.Accordion("Nudging Parameters", open=True): with gr.Row(): max_tokens = gr.Slider(label="Max tokens", value=256, minimum=0, maximum=512, step=16, interactive=True, visible=True) nudging_thres = gr.Slider(label="Nudging Threshold", step=0.1, minimum=0.1, maximum=0.9, value=0.4) with gr.Row(): btn = gr.Button("Generate") with gr.Row(): stop_btn = gr.Button("Stop") clear_btn = gr.Button("Clear") with gr.Row(): chat_b = gr.Chatbot(height=500, label="Base Answer") chat_a = gr.Chatbot(height=500, label="Nudging Answer", elem_id="chatbot") base_model_choice.value = "Llama-2-70B" nudging_model_choice.value = "Llama-2-13B-chat" system_prompt.value = "Answer the question by walking through the reasoning steps." message.value = "Question: There were 39 girls and 4 boys trying out for the schools basketball team. If only 26 of them got called back, how many students didn't make the cut?" model_type_left = gr.Textbox(visible=False, value="base") model_type_right = gr.Textbox(visible=False, value="aligned") go1 = btn.click(respond_nudging, [system_prompt, message, max_tokens, nudging_thres, base_model_choice, nudging_model_choice], chat_a) go2 = btn.click(respond_base, [system_prompt, message, max_tokens, base_model_choice], chat_b) stop_btn.click(None, None, None, cancels=[go1, go2]) clear_btn.click(clear_fn, None, [message, chat_a, chat_b]) if __name__ == "__main__": demo.launch(show_api=False)