Spaces:
Running
Running
initial version
Browse files- app.py +94 -5
- constant.py +74 -0
- utils.py +522 -0
app.py
CHANGED
@@ -1,7 +1,96 @@
|
|
1 |
-
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from typing import List
|
3 |
+
from utils import get_base_answer, get_nudging_answer
|
4 |
+
from constant import js_code_label, custom_css, HEADER_MD, BASE_MODELS, NUDGING_MODELS
|
5 |
+
import datetime
|
6 |
|
7 |
+
addr_limit_counter = {}
|
8 |
+
LAST_UPDATE_TIME = datetime.datetime.now()
|
9 |
|
10 |
+
base_models = BASE_MODELS
|
11 |
+
nudging_models = NUDGING_MODELS
|
12 |
+
|
13 |
+
def respond_base(
|
14 |
+
system_prompt: str,
|
15 |
+
message: str,
|
16 |
+
max_tokens: int,
|
17 |
+
base_model: str,
|
18 |
+
):
|
19 |
+
return [(message, get_base_answer(base_model=base_model, system_prompt=system_prompt, question=message, max_tokens=max_tokens))]
|
20 |
+
|
21 |
+
def respond_nudging(
|
22 |
+
system_prompt: str,
|
23 |
+
message: str,
|
24 |
+
# history: list[tuple[str, str]],
|
25 |
+
max_tokens: int,
|
26 |
+
nudging_thres: float,
|
27 |
+
base_model: str,
|
28 |
+
nudging_model: str,
|
29 |
+
request:gr.Request
|
30 |
+
):
|
31 |
+
all_info = get_nudging_answer(base_model=base_model, nudging_model=nudging_model, system_prompt=system_prompt, question=message, max_token_total=max_tokens, top_prob_thres=nudging_thres)
|
32 |
+
all_completions = all_info["all_completions"]
|
33 |
+
nudging_words = all_info["all_nudging_words"]
|
34 |
+
formatted_response = format_response(all_completions, nudging_words)
|
35 |
+
return [(message, formatted_response)]
|
36 |
+
|
37 |
+
def clear_fn():
|
38 |
+
# mega_hist["base"] = []
|
39 |
+
# mega_hist["aligned"] = []
|
40 |
+
return None, None, None
|
41 |
+
|
42 |
+
def format_response(all_completions, nudging_words):
|
43 |
+
html_code = ""
|
44 |
+
for all_completion, nudging_word in zip(all_completions, nudging_words):
|
45 |
+
# each all_completion = nudging_word + base_completion
|
46 |
+
base_completion = all_completion[len(nudging_word):]
|
47 |
+
base_completion = base_completion
|
48 |
+
nudging_word = nudging_word
|
49 |
+
html_code += f"<mark>{nudging_word}</mark>{base_completion}"
|
50 |
+
return html_code
|
51 |
+
|
52 |
+
with gr.Blocks(gr.themes.Soft(), js=js_code_label, css=custom_css) as demo:
|
53 |
+
api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
|
54 |
+
|
55 |
+
gr.Markdown(HEADER_MD)
|
56 |
+
|
57 |
+
with gr.Row():
|
58 |
+
chat_a = gr.Chatbot(height=500, label="Nudging Answer", elem_id="chatbot")
|
59 |
+
chat_b = gr.Chatbot(height=500, label="Base Answer")
|
60 |
+
|
61 |
+
with gr.Group():
|
62 |
+
with gr.Row():
|
63 |
+
with gr.Column(scale=1.5):
|
64 |
+
system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter your system prompt here")
|
65 |
+
message = gr.Textbox(label="Prompt", placeholder="Enter your message here")
|
66 |
+
with gr.Row():
|
67 |
+
with gr.Column(scale=2):
|
68 |
+
with gr.Row():
|
69 |
+
base_model_choice = gr.Dropdown(label="Base Model", choices=base_models, interactive=True)
|
70 |
+
nudging_model_choice = gr.Dropdown(label="Nudging Model", choices=nudging_models, interactive=True)
|
71 |
+
with gr.Accordion("Nudging Parameters", open=True):
|
72 |
+
with gr.Row():
|
73 |
+
max_tokens = gr.Slider(label="Max tokens", value=256, minimum=0, maximum=512, step=16, interactive=True, visible=True)
|
74 |
+
nudging_thres = gr.Slider(label="Nudging Threshold", step=0.1, minimum=0.1, maximum=0.9, value=0.4)
|
75 |
+
with gr.Row():
|
76 |
+
btn = gr.Button("Generate")
|
77 |
+
with gr.Row():
|
78 |
+
stop_btn = gr.Button("Stop")
|
79 |
+
clear_btn = gr.Button("Clear")
|
80 |
+
|
81 |
+
base_model_choice.value = "Llama-2-70B"
|
82 |
+
nudging_model_choice.value = "Llama-2-13B-chat"
|
83 |
+
system_prompt.value = "Answer the question by walking through the reasoning steps."
|
84 |
+
message.value = "Question: There were 39 girls and 4 boys trying out for the schools basketball team. If only 26 of them got called back, how many students didn't make the cut?"
|
85 |
+
|
86 |
+
model_type_left = gr.Textbox(visible=False, value="base")
|
87 |
+
model_type_right = gr.Textbox(visible=False, value="aligned")
|
88 |
+
|
89 |
+
go1 = btn.click(respond_nudging, [system_prompt, message, max_tokens, nudging_thres, base_model_choice, nudging_model_choice], chat_a)
|
90 |
+
go2 = btn.click(respond_base, [system_prompt, message, max_tokens, base_model_choice], chat_b)
|
91 |
+
|
92 |
+
stop_btn.click(None, None, None, cancels=[go1, go2])
|
93 |
+
clear_btn.click(clear_fn, None, [message, chat_a, chat_b])
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
demo.launch(show_api=False)
|
constant.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
HEADER_MD = """# Inference-time Alignment with Nudging.
|
2 |
+
[📑 Paper](https://arxiv.org/abs/2410.09300) | [🛜 Website](https://fywalter.github.io/nudging/) | [💻 GitHub](https://github.com/fywalter/nudging) | [🐦 X](https://x.com/Walter_Fei/status/1848538273917898753) | 📮 Contact: [Yu Fei](https://fywalter.github.io/)
|
3 |
+
|
4 |
+
**By injecting a few nudging tokens at inference time, we can make base models able to follow user instructions helpfully and safely.**
|
5 |
+
- Our demo is powered by the [Together AI API](https://api.together.ai/). However, since only three base models are currently still available in the serverless API, we only choose three base models and nudging models for demonstration.
|
6 |
+
"""
|
7 |
+
|
8 |
+
js_code_label = """
|
9 |
+
function addApiKeyLink() {
|
10 |
+
// Select the div with id 'api_key'
|
11 |
+
const apiKeyDiv = document.getElementById('api_key');
|
12 |
+
// Find the span within that div with data-testid 'block-info'
|
13 |
+
const blockInfoSpan = apiKeyDiv.querySelector('span[data-testid="block-info"]');
|
14 |
+
// Create the new link element
|
15 |
+
const newLink = document.createElement('a');
|
16 |
+
newLink.href = 'https://api.together.ai/settings/api-keys';
|
17 |
+
newLink.textContent = ' View your keys here.';
|
18 |
+
newLink.target = '_blank'; // Open link in new tab
|
19 |
+
newLink.style = 'color: #007bff; text-decoration: underline;';
|
20 |
+
// Create the additional text
|
21 |
+
const additionalText = document.createTextNode(' (new account will have free credits to use.)');
|
22 |
+
// Append the link and additional text to the span
|
23 |
+
if (blockInfoSpan) {
|
24 |
+
// add a br
|
25 |
+
apiKeyDiv.appendChild(document.createElement('br'));
|
26 |
+
apiKeyDiv.appendChild(newLink);
|
27 |
+
apiKeyDiv.appendChild(additionalText);
|
28 |
+
} else {
|
29 |
+
console.error('Span with data-testid "block-info" not found');
|
30 |
+
}
|
31 |
+
}
|
32 |
+
"""
|
33 |
+
|
34 |
+
BASE_MODELS = [
|
35 |
+
"Llama-2-70B",
|
36 |
+
"Mistral-7B-v0.1",
|
37 |
+
"Mixtral-8x7B-v0.1",
|
38 |
+
]
|
39 |
+
|
40 |
+
NUDGING_MODELS = [
|
41 |
+
'Llama-2-13B-chat',
|
42 |
+
'Gemma-2-2B-it',
|
43 |
+
'Mistral-7B-v0.1-Instruct',
|
44 |
+
]
|
45 |
+
|
46 |
+
|
47 |
+
my_css = """
|
48 |
+
/* CSS for a link color that is visible on both black and white backgrounds */
|
49 |
+
a {
|
50 |
+
color: #1E90FF; /* DodgerBlue */
|
51 |
+
text-decoration: none; /* Optional: remove underline */
|
52 |
+
}
|
53 |
+
a:hover {
|
54 |
+
color: #104E8B; /* Slightly darker blue for hover effect */
|
55 |
+
text-decoration: underline; /* Optional: add underline on hover */
|
56 |
+
}
|
57 |
+
"""
|
58 |
+
# import json
|
59 |
+
# with open("together_model_ids.json", "r") as f:
|
60 |
+
# TOGETHER_MODEL_IDS = json.load(f)
|
61 |
+
|
62 |
+
# for _, model_id in MODEL_MAPPING.items():
|
63 |
+
# if model_id not in TOGETHER_MODEL_IDS + HYPERBOLIC_MODELS:
|
64 |
+
# print(model_id)
|
65 |
+
|
66 |
+
# Custom CSS for highlighting nudging words in the Chatbot
|
67 |
+
custom_css = """
|
68 |
+
.chatbot mark {
|
69 |
+
background-color: yellow;
|
70 |
+
color: orange;
|
71 |
+
font-style: italic;
|
72 |
+
font-weight: bold;
|
73 |
+
}
|
74 |
+
"""
|
utils.py
ADDED
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from openai import OpenAI
|
3 |
+
import os
|
4 |
+
import tiktoken
|
5 |
+
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") # use gpt3.5 tokenizer for token number controlling, so we don't need to load the actual tokenizer for API models
|
6 |
+
|
7 |
+
NUM_LOGPROBS = {
|
8 |
+
'top_prob': 1,
|
9 |
+
}
|
10 |
+
|
11 |
+
MODEL_MAPPING = {
|
12 |
+
"Llama-2-70B": "meta-llama/Llama-2-70b-hf",
|
13 |
+
"Mistral-7B-v0.1": "mistralai/Mistral-7B-v0.1",
|
14 |
+
"Mixtral-8x7B": "mistralai/Mixtral-8x7B-v0.1",
|
15 |
+
# Nudging models below
|
16 |
+
"Mistral-7B-v0.1-Instruct": "mistralai/Mistral-7B-Instruct-v0.1",
|
17 |
+
"Llama-2-13B-chat": "meta-llama/Llama-2-13b-chat-hf",
|
18 |
+
"Gemma-2-2B-it": "google/gemma-2b-it",
|
19 |
+
}
|
20 |
+
|
21 |
+
def apply_instruct_template(model_name, system_prompt, instruct_prompt, response_prompt, add_bos=False):
|
22 |
+
model_name = model_name.lower()
|
23 |
+
# print(model_name)
|
24 |
+
if "chat" in model_name and "llama" in model_name and "2" in model_name:
|
25 |
+
return llama_2_chat_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
|
26 |
+
elif "instruct" in model_name and "llama" in model_name and "3" in model_name:
|
27 |
+
if "3.1" in model_name: # for llama-3.1 models, add knowledge cut in system prompmt
|
28 |
+
return llama_3_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos, add_knowledge_cut=True)
|
29 |
+
else:
|
30 |
+
return llama_3_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
|
31 |
+
elif "it" in model_name and "gemma" in model_name:
|
32 |
+
return gemma_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
|
33 |
+
elif "instruct" in model_name and "olmo" in model_name:
|
34 |
+
return olmo_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
|
35 |
+
elif "instruct" in model_name and "mistral" in model_name:
|
36 |
+
return mistral_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=True)
|
37 |
+
else:
|
38 |
+
return f"{system_prompt}\n{instruct_prompt}\n{response_prompt}" # non-instruct model or models with unknown template
|
39 |
+
|
40 |
+
def mistral_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=True):
|
41 |
+
"""
|
42 |
+
Convert the input and output into the template used for the mistral instruct models training.
|
43 |
+
"""
|
44 |
+
prefix = "<s>" if add_bos else ""
|
45 |
+
return prefix + f"[INST] {system_prompt}\n{instruct_prompt} [/INST] {response_prompt}"
|
46 |
+
|
47 |
+
def llama_2_chat_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
|
48 |
+
"""
|
49 |
+
Convert the input and output into the template used for the llama-2 chat models training.
|
50 |
+
"""
|
51 |
+
prefix = "<s>" if add_bos else ""
|
52 |
+
return prefix + f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{instruct_prompt} [/INST] {response_prompt.lstrip()}" # for most servers that add <s> automatically so we don't need to add it here
|
53 |
+
|
54 |
+
def llama_3_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False, add_knowledge_cut=False):
|
55 |
+
"""
|
56 |
+
Convert the input and output into the template used for the llama-3 instruct models training.
|
57 |
+
"""
|
58 |
+
# print("applying llama-3 instruct template")
|
59 |
+
prefix = "<|begin_of_text|>" if add_bos else ""
|
60 |
+
if add_knowledge_cut:
|
61 |
+
system_prompt = f"Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"+ system_prompt
|
62 |
+
return prefix + f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{instruct_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{response_prompt}"
|
63 |
+
|
64 |
+
def gemma_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
|
65 |
+
"""
|
66 |
+
Convert the input and output into the template used for the gemma instruct models training.
|
67 |
+
<bos><start_of_turn>user
|
68 |
+
Write a hello world program<end_of_turn>
|
69 |
+
<start_of_turn>model
|
70 |
+
"""
|
71 |
+
prefix = "<bos>" if add_bos else ""
|
72 |
+
return prefix + f"<start_of_turn>user\n{system_prompt}\n{instruct_prompt}<end_of_turn>\n<start_of_turn>model\n{response_prompt}"
|
73 |
+
|
74 |
+
def olmo_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
|
75 |
+
"""
|
76 |
+
Convert the input and output into the template used for the olmo instruct models training.
|
77 |
+
"""
|
78 |
+
return f"<|endoftext|><|user|>\n{system_prompt}\n{instruct_prompt}\n<|assistant|>\n{response_prompt}"
|
79 |
+
|
80 |
+
def find_longest_repeated_suffix(s):
|
81 |
+
|
82 |
+
# Helper function to check if a substring repeats
|
83 |
+
def has_repeated(s, length):
|
84 |
+
if length < 30:
|
85 |
+
return False
|
86 |
+
# Extract the suffix of length 'length'
|
87 |
+
suffix = s[-length:]
|
88 |
+
# Check the rest of the string for another occurrence
|
89 |
+
# return s[:-length].find(suffix) != -1
|
90 |
+
return s[:-length].endswith(suffix)
|
91 |
+
|
92 |
+
left, right = 0, len(s)
|
93 |
+
result = 0
|
94 |
+
|
95 |
+
# Binary search for the longest repeated suffix
|
96 |
+
while left <= right:
|
97 |
+
mid = (left + right) // 2
|
98 |
+
if has_repeated(s, mid):
|
99 |
+
result = mid # Store the longest length found
|
100 |
+
left = mid + 1 # Try for a longer suffix
|
101 |
+
else:
|
102 |
+
right = mid - 1 # Try for a shorter suffix
|
103 |
+
|
104 |
+
# Return the longest repeated suffix
|
105 |
+
if result > 0:
|
106 |
+
return s[-result:]
|
107 |
+
return None # Return an empty string if no repetition is found
|
108 |
+
|
109 |
+
def remove_redundant_repetitions(s):
|
110 |
+
s = s.strip()
|
111 |
+
# Find the longest repeated suffix
|
112 |
+
longest_repeated_suffix = find_longest_repeated_suffix(s)
|
113 |
+
while longest_repeated_suffix:
|
114 |
+
# Remove the longest repeated suffix
|
115 |
+
s = s[:-len(longest_repeated_suffix)]
|
116 |
+
# Find the longest repeated suffix again
|
117 |
+
longest_repeated_suffix = find_longest_repeated_suffix(s)
|
118 |
+
return s
|
119 |
+
|
120 |
+
def repetition_check(new_completion, full_prefix, subseq_len=5):
|
121 |
+
words = new_completion.split(" ")
|
122 |
+
if len(words) > subseq_len and new_completion in full_prefix:
|
123 |
+
return True
|
124 |
+
return False
|
125 |
+
|
126 |
+
def convert_token_logprobs_to_top_logprobs(token_logprobs, tokens):
|
127 |
+
"""
|
128 |
+
Together AI now only returns token logprobs, this function converts token logprobs to top logprobs format: {token: logprob}
|
129 |
+
"""
|
130 |
+
top_logprobs = [{token: logprob} for token, logprob in zip(tokens, token_logprobs)]
|
131 |
+
return top_logprobs
|
132 |
+
|
133 |
+
def check_need_nudging(nudging_method,
|
134 |
+
base_token_id,
|
135 |
+
current_base_info,
|
136 |
+
thresholds,
|
137 |
+
):
|
138 |
+
if nudging_method == 'top_prob':
|
139 |
+
# check if the token prob is below the threshold
|
140 |
+
sorted_base_top_logprobs = {k: v for k, v in sorted(current_base_info["top_logprobs"][base_token_id].items(), key=lambda item: item[1], reverse=True)}
|
141 |
+
base_top_prob = np.exp(list(sorted_base_top_logprobs.values())[0])
|
142 |
+
need_nudging = base_top_prob < thresholds['top_prob']
|
143 |
+
else:
|
144 |
+
raise ValueError(f"Unknown nudging method {nudging_method}")
|
145 |
+
return need_nudging
|
146 |
+
|
147 |
+
def complete_with_base(nudging_method='top_prob',
|
148 |
+
base_model="davinci-002",
|
149 |
+
full_prefix_base="",
|
150 |
+
output="",
|
151 |
+
current_base_info=None,
|
152 |
+
max_completion_token=256,
|
153 |
+
completion_token_num=16,
|
154 |
+
client_base=None,
|
155 |
+
thresholds=None,
|
156 |
+
temperature=0.0,
|
157 |
+
top_p=0.9,
|
158 |
+
):
|
159 |
+
completion_base = "" if len(current_base_info["completion"]) == 0 else current_base_info["tokens"][0] # accept the first token from the 1st round which is the acc token from the first stage
|
160 |
+
completion_all = "" if len(current_base_info["completion"]) == 0 else current_base_info["tokens"][0] # completion_all records all the tokens from the base model including the tokens that are not accepted in the last round, for debugging and visualization
|
161 |
+
found_nudging_token = False
|
162 |
+
response = None
|
163 |
+
has_acc_token_stage_1 = True if len(current_base_info["completion"]) > 0 else False # if the current_base_info["completion"] is not empty, it means the first token in base completion is accepted from the 1st stage
|
164 |
+
EMPTY_INFO_DICT = {
|
165 |
+
"completion": "",
|
166 |
+
"tokens": [],
|
167 |
+
"top_logprobs": [],
|
168 |
+
"stop_reason": None,
|
169 |
+
"num_logprobs": NUM_LOGPROBS[nudging_method],
|
170 |
+
}
|
171 |
+
next_nudging_info = EMPTY_INFO_DICT # for nudging methods that compute nudging info during base completion, we can save the info for the next round, currently not used for top_prob nudging
|
172 |
+
while len(encoding.encode(completion_base)) < max_completion_token and not found_nudging_token:
|
173 |
+
|
174 |
+
if current_base_info["completion"] == "":
|
175 |
+
# complete the sentence using the base model
|
176 |
+
response = client_base.completions.create(
|
177 |
+
model=base_model,
|
178 |
+
prompt=full_prefix_base + output + completion_base,
|
179 |
+
max_tokens=completion_token_num,
|
180 |
+
temperature=temperature,
|
181 |
+
logprobs=current_base_info["num_logprobs"],
|
182 |
+
top_p=top_p,
|
183 |
+
)
|
184 |
+
current_base_info["tokens"] = response.choices[0].logprobs.tokens
|
185 |
+
current_base_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
|
186 |
+
if current_base_info["top_logprobs"] is None:
|
187 |
+
current_base_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_base_info["tokens"])
|
188 |
+
current_base_info["completion"] = response.choices[0].text
|
189 |
+
|
190 |
+
if has_acc_token_stage_1:
|
191 |
+
# pop the first token from the 1st round as it is already accepted from stage 1
|
192 |
+
current_base_info["tokens"] = current_base_info["tokens"][1:]
|
193 |
+
current_base_info["top_logprobs"] = current_base_info["top_logprobs"][1:]
|
194 |
+
current_base_info["completion"] = "".join(current_base_info["tokens"])
|
195 |
+
has_acc_token_stage_1 = False
|
196 |
+
|
197 |
+
completion = current_base_info["completion"]
|
198 |
+
tokens = current_base_info["tokens"]
|
199 |
+
|
200 |
+
if completion in completion_base:
|
201 |
+
break # repeated completion, break
|
202 |
+
|
203 |
+
nudging_position = -1
|
204 |
+
|
205 |
+
# find the first token that violates the nudging criteria
|
206 |
+
for base_idx in range(len(tokens)):
|
207 |
+
found_nudging_token = check_need_nudging(nudging_method=nudging_method, base_token_id=base_idx, current_base_info=current_base_info, thresholds=thresholds)
|
208 |
+
if found_nudging_token:
|
209 |
+
nudging_position = base_idx
|
210 |
+
break
|
211 |
+
|
212 |
+
if nudging_position == -1:
|
213 |
+
new_completion= "".join(tokens)
|
214 |
+
else:
|
215 |
+
new_completion = "".join(tokens[:nudging_position]) # include the last agreed token
|
216 |
+
# avoid repetition in answer
|
217 |
+
if repetition_check(new_completion, output + completion_base):
|
218 |
+
break
|
219 |
+
else:
|
220 |
+
completion_base += new_completion
|
221 |
+
|
222 |
+
if found_nudging_token: # if found the nudging token, break the loop, concat the last base completion to completion_all
|
223 |
+
completion_all += completion
|
224 |
+
else:
|
225 |
+
completion_all += new_completion
|
226 |
+
|
227 |
+
next_nudging_info = EMPTY_INFO_DICT
|
228 |
+
if response is not None and response.choices[0].finish_reason == "stop":
|
229 |
+
break
|
230 |
+
|
231 |
+
# reset the current_base_info
|
232 |
+
current_base_info['completion'] = ""
|
233 |
+
current_base_info['tokens'] = []
|
234 |
+
current_base_info['top_logprobs'] = []
|
235 |
+
|
236 |
+
return completion_base, completion_all, next_nudging_info
|
237 |
+
|
238 |
+
def completion_with_nudging(
|
239 |
+
base_model="davinci-002",
|
240 |
+
nudging_model="gpt-3.5-turbo",
|
241 |
+
system_prompt_base="Answer the question by walking through the reasoning step by step.",
|
242 |
+
system_prompt_nudging="Answer the question by walking through the reasoning step by step.",
|
243 |
+
question="",
|
244 |
+
context="",
|
245 |
+
question_prompt="Question: ",
|
246 |
+
answer_start_prompt_base="Answer: ",
|
247 |
+
answer_start_prompt_nudging="Answer: ",
|
248 |
+
completion_token_num=16,
|
249 |
+
completion_token_num_nudging=16,
|
250 |
+
max_token_total=256,
|
251 |
+
print_intermediate_output=False,
|
252 |
+
client=None, # default client
|
253 |
+
client_base=None,
|
254 |
+
client_nudging=None,
|
255 |
+
max_round=150,
|
256 |
+
nudging_temperature=0.0, # deterministic for nudging
|
257 |
+
base_temperature=0.0, # deterministic for base model
|
258 |
+
nudging_method='top_prob',
|
259 |
+
top_prob_thres=0.3,
|
260 |
+
top_p=0.9,
|
261 |
+
):
|
262 |
+
if client_base is None:
|
263 |
+
client_base = client
|
264 |
+
if client_nudging is None:
|
265 |
+
client_nudging = client
|
266 |
+
|
267 |
+
if nudging_method not in NUM_LOGPROBS.keys():
|
268 |
+
raise ValueError(f"nudging method {nudging_method} number of logprobs not defined")
|
269 |
+
|
270 |
+
full_prefix_base = apply_instruct_template(base_model, system_prompt_base, context + question_prompt + question, answer_start_prompt_base) # for base model this function just adds newlines
|
271 |
+
full_prefix_nudging = apply_instruct_template(nudging_model, system_prompt_nudging, context + question_prompt + question, answer_start_prompt_nudging)
|
272 |
+
|
273 |
+
thresholds = {
|
274 |
+
'top_prob': top_prob_thres,
|
275 |
+
}
|
276 |
+
|
277 |
+
output = ""
|
278 |
+
nudging_round = 0
|
279 |
+
all_nudging_words = []
|
280 |
+
all_nudging_and_completions = []
|
281 |
+
current_nudging_info = {
|
282 |
+
"completion": "",
|
283 |
+
"tokens": [],
|
284 |
+
"top_logprobs": [],
|
285 |
+
"stop_reason": None,
|
286 |
+
"num_logprobs": NUM_LOGPROBS[nudging_method],
|
287 |
+
}
|
288 |
+
stop_reason = None
|
289 |
+
repeat_nudging_word = 0
|
290 |
+
last_nudging_word = ""
|
291 |
+
while len(encoding.encode(output)) < max_token_total and nudging_round < max_round: # use the number of gpt-3.5 token to approximately control the length
|
292 |
+
nudging_round += 1
|
293 |
+
if current_nudging_info["completion"] == "":
|
294 |
+
response = client_nudging.completions.create(
|
295 |
+
model=nudging_model,
|
296 |
+
prompt=full_prefix_nudging + output,
|
297 |
+
max_tokens=completion_token_num_nudging,
|
298 |
+
temperature=nudging_temperature,
|
299 |
+
logprobs=current_nudging_info["num_logprobs"],
|
300 |
+
)
|
301 |
+
current_nudging_info["completion"] = response.choices[0].text
|
302 |
+
current_nudging_info["tokens"] = response.choices[0].logprobs.tokens
|
303 |
+
current_nudging_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
|
304 |
+
if current_nudging_info["top_logprobs"] is None:
|
305 |
+
current_nudging_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_nudging_info["tokens"])
|
306 |
+
current_nudging_info["stop_reason"] = response.choices[0].finish_reason
|
307 |
+
|
308 |
+
# if finish_reason is stop, break the loop, also handles nudging completion from previous round
|
309 |
+
if current_nudging_info["stop_reason"] == "stop":
|
310 |
+
stop_reason = "nudging_model_stop"
|
311 |
+
if len(current_nudging_info["completion"]) > 0:
|
312 |
+
all_nudging_words.append(current_nudging_info["completion"])
|
313 |
+
all_nudging_and_completions.append(current_nudging_info["completion"])
|
314 |
+
output += current_nudging_info["completion"]
|
315 |
+
break
|
316 |
+
|
317 |
+
# ===================================================================
|
318 |
+
# Stage 1: use base model to find the first token that violates the nudging criteria (no need to nudge)
|
319 |
+
# ===================================================================
|
320 |
+
found_acc_token = False
|
321 |
+
current_base_info = { # will be passed to the next stage
|
322 |
+
"completion": "",
|
323 |
+
"tokens": [],
|
324 |
+
"top_logprobs": [],
|
325 |
+
"num_logprobs": NUM_LOGPROBS[nudging_method],
|
326 |
+
}
|
327 |
+
nudging_text = current_nudging_info["completion"]
|
328 |
+
num_whitespaces = len(nudging_text) - len(nudging_text.lstrip(" "))
|
329 |
+
space_prefix = " " * num_whitespaces
|
330 |
+
current_nudging_words = nudging_text.lstrip(" ").split(" ") # token leads to some unexpected behaviors, still use nudging word
|
331 |
+
nudging_word_id = 0 if len(current_nudging_words) > 1 else 1 # if only one word, always accept the word and go to the next round: it won't go into the loop and found_acc_token will be False
|
332 |
+
while not found_acc_token and nudging_word_id < len(current_nudging_words) - 1:
|
333 |
+
nudging_word_id += 1 # always accept the first word
|
334 |
+
nudging_gen_prefix = space_prefix + " ".join(current_nudging_words[:nudging_word_id])
|
335 |
+
current_nudging_word = " " + current_nudging_words[nudging_word_id] # add a leading space to the current nudging word since the nudging words a split by space
|
336 |
+
if current_nudging_word == " ": # skip the multiple space
|
337 |
+
continue
|
338 |
+
prefix = full_prefix_base + output + nudging_gen_prefix
|
339 |
+
response = client_base.completions.create(
|
340 |
+
model=base_model,
|
341 |
+
prompt=prefix,
|
342 |
+
max_tokens=completion_token_num,
|
343 |
+
temperature=base_temperature,
|
344 |
+
logprobs=current_base_info["num_logprobs"],
|
345 |
+
top_p=top_p,
|
346 |
+
)
|
347 |
+
current_base_info["tokens"] = response.choices[0].logprobs.tokens
|
348 |
+
current_base_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
|
349 |
+
if current_base_info["top_logprobs"] is None:
|
350 |
+
current_base_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_base_info["tokens"])
|
351 |
+
current_base_info["completion"] = response.choices[0].text
|
352 |
+
|
353 |
+
# look for the first token that meets the nudging criteria
|
354 |
+
first_base_token = current_base_info["tokens"][0]
|
355 |
+
if current_nudging_word.startswith(first_base_token): # check if the current nudging word is the same or starts with the first base token
|
356 |
+
found_acc_token = True
|
357 |
+
else:
|
358 |
+
found_acc_token = not check_need_nudging(nudging_method, # check if the token violates the nudging criteria (no need to nudge)
|
359 |
+
base_token_id=0,
|
360 |
+
current_base_info=current_base_info,
|
361 |
+
thresholds=thresholds)
|
362 |
+
|
363 |
+
# here we have either prefix_idx == len(current_nudging_info["tokens"]): if no token meets the nudging criteria, use the current nudging completion
|
364 |
+
# or found_acc_token == True: if a token violates the nudging criteria, we use the prefix as nudging tokens
|
365 |
+
|
366 |
+
nudging_words = space_prefix + " ".join(current_nudging_words[:nudging_word_id])
|
367 |
+
|
368 |
+
# Heuristic: if the nudging words are the same as the last one for three rounds, break the loop
|
369 |
+
if nudging_words == last_nudging_word:
|
370 |
+
repeat_nudging_word += 1
|
371 |
+
if repeat_nudging_word >= 3:
|
372 |
+
stop_reason = "repeated_nudging_words"
|
373 |
+
break
|
374 |
+
else:
|
375 |
+
last_nudging_word = nudging_words
|
376 |
+
repeat_nudging_word = 0
|
377 |
+
all_nudging_words.append(nudging_words)
|
378 |
+
output += nudging_words
|
379 |
+
|
380 |
+
if not found_acc_token: # if no base token can be accepted, use the current nudging completion and go to the next round
|
381 |
+
all_nudging_and_completions.append(nudging_words)
|
382 |
+
# reset the current nudging info and continue to the next round
|
383 |
+
current_nudging_info = {
|
384 |
+
"completion": "",
|
385 |
+
"tokens": [],
|
386 |
+
"logprobs": [],
|
387 |
+
"stop_reason": None,
|
388 |
+
"num_logprobs": NUM_LOGPROBS[nudging_method],
|
389 |
+
}
|
390 |
+
continue
|
391 |
+
if current_base_info["completion"] == "": # the base model thinks the completion is done, go to the next round. Make sure current_base_info["completion"] is not empty if proceed to the next stage
|
392 |
+
all_nudging_and_completions.append(nudging_words)
|
393 |
+
current_nudging_info = {
|
394 |
+
"completion": "",
|
395 |
+
"tokens": [],
|
396 |
+
"logprobs": [],
|
397 |
+
"stop_reason": None,
|
398 |
+
"num_logprobs": NUM_LOGPROBS[nudging_method],
|
399 |
+
}
|
400 |
+
continue
|
401 |
+
|
402 |
+
# ===================================================================
|
403 |
+
# Stage 2: use nudging model to find the first token that meets the nudging criteria (need to nudge)
|
404 |
+
# ===================================================================
|
405 |
+
max_completion_token = max_token_total - len(encoding.encode(output))
|
406 |
+
completion_base, completion_base_all, current_nudging_info = complete_with_base(nudging_method=nudging_method,
|
407 |
+
base_model=base_model,
|
408 |
+
full_prefix_base=full_prefix_base,
|
409 |
+
output=output,
|
410 |
+
current_base_info=current_base_info,
|
411 |
+
max_completion_token=max_completion_token,
|
412 |
+
completion_token_num=completion_token_num,
|
413 |
+
client_base=client_base,
|
414 |
+
thresholds=thresholds,
|
415 |
+
temperature=base_temperature,
|
416 |
+
top_p=top_p,
|
417 |
+
)
|
418 |
+
# print(f"next_nudging_info: {current_nudging_info}") # debug
|
419 |
+
|
420 |
+
output += completion_base
|
421 |
+
all_nudging_and_completions.append(nudging_words + completion_base) # the generated tokens in each round, concating all completion would be the final output
|
422 |
+
if print_intermediate_output:
|
423 |
+
print(f"************nudging round {nudging_round}************")
|
424 |
+
print(f"****nudging words from {nudging_model}****: {nudging_words}")
|
425 |
+
print(f"****nudging text****: {nudging_text}")
|
426 |
+
print(f"****completion from {base_model}****: {completion_base}")
|
427 |
+
print(f"****all completion from {base_model}****: {completion_base_all}")
|
428 |
+
print(f"****output****: {output}")
|
429 |
+
|
430 |
+
if nudging_round >= max_round and not stop_reason:
|
431 |
+
stop_reason = "round"
|
432 |
+
if len(encoding.encode(output)) >= max_token_total and not stop_reason:
|
433 |
+
stop_reason = "length"
|
434 |
+
output = remove_redundant_repetitions(output)
|
435 |
+
if print_intermediate_output:
|
436 |
+
print(f"************final output************")
|
437 |
+
print(f"****output****: {output}")
|
438 |
+
|
439 |
+
all_info = {
|
440 |
+
"question": question,
|
441 |
+
"context": context,
|
442 |
+
"raw_answer": output,
|
443 |
+
"all_nudging_words": all_nudging_words,
|
444 |
+
"all_completions": all_nudging_and_completions,
|
445 |
+
"stop_reason": stop_reason,
|
446 |
+
"system_prompt_base": system_prompt_base,
|
447 |
+
"system_prompt_nudging": system_prompt_nudging,
|
448 |
+
"full_prefix_base": full_prefix_base,
|
449 |
+
"full_prefix_nudging": full_prefix_nudging,
|
450 |
+
}
|
451 |
+
return all_info
|
452 |
+
|
453 |
+
|
454 |
+
def get_nudging_answer(base_model,
|
455 |
+
nudging_model,
|
456 |
+
system_prompt,
|
457 |
+
question,
|
458 |
+
context="",
|
459 |
+
question_prompt="",
|
460 |
+
answer_start_prompt_base="",
|
461 |
+
answer_start_prompt_nudging="",
|
462 |
+
completion_token_num=16,
|
463 |
+
completion_token_num_nudging=16,
|
464 |
+
max_token_total=256,
|
465 |
+
max_round=150,
|
466 |
+
nudging_temperature=0.0,
|
467 |
+
base_temperature=0.0,
|
468 |
+
nudging_method='top_prob',
|
469 |
+
top_prob_thres=0.3,
|
470 |
+
):
|
471 |
+
base_model = MODEL_MAPPING[base_model]
|
472 |
+
nudging_model = MODEL_MAPPING[nudging_model]
|
473 |
+
# with open('TOGETHER_KEY.txt', 'r') as f:
|
474 |
+
# togetherai_api_key = f.read().strip()
|
475 |
+
togetherai_api_key = os.environ.get("TOGETHERAI_API_KEY")
|
476 |
+
client = OpenAI(
|
477 |
+
api_key=togetherai_api_key,
|
478 |
+
base_url="https://api.together.xyz/v1",
|
479 |
+
)
|
480 |
+
return completion_with_nudging(
|
481 |
+
base_model=base_model,
|
482 |
+
nudging_model=nudging_model,
|
483 |
+
system_prompt_base=system_prompt,
|
484 |
+
system_prompt_nudging=system_prompt,
|
485 |
+
question=question,
|
486 |
+
context=context,
|
487 |
+
question_prompt=question_prompt,
|
488 |
+
answer_start_prompt_base=answer_start_prompt_base,
|
489 |
+
answer_start_prompt_nudging=answer_start_prompt_nudging,
|
490 |
+
completion_token_num=completion_token_num,
|
491 |
+
completion_token_num_nudging=completion_token_num_nudging,
|
492 |
+
max_token_total=max_token_total,
|
493 |
+
print_intermediate_output=False,
|
494 |
+
client_base=client,
|
495 |
+
client_nudging=client,
|
496 |
+
max_round=max_round,
|
497 |
+
nudging_temperature=nudging_temperature,
|
498 |
+
base_temperature=base_temperature,
|
499 |
+
nudging_method=nudging_method,
|
500 |
+
top_prob_thres=top_prob_thres,
|
501 |
+
)
|
502 |
+
|
503 |
+
def get_base_answer(base_model,
|
504 |
+
system_prompt,
|
505 |
+
question,
|
506 |
+
max_tokens=256,):
|
507 |
+
base_model = MODEL_MAPPING[base_model]
|
508 |
+
# with open('TOGETHER_KEY.txt', 'r') as f:
|
509 |
+
# togetherai_api_key = f.read().strip()
|
510 |
+
togetherai_api_key = os.environ.get("TOGETHERAI_API_KEY")
|
511 |
+
client = OpenAI(
|
512 |
+
api_key=togetherai_api_key,
|
513 |
+
base_url="https://api.together.xyz/v1",
|
514 |
+
)
|
515 |
+
response = client.completions.create(
|
516 |
+
model=base_model,
|
517 |
+
prompt=system_prompt+"\n"+ question,
|
518 |
+
max_tokens=max_tokens,
|
519 |
+
temperature=0.0,
|
520 |
+
logprobs=1,
|
521 |
+
)
|
522 |
+
return response.choices[0].text
|