Spaces:
Running
Running
File size: 6,096 Bytes
675e88f d07b421 1e06da5 9504acd d07b421 9504acd d07b421 1e06da5 d07b421 1e06da5 7ff0270 1e06da5 d07b421 1e06da5 7ff0270 d07b421 1e06da5 d07b421 8c82e12 81bed86 8c82e12 d07b421 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
from utils import get_base_answer, get_nudging_answer
from constant import js_code_label, custom_css, HEADER_MD, BASE_MODELS, NUDGING_MODELS
import datetime
import logging
# add logging info to console
logging.basicConfig(level=logging.INFO)
addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now()
base_models = BASE_MODELS
nudging_models = NUDGING_MODELS
def respond_base(
system_prompt: str,
message: str,
max_tokens: int,
base_model: str,
request:gr.Request
):
global LAST_UPDATE_TIME, addr_limit_counter
# if already 24 hours passed, reset the counter
if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1):
addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now()
host_addr = request.client.host
if host_addr not in addr_limit_counter:
addr_limit_counter[host_addr] = 0
if addr_limit_counter[host_addr] > 50:
raise gr.Error("You have reached the limit of 50 requests for today.", duration=10)
base_answer = get_base_answer(base_model=base_model, system_prompt=system_prompt, question=message, max_tokens=max_tokens)
addr_limit_counter[host_addr] += 1
logging.info(f"Requesting chat completion from OpenAI API with model {base_model}")
logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
return [(message, base_answer)]
def respond_nudging(
system_prompt: str,
message: str,
# history: list[tuple[str, str]],
max_tokens: int,
nudging_thres: float,
base_model: str,
nudging_model: str,
request:gr.Request
):
global LAST_UPDATE_TIME, addr_limit_counter
# if already 24 hours passed, reset the counter
if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1):
addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now()
host_addr = request.client.host
if host_addr not in addr_limit_counter:
addr_limit_counter[host_addr] = 0
if addr_limit_counter[host_addr] > 50:
raise gr.Error("You have reached the limit of 50 requests for today.", duration=10)
all_info = get_nudging_answer(base_model=base_model, nudging_model=nudging_model, system_prompt=system_prompt, question=message, max_token_total=max_tokens, top_prob_thres=nudging_thres)
all_completions = all_info["all_completions"]
nudging_words = all_info["all_nudging_words"]
formatted_response = format_response(all_completions, nudging_words)
addr_limit_counter[host_addr] += 1
logging.info(f"Requesting chat completion from OpenAI API with model {base_model} and {nudging_model}")
logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
return [(message, formatted_response)]
def clear_fn():
# mega_hist["base"] = []
# mega_hist["aligned"] = []
return None, None, None
def format_response(all_completions, nudging_words):
html_code = ""
for all_completion, nudging_word in zip(all_completions, nudging_words):
# each all_completion = nudging_word + base_completion
base_completion = all_completion[len(nudging_word):]
base_completion = base_completion
nudging_word = nudging_word
html_code += f"<mark>{nudging_word}</mark>{base_completion}"
return html_code
with gr.Blocks(gr.themes.Soft(), js=js_code_label, css=custom_css) as demo:
api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
gr.Markdown(HEADER_MD)
with gr.Group():
with gr.Row():
with gr.Column(scale=1.5):
system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter your system prompt here")
message = gr.Textbox(label="Prompt", placeholder="Enter your message here")
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
base_model_choice = gr.Dropdown(label="Base Model", choices=base_models, interactive=True)
nudging_model_choice = gr.Dropdown(label="Nudging Model", choices=nudging_models, interactive=True)
with gr.Accordion("Nudging Parameters", open=True):
with gr.Row():
max_tokens = gr.Slider(label="Max tokens", value=256, minimum=0, maximum=512, step=16, interactive=True, visible=True)
nudging_thres = gr.Slider(label="Nudging Threshold", step=0.1, minimum=0.1, maximum=0.9, value=0.4)
with gr.Row():
btn = gr.Button("Generate")
with gr.Row():
stop_btn = gr.Button("Stop")
clear_btn = gr.Button("Clear")
with gr.Row():
chat_b = gr.Chatbot(height=500, label="Base Answer")
chat_a = gr.Chatbot(height=500, label="Nudging Answer", elem_id="chatbot")
base_model_choice.value = "Llama-2-70B"
nudging_model_choice.value = "Llama-2-13B-chat"
system_prompt.value = "Answer the question by walking through the reasoning steps."
message.value = "Question: There were 39 girls and 4 boys trying out for the schools basketball team. If only 26 of them got called back, how many students didn't make the cut?"
model_type_left = gr.Textbox(visible=False, value="base")
model_type_right = gr.Textbox(visible=False, value="aligned")
go1 = btn.click(respond_nudging, [system_prompt, message, max_tokens, nudging_thres, base_model_choice, nudging_model_choice], chat_a)
go2 = btn.click(respond_base, [system_prompt, message, max_tokens, base_model_choice], chat_b)
stop_btn.click(None, None, None, cancels=[go1, go2])
clear_btn.click(clear_fn, None, [message, chat_a, chat_b])
if __name__ == "__main__":
demo.launch(show_api=False) |