fywalter commited on
Commit
d07b421
·
1 Parent(s): 9504acd

initial version

Browse files
Files changed (3) hide show
  1. app.py +94 -5
  2. constant.py +74 -0
  3. utils.py +522 -0
app.py CHANGED
@@ -1,7 +1,96 @@
1
- import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import List
3
+ from utils import get_base_answer, get_nudging_answer
4
+ from constant import js_code_label, custom_css, HEADER_MD, BASE_MODELS, NUDGING_MODELS
5
+ import datetime
6
 
7
+ addr_limit_counter = {}
8
+ LAST_UPDATE_TIME = datetime.datetime.now()
9
 
10
+ base_models = BASE_MODELS
11
+ nudging_models = NUDGING_MODELS
12
+
13
+ def respond_base(
14
+ system_prompt: str,
15
+ message: str,
16
+ max_tokens: int,
17
+ base_model: str,
18
+ ):
19
+ return [(message, get_base_answer(base_model=base_model, system_prompt=system_prompt, question=message, max_tokens=max_tokens))]
20
+
21
+ def respond_nudging(
22
+ system_prompt: str,
23
+ message: str,
24
+ # history: list[tuple[str, str]],
25
+ max_tokens: int,
26
+ nudging_thres: float,
27
+ base_model: str,
28
+ nudging_model: str,
29
+ request:gr.Request
30
+ ):
31
+ all_info = get_nudging_answer(base_model=base_model, nudging_model=nudging_model, system_prompt=system_prompt, question=message, max_token_total=max_tokens, top_prob_thres=nudging_thres)
32
+ all_completions = all_info["all_completions"]
33
+ nudging_words = all_info["all_nudging_words"]
34
+ formatted_response = format_response(all_completions, nudging_words)
35
+ return [(message, formatted_response)]
36
+
37
+ def clear_fn():
38
+ # mega_hist["base"] = []
39
+ # mega_hist["aligned"] = []
40
+ return None, None, None
41
+
42
+ def format_response(all_completions, nudging_words):
43
+ html_code = ""
44
+ for all_completion, nudging_word in zip(all_completions, nudging_words):
45
+ # each all_completion = nudging_word + base_completion
46
+ base_completion = all_completion[len(nudging_word):]
47
+ base_completion = base_completion
48
+ nudging_word = nudging_word
49
+ html_code += f"<mark>{nudging_word}</mark>{base_completion}"
50
+ return html_code
51
+
52
+ with gr.Blocks(gr.themes.Soft(), js=js_code_label, css=custom_css) as demo:
53
+ api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
54
+
55
+ gr.Markdown(HEADER_MD)
56
+
57
+ with gr.Row():
58
+ chat_a = gr.Chatbot(height=500, label="Nudging Answer", elem_id="chatbot")
59
+ chat_b = gr.Chatbot(height=500, label="Base Answer")
60
+
61
+ with gr.Group():
62
+ with gr.Row():
63
+ with gr.Column(scale=1.5):
64
+ system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter your system prompt here")
65
+ message = gr.Textbox(label="Prompt", placeholder="Enter your message here")
66
+ with gr.Row():
67
+ with gr.Column(scale=2):
68
+ with gr.Row():
69
+ base_model_choice = gr.Dropdown(label="Base Model", choices=base_models, interactive=True)
70
+ nudging_model_choice = gr.Dropdown(label="Nudging Model", choices=nudging_models, interactive=True)
71
+ with gr.Accordion("Nudging Parameters", open=True):
72
+ with gr.Row():
73
+ max_tokens = gr.Slider(label="Max tokens", value=256, minimum=0, maximum=512, step=16, interactive=True, visible=True)
74
+ nudging_thres = gr.Slider(label="Nudging Threshold", step=0.1, minimum=0.1, maximum=0.9, value=0.4)
75
+ with gr.Row():
76
+ btn = gr.Button("Generate")
77
+ with gr.Row():
78
+ stop_btn = gr.Button("Stop")
79
+ clear_btn = gr.Button("Clear")
80
+
81
+ base_model_choice.value = "Llama-2-70B"
82
+ nudging_model_choice.value = "Llama-2-13B-chat"
83
+ system_prompt.value = "Answer the question by walking through the reasoning steps."
84
+ message.value = "Question: There were 39 girls and 4 boys trying out for the schools basketball team. If only 26 of them got called back, how many students didn't make the cut?"
85
+
86
+ model_type_left = gr.Textbox(visible=False, value="base")
87
+ model_type_right = gr.Textbox(visible=False, value="aligned")
88
+
89
+ go1 = btn.click(respond_nudging, [system_prompt, message, max_tokens, nudging_thres, base_model_choice, nudging_model_choice], chat_a)
90
+ go2 = btn.click(respond_base, [system_prompt, message, max_tokens, base_model_choice], chat_b)
91
+
92
+ stop_btn.click(None, None, None, cancels=[go1, go2])
93
+ clear_btn.click(clear_fn, None, [message, chat_a, chat_b])
94
+
95
+ if __name__ == "__main__":
96
+ demo.launch(show_api=False)
constant.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HEADER_MD = """# Inference-time Alignment with Nudging.
2
+ [📑 Paper](https://arxiv.org/abs/2410.09300) | [🛜 Website](https://fywalter.github.io/nudging/) | [💻 GitHub](https://github.com/fywalter/nudging) | [🐦 X](https://x.com/Walter_Fei/status/1848538273917898753) | 📮 Contact: [Yu Fei](https://fywalter.github.io/)
3
+
4
+ **By injecting a few nudging tokens at inference time, we can make base models able to follow user instructions helpfully and safely.**
5
+ - Our demo is powered by the [Together AI API](https://api.together.ai/). However, since only three base models are currently still available in the serverless API, we only choose three base models and nudging models for demonstration.
6
+ """
7
+
8
+ js_code_label = """
9
+ function addApiKeyLink() {
10
+ // Select the div with id 'api_key'
11
+ const apiKeyDiv = document.getElementById('api_key');
12
+ // Find the span within that div with data-testid 'block-info'
13
+ const blockInfoSpan = apiKeyDiv.querySelector('span[data-testid="block-info"]');
14
+ // Create the new link element
15
+ const newLink = document.createElement('a');
16
+ newLink.href = 'https://api.together.ai/settings/api-keys';
17
+ newLink.textContent = ' View your keys here.';
18
+ newLink.target = '_blank'; // Open link in new tab
19
+ newLink.style = 'color: #007bff; text-decoration: underline;';
20
+ // Create the additional text
21
+ const additionalText = document.createTextNode(' (new account will have free credits to use.)');
22
+ // Append the link and additional text to the span
23
+ if (blockInfoSpan) {
24
+ // add a br
25
+ apiKeyDiv.appendChild(document.createElement('br'));
26
+ apiKeyDiv.appendChild(newLink);
27
+ apiKeyDiv.appendChild(additionalText);
28
+ } else {
29
+ console.error('Span with data-testid "block-info" not found');
30
+ }
31
+ }
32
+ """
33
+
34
+ BASE_MODELS = [
35
+ "Llama-2-70B",
36
+ "Mistral-7B-v0.1",
37
+ "Mixtral-8x7B-v0.1",
38
+ ]
39
+
40
+ NUDGING_MODELS = [
41
+ 'Llama-2-13B-chat',
42
+ 'Gemma-2-2B-it',
43
+ 'Mistral-7B-v0.1-Instruct',
44
+ ]
45
+
46
+
47
+ my_css = """
48
+ /* CSS for a link color that is visible on both black and white backgrounds */
49
+ a {
50
+ color: #1E90FF; /* DodgerBlue */
51
+ text-decoration: none; /* Optional: remove underline */
52
+ }
53
+ a:hover {
54
+ color: #104E8B; /* Slightly darker blue for hover effect */
55
+ text-decoration: underline; /* Optional: add underline on hover */
56
+ }
57
+ """
58
+ # import json
59
+ # with open("together_model_ids.json", "r") as f:
60
+ # TOGETHER_MODEL_IDS = json.load(f)
61
+
62
+ # for _, model_id in MODEL_MAPPING.items():
63
+ # if model_id not in TOGETHER_MODEL_IDS + HYPERBOLIC_MODELS:
64
+ # print(model_id)
65
+
66
+ # Custom CSS for highlighting nudging words in the Chatbot
67
+ custom_css = """
68
+ .chatbot mark {
69
+ background-color: yellow;
70
+ color: orange;
71
+ font-style: italic;
72
+ font-weight: bold;
73
+ }
74
+ """
utils.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from openai import OpenAI
3
+ import os
4
+ import tiktoken
5
+ encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") # use gpt3.5 tokenizer for token number controlling, so we don't need to load the actual tokenizer for API models
6
+
7
+ NUM_LOGPROBS = {
8
+ 'top_prob': 1,
9
+ }
10
+
11
+ MODEL_MAPPING = {
12
+ "Llama-2-70B": "meta-llama/Llama-2-70b-hf",
13
+ "Mistral-7B-v0.1": "mistralai/Mistral-7B-v0.1",
14
+ "Mixtral-8x7B": "mistralai/Mixtral-8x7B-v0.1",
15
+ # Nudging models below
16
+ "Mistral-7B-v0.1-Instruct": "mistralai/Mistral-7B-Instruct-v0.1",
17
+ "Llama-2-13B-chat": "meta-llama/Llama-2-13b-chat-hf",
18
+ "Gemma-2-2B-it": "google/gemma-2b-it",
19
+ }
20
+
21
+ def apply_instruct_template(model_name, system_prompt, instruct_prompt, response_prompt, add_bos=False):
22
+ model_name = model_name.lower()
23
+ # print(model_name)
24
+ if "chat" in model_name and "llama" in model_name and "2" in model_name:
25
+ return llama_2_chat_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
26
+ elif "instruct" in model_name and "llama" in model_name and "3" in model_name:
27
+ if "3.1" in model_name: # for llama-3.1 models, add knowledge cut in system prompmt
28
+ return llama_3_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos, add_knowledge_cut=True)
29
+ else:
30
+ return llama_3_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
31
+ elif "it" in model_name and "gemma" in model_name:
32
+ return gemma_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
33
+ elif "instruct" in model_name and "olmo" in model_name:
34
+ return olmo_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
35
+ elif "instruct" in model_name and "mistral" in model_name:
36
+ return mistral_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=True)
37
+ else:
38
+ return f"{system_prompt}\n{instruct_prompt}\n{response_prompt}" # non-instruct model or models with unknown template
39
+
40
+ def mistral_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=True):
41
+ """
42
+ Convert the input and output into the template used for the mistral instruct models training.
43
+ """
44
+ prefix = "<s>" if add_bos else ""
45
+ return prefix + f"[INST] {system_prompt}\n{instruct_prompt} [/INST] {response_prompt}"
46
+
47
+ def llama_2_chat_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
48
+ """
49
+ Convert the input and output into the template used for the llama-2 chat models training.
50
+ """
51
+ prefix = "<s>" if add_bos else ""
52
+ return prefix + f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{instruct_prompt} [/INST] {response_prompt.lstrip()}" # for most servers that add <s> automatically so we don't need to add it here
53
+
54
+ def llama_3_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False, add_knowledge_cut=False):
55
+ """
56
+ Convert the input and output into the template used for the llama-3 instruct models training.
57
+ """
58
+ # print("applying llama-3 instruct template")
59
+ prefix = "<|begin_of_text|>" if add_bos else ""
60
+ if add_knowledge_cut:
61
+ system_prompt = f"Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"+ system_prompt
62
+ return prefix + f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{instruct_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{response_prompt}"
63
+
64
+ def gemma_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
65
+ """
66
+ Convert the input and output into the template used for the gemma instruct models training.
67
+ <bos><start_of_turn>user
68
+ Write a hello world program<end_of_turn>
69
+ <start_of_turn>model
70
+ """
71
+ prefix = "<bos>" if add_bos else ""
72
+ return prefix + f"<start_of_turn>user\n{system_prompt}\n{instruct_prompt}<end_of_turn>\n<start_of_turn>model\n{response_prompt}"
73
+
74
+ def olmo_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
75
+ """
76
+ Convert the input and output into the template used for the olmo instruct models training.
77
+ """
78
+ return f"<|endoftext|><|user|>\n{system_prompt}\n{instruct_prompt}\n<|assistant|>\n{response_prompt}"
79
+
80
+ def find_longest_repeated_suffix(s):
81
+
82
+ # Helper function to check if a substring repeats
83
+ def has_repeated(s, length):
84
+ if length < 30:
85
+ return False
86
+ # Extract the suffix of length 'length'
87
+ suffix = s[-length:]
88
+ # Check the rest of the string for another occurrence
89
+ # return s[:-length].find(suffix) != -1
90
+ return s[:-length].endswith(suffix)
91
+
92
+ left, right = 0, len(s)
93
+ result = 0
94
+
95
+ # Binary search for the longest repeated suffix
96
+ while left <= right:
97
+ mid = (left + right) // 2
98
+ if has_repeated(s, mid):
99
+ result = mid # Store the longest length found
100
+ left = mid + 1 # Try for a longer suffix
101
+ else:
102
+ right = mid - 1 # Try for a shorter suffix
103
+
104
+ # Return the longest repeated suffix
105
+ if result > 0:
106
+ return s[-result:]
107
+ return None # Return an empty string if no repetition is found
108
+
109
+ def remove_redundant_repetitions(s):
110
+ s = s.strip()
111
+ # Find the longest repeated suffix
112
+ longest_repeated_suffix = find_longest_repeated_suffix(s)
113
+ while longest_repeated_suffix:
114
+ # Remove the longest repeated suffix
115
+ s = s[:-len(longest_repeated_suffix)]
116
+ # Find the longest repeated suffix again
117
+ longest_repeated_suffix = find_longest_repeated_suffix(s)
118
+ return s
119
+
120
+ def repetition_check(new_completion, full_prefix, subseq_len=5):
121
+ words = new_completion.split(" ")
122
+ if len(words) > subseq_len and new_completion in full_prefix:
123
+ return True
124
+ return False
125
+
126
+ def convert_token_logprobs_to_top_logprobs(token_logprobs, tokens):
127
+ """
128
+ Together AI now only returns token logprobs, this function converts token logprobs to top logprobs format: {token: logprob}
129
+ """
130
+ top_logprobs = [{token: logprob} for token, logprob in zip(tokens, token_logprobs)]
131
+ return top_logprobs
132
+
133
+ def check_need_nudging(nudging_method,
134
+ base_token_id,
135
+ current_base_info,
136
+ thresholds,
137
+ ):
138
+ if nudging_method == 'top_prob':
139
+ # check if the token prob is below the threshold
140
+ sorted_base_top_logprobs = {k: v for k, v in sorted(current_base_info["top_logprobs"][base_token_id].items(), key=lambda item: item[1], reverse=True)}
141
+ base_top_prob = np.exp(list(sorted_base_top_logprobs.values())[0])
142
+ need_nudging = base_top_prob < thresholds['top_prob']
143
+ else:
144
+ raise ValueError(f"Unknown nudging method {nudging_method}")
145
+ return need_nudging
146
+
147
+ def complete_with_base(nudging_method='top_prob',
148
+ base_model="davinci-002",
149
+ full_prefix_base="",
150
+ output="",
151
+ current_base_info=None,
152
+ max_completion_token=256,
153
+ completion_token_num=16,
154
+ client_base=None,
155
+ thresholds=None,
156
+ temperature=0.0,
157
+ top_p=0.9,
158
+ ):
159
+ completion_base = "" if len(current_base_info["completion"]) == 0 else current_base_info["tokens"][0] # accept the first token from the 1st round which is the acc token from the first stage
160
+ completion_all = "" if len(current_base_info["completion"]) == 0 else current_base_info["tokens"][0] # completion_all records all the tokens from the base model including the tokens that are not accepted in the last round, for debugging and visualization
161
+ found_nudging_token = False
162
+ response = None
163
+ has_acc_token_stage_1 = True if len(current_base_info["completion"]) > 0 else False # if the current_base_info["completion"] is not empty, it means the first token in base completion is accepted from the 1st stage
164
+ EMPTY_INFO_DICT = {
165
+ "completion": "",
166
+ "tokens": [],
167
+ "top_logprobs": [],
168
+ "stop_reason": None,
169
+ "num_logprobs": NUM_LOGPROBS[nudging_method],
170
+ }
171
+ next_nudging_info = EMPTY_INFO_DICT # for nudging methods that compute nudging info during base completion, we can save the info for the next round, currently not used for top_prob nudging
172
+ while len(encoding.encode(completion_base)) < max_completion_token and not found_nudging_token:
173
+
174
+ if current_base_info["completion"] == "":
175
+ # complete the sentence using the base model
176
+ response = client_base.completions.create(
177
+ model=base_model,
178
+ prompt=full_prefix_base + output + completion_base,
179
+ max_tokens=completion_token_num,
180
+ temperature=temperature,
181
+ logprobs=current_base_info["num_logprobs"],
182
+ top_p=top_p,
183
+ )
184
+ current_base_info["tokens"] = response.choices[0].logprobs.tokens
185
+ current_base_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
186
+ if current_base_info["top_logprobs"] is None:
187
+ current_base_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_base_info["tokens"])
188
+ current_base_info["completion"] = response.choices[0].text
189
+
190
+ if has_acc_token_stage_1:
191
+ # pop the first token from the 1st round as it is already accepted from stage 1
192
+ current_base_info["tokens"] = current_base_info["tokens"][1:]
193
+ current_base_info["top_logprobs"] = current_base_info["top_logprobs"][1:]
194
+ current_base_info["completion"] = "".join(current_base_info["tokens"])
195
+ has_acc_token_stage_1 = False
196
+
197
+ completion = current_base_info["completion"]
198
+ tokens = current_base_info["tokens"]
199
+
200
+ if completion in completion_base:
201
+ break # repeated completion, break
202
+
203
+ nudging_position = -1
204
+
205
+ # find the first token that violates the nudging criteria
206
+ for base_idx in range(len(tokens)):
207
+ found_nudging_token = check_need_nudging(nudging_method=nudging_method, base_token_id=base_idx, current_base_info=current_base_info, thresholds=thresholds)
208
+ if found_nudging_token:
209
+ nudging_position = base_idx
210
+ break
211
+
212
+ if nudging_position == -1:
213
+ new_completion= "".join(tokens)
214
+ else:
215
+ new_completion = "".join(tokens[:nudging_position]) # include the last agreed token
216
+ # avoid repetition in answer
217
+ if repetition_check(new_completion, output + completion_base):
218
+ break
219
+ else:
220
+ completion_base += new_completion
221
+
222
+ if found_nudging_token: # if found the nudging token, break the loop, concat the last base completion to completion_all
223
+ completion_all += completion
224
+ else:
225
+ completion_all += new_completion
226
+
227
+ next_nudging_info = EMPTY_INFO_DICT
228
+ if response is not None and response.choices[0].finish_reason == "stop":
229
+ break
230
+
231
+ # reset the current_base_info
232
+ current_base_info['completion'] = ""
233
+ current_base_info['tokens'] = []
234
+ current_base_info['top_logprobs'] = []
235
+
236
+ return completion_base, completion_all, next_nudging_info
237
+
238
+ def completion_with_nudging(
239
+ base_model="davinci-002",
240
+ nudging_model="gpt-3.5-turbo",
241
+ system_prompt_base="Answer the question by walking through the reasoning step by step.",
242
+ system_prompt_nudging="Answer the question by walking through the reasoning step by step.",
243
+ question="",
244
+ context="",
245
+ question_prompt="Question: ",
246
+ answer_start_prompt_base="Answer: ",
247
+ answer_start_prompt_nudging="Answer: ",
248
+ completion_token_num=16,
249
+ completion_token_num_nudging=16,
250
+ max_token_total=256,
251
+ print_intermediate_output=False,
252
+ client=None, # default client
253
+ client_base=None,
254
+ client_nudging=None,
255
+ max_round=150,
256
+ nudging_temperature=0.0, # deterministic for nudging
257
+ base_temperature=0.0, # deterministic for base model
258
+ nudging_method='top_prob',
259
+ top_prob_thres=0.3,
260
+ top_p=0.9,
261
+ ):
262
+ if client_base is None:
263
+ client_base = client
264
+ if client_nudging is None:
265
+ client_nudging = client
266
+
267
+ if nudging_method not in NUM_LOGPROBS.keys():
268
+ raise ValueError(f"nudging method {nudging_method} number of logprobs not defined")
269
+
270
+ full_prefix_base = apply_instruct_template(base_model, system_prompt_base, context + question_prompt + question, answer_start_prompt_base) # for base model this function just adds newlines
271
+ full_prefix_nudging = apply_instruct_template(nudging_model, system_prompt_nudging, context + question_prompt + question, answer_start_prompt_nudging)
272
+
273
+ thresholds = {
274
+ 'top_prob': top_prob_thres,
275
+ }
276
+
277
+ output = ""
278
+ nudging_round = 0
279
+ all_nudging_words = []
280
+ all_nudging_and_completions = []
281
+ current_nudging_info = {
282
+ "completion": "",
283
+ "tokens": [],
284
+ "top_logprobs": [],
285
+ "stop_reason": None,
286
+ "num_logprobs": NUM_LOGPROBS[nudging_method],
287
+ }
288
+ stop_reason = None
289
+ repeat_nudging_word = 0
290
+ last_nudging_word = ""
291
+ while len(encoding.encode(output)) < max_token_total and nudging_round < max_round: # use the number of gpt-3.5 token to approximately control the length
292
+ nudging_round += 1
293
+ if current_nudging_info["completion"] == "":
294
+ response = client_nudging.completions.create(
295
+ model=nudging_model,
296
+ prompt=full_prefix_nudging + output,
297
+ max_tokens=completion_token_num_nudging,
298
+ temperature=nudging_temperature,
299
+ logprobs=current_nudging_info["num_logprobs"],
300
+ )
301
+ current_nudging_info["completion"] = response.choices[0].text
302
+ current_nudging_info["tokens"] = response.choices[0].logprobs.tokens
303
+ current_nudging_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
304
+ if current_nudging_info["top_logprobs"] is None:
305
+ current_nudging_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_nudging_info["tokens"])
306
+ current_nudging_info["stop_reason"] = response.choices[0].finish_reason
307
+
308
+ # if finish_reason is stop, break the loop, also handles nudging completion from previous round
309
+ if current_nudging_info["stop_reason"] == "stop":
310
+ stop_reason = "nudging_model_stop"
311
+ if len(current_nudging_info["completion"]) > 0:
312
+ all_nudging_words.append(current_nudging_info["completion"])
313
+ all_nudging_and_completions.append(current_nudging_info["completion"])
314
+ output += current_nudging_info["completion"]
315
+ break
316
+
317
+ # ===================================================================
318
+ # Stage 1: use base model to find the first token that violates the nudging criteria (no need to nudge)
319
+ # ===================================================================
320
+ found_acc_token = False
321
+ current_base_info = { # will be passed to the next stage
322
+ "completion": "",
323
+ "tokens": [],
324
+ "top_logprobs": [],
325
+ "num_logprobs": NUM_LOGPROBS[nudging_method],
326
+ }
327
+ nudging_text = current_nudging_info["completion"]
328
+ num_whitespaces = len(nudging_text) - len(nudging_text.lstrip(" "))
329
+ space_prefix = " " * num_whitespaces
330
+ current_nudging_words = nudging_text.lstrip(" ").split(" ") # token leads to some unexpected behaviors, still use nudging word
331
+ nudging_word_id = 0 if len(current_nudging_words) > 1 else 1 # if only one word, always accept the word and go to the next round: it won't go into the loop and found_acc_token will be False
332
+ while not found_acc_token and nudging_word_id < len(current_nudging_words) - 1:
333
+ nudging_word_id += 1 # always accept the first word
334
+ nudging_gen_prefix = space_prefix + " ".join(current_nudging_words[:nudging_word_id])
335
+ current_nudging_word = " " + current_nudging_words[nudging_word_id] # add a leading space to the current nudging word since the nudging words a split by space
336
+ if current_nudging_word == " ": # skip the multiple space
337
+ continue
338
+ prefix = full_prefix_base + output + nudging_gen_prefix
339
+ response = client_base.completions.create(
340
+ model=base_model,
341
+ prompt=prefix,
342
+ max_tokens=completion_token_num,
343
+ temperature=base_temperature,
344
+ logprobs=current_base_info["num_logprobs"],
345
+ top_p=top_p,
346
+ )
347
+ current_base_info["tokens"] = response.choices[0].logprobs.tokens
348
+ current_base_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
349
+ if current_base_info["top_logprobs"] is None:
350
+ current_base_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_base_info["tokens"])
351
+ current_base_info["completion"] = response.choices[0].text
352
+
353
+ # look for the first token that meets the nudging criteria
354
+ first_base_token = current_base_info["tokens"][0]
355
+ if current_nudging_word.startswith(first_base_token): # check if the current nudging word is the same or starts with the first base token
356
+ found_acc_token = True
357
+ else:
358
+ found_acc_token = not check_need_nudging(nudging_method, # check if the token violates the nudging criteria (no need to nudge)
359
+ base_token_id=0,
360
+ current_base_info=current_base_info,
361
+ thresholds=thresholds)
362
+
363
+ # here we have either prefix_idx == len(current_nudging_info["tokens"]): if no token meets the nudging criteria, use the current nudging completion
364
+ # or found_acc_token == True: if a token violates the nudging criteria, we use the prefix as nudging tokens
365
+
366
+ nudging_words = space_prefix + " ".join(current_nudging_words[:nudging_word_id])
367
+
368
+ # Heuristic: if the nudging words are the same as the last one for three rounds, break the loop
369
+ if nudging_words == last_nudging_word:
370
+ repeat_nudging_word += 1
371
+ if repeat_nudging_word >= 3:
372
+ stop_reason = "repeated_nudging_words"
373
+ break
374
+ else:
375
+ last_nudging_word = nudging_words
376
+ repeat_nudging_word = 0
377
+ all_nudging_words.append(nudging_words)
378
+ output += nudging_words
379
+
380
+ if not found_acc_token: # if no base token can be accepted, use the current nudging completion and go to the next round
381
+ all_nudging_and_completions.append(nudging_words)
382
+ # reset the current nudging info and continue to the next round
383
+ current_nudging_info = {
384
+ "completion": "",
385
+ "tokens": [],
386
+ "logprobs": [],
387
+ "stop_reason": None,
388
+ "num_logprobs": NUM_LOGPROBS[nudging_method],
389
+ }
390
+ continue
391
+ if current_base_info["completion"] == "": # the base model thinks the completion is done, go to the next round. Make sure current_base_info["completion"] is not empty if proceed to the next stage
392
+ all_nudging_and_completions.append(nudging_words)
393
+ current_nudging_info = {
394
+ "completion": "",
395
+ "tokens": [],
396
+ "logprobs": [],
397
+ "stop_reason": None,
398
+ "num_logprobs": NUM_LOGPROBS[nudging_method],
399
+ }
400
+ continue
401
+
402
+ # ===================================================================
403
+ # Stage 2: use nudging model to find the first token that meets the nudging criteria (need to nudge)
404
+ # ===================================================================
405
+ max_completion_token = max_token_total - len(encoding.encode(output))
406
+ completion_base, completion_base_all, current_nudging_info = complete_with_base(nudging_method=nudging_method,
407
+ base_model=base_model,
408
+ full_prefix_base=full_prefix_base,
409
+ output=output,
410
+ current_base_info=current_base_info,
411
+ max_completion_token=max_completion_token,
412
+ completion_token_num=completion_token_num,
413
+ client_base=client_base,
414
+ thresholds=thresholds,
415
+ temperature=base_temperature,
416
+ top_p=top_p,
417
+ )
418
+ # print(f"next_nudging_info: {current_nudging_info}") # debug
419
+
420
+ output += completion_base
421
+ all_nudging_and_completions.append(nudging_words + completion_base) # the generated tokens in each round, concating all completion would be the final output
422
+ if print_intermediate_output:
423
+ print(f"************nudging round {nudging_round}************")
424
+ print(f"****nudging words from {nudging_model}****: {nudging_words}")
425
+ print(f"****nudging text****: {nudging_text}")
426
+ print(f"****completion from {base_model}****: {completion_base}")
427
+ print(f"****all completion from {base_model}****: {completion_base_all}")
428
+ print(f"****output****: {output}")
429
+
430
+ if nudging_round >= max_round and not stop_reason:
431
+ stop_reason = "round"
432
+ if len(encoding.encode(output)) >= max_token_total and not stop_reason:
433
+ stop_reason = "length"
434
+ output = remove_redundant_repetitions(output)
435
+ if print_intermediate_output:
436
+ print(f"************final output************")
437
+ print(f"****output****: {output}")
438
+
439
+ all_info = {
440
+ "question": question,
441
+ "context": context,
442
+ "raw_answer": output,
443
+ "all_nudging_words": all_nudging_words,
444
+ "all_completions": all_nudging_and_completions,
445
+ "stop_reason": stop_reason,
446
+ "system_prompt_base": system_prompt_base,
447
+ "system_prompt_nudging": system_prompt_nudging,
448
+ "full_prefix_base": full_prefix_base,
449
+ "full_prefix_nudging": full_prefix_nudging,
450
+ }
451
+ return all_info
452
+
453
+
454
+ def get_nudging_answer(base_model,
455
+ nudging_model,
456
+ system_prompt,
457
+ question,
458
+ context="",
459
+ question_prompt="",
460
+ answer_start_prompt_base="",
461
+ answer_start_prompt_nudging="",
462
+ completion_token_num=16,
463
+ completion_token_num_nudging=16,
464
+ max_token_total=256,
465
+ max_round=150,
466
+ nudging_temperature=0.0,
467
+ base_temperature=0.0,
468
+ nudging_method='top_prob',
469
+ top_prob_thres=0.3,
470
+ ):
471
+ base_model = MODEL_MAPPING[base_model]
472
+ nudging_model = MODEL_MAPPING[nudging_model]
473
+ # with open('TOGETHER_KEY.txt', 'r') as f:
474
+ # togetherai_api_key = f.read().strip()
475
+ togetherai_api_key = os.environ.get("TOGETHERAI_API_KEY")
476
+ client = OpenAI(
477
+ api_key=togetherai_api_key,
478
+ base_url="https://api.together.xyz/v1",
479
+ )
480
+ return completion_with_nudging(
481
+ base_model=base_model,
482
+ nudging_model=nudging_model,
483
+ system_prompt_base=system_prompt,
484
+ system_prompt_nudging=system_prompt,
485
+ question=question,
486
+ context=context,
487
+ question_prompt=question_prompt,
488
+ answer_start_prompt_base=answer_start_prompt_base,
489
+ answer_start_prompt_nudging=answer_start_prompt_nudging,
490
+ completion_token_num=completion_token_num,
491
+ completion_token_num_nudging=completion_token_num_nudging,
492
+ max_token_total=max_token_total,
493
+ print_intermediate_output=False,
494
+ client_base=client,
495
+ client_nudging=client,
496
+ max_round=max_round,
497
+ nudging_temperature=nudging_temperature,
498
+ base_temperature=base_temperature,
499
+ nudging_method=nudging_method,
500
+ top_prob_thres=top_prob_thres,
501
+ )
502
+
503
+ def get_base_answer(base_model,
504
+ system_prompt,
505
+ question,
506
+ max_tokens=256,):
507
+ base_model = MODEL_MAPPING[base_model]
508
+ # with open('TOGETHER_KEY.txt', 'r') as f:
509
+ # togetherai_api_key = f.read().strip()
510
+ togetherai_api_key = os.environ.get("TOGETHERAI_API_KEY")
511
+ client = OpenAI(
512
+ api_key=togetherai_api_key,
513
+ base_url="https://api.together.xyz/v1",
514
+ )
515
+ response = client.completions.create(
516
+ model=base_model,
517
+ prompt=system_prompt+"\n"+ question,
518
+ max_tokens=max_tokens,
519
+ temperature=0.0,
520
+ logprobs=1,
521
+ )
522
+ return response.choices[0].text