Spaces:
Running
Running
import numpy as np | |
from openai import OpenAI | |
import os | |
import tiktoken | |
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") # use gpt3.5 tokenizer for token number controlling, so we don't need to load the actual tokenizer for API models | |
NUM_LOGPROBS = { | |
'top_prob': 1, | |
} | |
MODEL_MAPPING = { | |
"Llama-2-70B": "meta-llama/Llama-2-70b-hf", | |
"Mistral-7B-v0.1": "mistralai/Mistral-7B-v0.1", | |
"Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1", | |
# Nudging models below | |
"Mistral-7B-v0.1-Instruct": "mistralai/Mistral-7B-Instruct-v0.1", | |
"Llama-2-13B-chat": "meta-llama/Llama-2-13b-chat-hf", | |
"Gemma-2-2B-it": "google/gemma-2b-it", | |
} | |
def apply_instruct_template(model_name, system_prompt, instruct_prompt, response_prompt, add_bos=False): | |
model_name = model_name.lower() | |
# print(model_name) | |
if "chat" in model_name and "llama" in model_name and "2" in model_name: | |
return llama_2_chat_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos) | |
elif "instruct" in model_name and "llama" in model_name and "3" in model_name: | |
if "3.1" in model_name: # for llama-3.1 models, add knowledge cut in system prompmt | |
return llama_3_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos, add_knowledge_cut=True) | |
else: | |
return llama_3_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos) | |
elif "it" in model_name and "gemma" in model_name: | |
return gemma_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos) | |
elif "instruct" in model_name and "olmo" in model_name: | |
return olmo_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos) | |
elif "instruct" in model_name and "mistral" in model_name: | |
return mistral_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=True) | |
else: | |
return f"{system_prompt}\n{instruct_prompt}\n{response_prompt}" # non-instruct model or models with unknown template | |
def mistral_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=True): | |
""" | |
Convert the input and output into the template used for the mistral instruct models training. | |
""" | |
prefix = "<s>" if add_bos else "" | |
return prefix + f"[INST] {system_prompt}\n{instruct_prompt} [/INST] {response_prompt}" | |
def llama_2_chat_template(system_prompt, instruct_prompt, response_prompt, add_bos=False): | |
""" | |
Convert the input and output into the template used for the llama-2 chat models training. | |
""" | |
prefix = "<s>" if add_bos else "" | |
return prefix + f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{instruct_prompt} [/INST] {response_prompt.lstrip()}" # for most servers that add <s> automatically so we don't need to add it here | |
def llama_3_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False, add_knowledge_cut=False): | |
""" | |
Convert the input and output into the template used for the llama-3 instruct models training. | |
""" | |
# print("applying llama-3 instruct template") | |
prefix = "<|begin_of_text|>" if add_bos else "" | |
if add_knowledge_cut: | |
system_prompt = f"Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"+ system_prompt | |
return prefix + f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{instruct_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{response_prompt}" | |
def gemma_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False): | |
""" | |
Convert the input and output into the template used for the gemma instruct models training. | |
<bos><start_of_turn>user | |
Write a hello world program<end_of_turn> | |
<start_of_turn>model | |
""" | |
prefix = "<bos>" if add_bos else "" | |
return prefix + f"<start_of_turn>user\n{system_prompt}\n{instruct_prompt}<end_of_turn>\n<start_of_turn>model\n{response_prompt}" | |
def olmo_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False): | |
""" | |
Convert the input and output into the template used for the olmo instruct models training. | |
""" | |
return f"<|endoftext|><|user|>\n{system_prompt}\n{instruct_prompt}\n<|assistant|>\n{response_prompt}" | |
def find_longest_repeated_suffix(s): | |
# Helper function to check if a substring repeats | |
def has_repeated(s, length): | |
if length < 30: | |
return False | |
# Extract the suffix of length 'length' | |
suffix = s[-length:] | |
# Check the rest of the string for another occurrence | |
# return s[:-length].find(suffix) != -1 | |
return s[:-length].endswith(suffix) | |
left, right = 0, len(s) | |
result = 0 | |
# Binary search for the longest repeated suffix | |
while left <= right: | |
mid = (left + right) // 2 | |
if has_repeated(s, mid): | |
result = mid # Store the longest length found | |
left = mid + 1 # Try for a longer suffix | |
else: | |
right = mid - 1 # Try for a shorter suffix | |
# Return the longest repeated suffix | |
if result > 0: | |
return s[-result:] | |
return None # Return an empty string if no repetition is found | |
def remove_redundant_repetitions(s): | |
s = s.strip() | |
# Find the longest repeated suffix | |
longest_repeated_suffix = find_longest_repeated_suffix(s) | |
while longest_repeated_suffix: | |
# Remove the longest repeated suffix | |
s = s[:-len(longest_repeated_suffix)] | |
# Find the longest repeated suffix again | |
longest_repeated_suffix = find_longest_repeated_suffix(s) | |
return s | |
def repetition_check(new_completion, full_prefix, subseq_len=5): | |
words = new_completion.split(" ") | |
if len(words) > subseq_len and new_completion in full_prefix: | |
return True | |
return False | |
def convert_token_logprobs_to_top_logprobs(token_logprobs, tokens): | |
""" | |
Together AI now only returns token logprobs, this function converts token logprobs to top logprobs format: {token: logprob} | |
""" | |
top_logprobs = [{token: logprob} for token, logprob in zip(tokens, token_logprobs)] | |
return top_logprobs | |
def check_need_nudging(nudging_method, | |
base_token_id, | |
current_base_info, | |
thresholds, | |
): | |
if nudging_method == 'top_prob': | |
# check if the token prob is below the threshold | |
sorted_base_top_logprobs = {k: v for k, v in sorted(current_base_info["top_logprobs"][base_token_id].items(), key=lambda item: item[1], reverse=True)} | |
base_top_prob = np.exp(list(sorted_base_top_logprobs.values())[0]) | |
need_nudging = base_top_prob < thresholds['top_prob'] | |
else: | |
raise ValueError(f"Unknown nudging method {nudging_method}") | |
return need_nudging | |
def complete_with_base(nudging_method='top_prob', | |
base_model="davinci-002", | |
full_prefix_base="", | |
output="", | |
current_base_info=None, | |
max_completion_token=256, | |
completion_token_num=16, | |
client_base=None, | |
thresholds=None, | |
temperature=0.0, | |
top_p=0.9, | |
): | |
completion_base = "" if len(current_base_info["completion"]) == 0 else current_base_info["tokens"][0] # accept the first token from the 1st round which is the acc token from the first stage | |
completion_all = "" if len(current_base_info["completion"]) == 0 else current_base_info["tokens"][0] # completion_all records all the tokens from the base model including the tokens that are not accepted in the last round, for debugging and visualization | |
found_nudging_token = False | |
response = None | |
has_acc_token_stage_1 = True if len(current_base_info["completion"]) > 0 else False # if the current_base_info["completion"] is not empty, it means the first token in base completion is accepted from the 1st stage | |
EMPTY_INFO_DICT = { | |
"completion": "", | |
"tokens": [], | |
"top_logprobs": [], | |
"stop_reason": None, | |
"num_logprobs": NUM_LOGPROBS[nudging_method], | |
} | |
next_nudging_info = EMPTY_INFO_DICT # for nudging methods that compute nudging info during base completion, we can save the info for the next round, currently not used for top_prob nudging | |
while len(encoding.encode(completion_base)) < max_completion_token and not found_nudging_token: | |
if current_base_info["completion"] == "": | |
# complete the sentence using the base model | |
response = client_base.completions.create( | |
model=base_model, | |
prompt=full_prefix_base + output + completion_base, | |
max_tokens=completion_token_num, | |
temperature=temperature, | |
logprobs=current_base_info["num_logprobs"], | |
top_p=top_p, | |
) | |
current_base_info["tokens"] = response.choices[0].logprobs.tokens | |
current_base_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs | |
if current_base_info["top_logprobs"] is None: | |
current_base_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_base_info["tokens"]) | |
current_base_info["completion"] = response.choices[0].text | |
if has_acc_token_stage_1: | |
# pop the first token from the 1st round as it is already accepted from stage 1 | |
current_base_info["tokens"] = current_base_info["tokens"][1:] | |
current_base_info["top_logprobs"] = current_base_info["top_logprobs"][1:] | |
current_base_info["completion"] = "".join(current_base_info["tokens"]) | |
has_acc_token_stage_1 = False | |
completion = current_base_info["completion"] | |
tokens = current_base_info["tokens"] | |
if completion in completion_base: | |
break # repeated completion, break | |
nudging_position = -1 | |
# find the first token that violates the nudging criteria | |
for base_idx in range(len(tokens)): | |
found_nudging_token = check_need_nudging(nudging_method=nudging_method, base_token_id=base_idx, current_base_info=current_base_info, thresholds=thresholds) | |
if found_nudging_token: | |
nudging_position = base_idx | |
break | |
if nudging_position == -1: | |
new_completion= "".join(tokens) | |
else: | |
new_completion = "".join(tokens[:nudging_position]) # include the last agreed token | |
# avoid repetition in answer | |
if repetition_check(new_completion, output + completion_base): | |
break | |
else: | |
completion_base += new_completion | |
if found_nudging_token: # if found the nudging token, break the loop, concat the last base completion to completion_all | |
completion_all += completion | |
else: | |
completion_all += new_completion | |
next_nudging_info = EMPTY_INFO_DICT | |
if response is not None and response.choices[0].finish_reason == "stop": | |
break | |
# reset the current_base_info | |
current_base_info['completion'] = "" | |
current_base_info['tokens'] = [] | |
current_base_info['top_logprobs'] = [] | |
return completion_base, completion_all, next_nudging_info | |
def completion_with_nudging( | |
base_model="davinci-002", | |
nudging_model="gpt-3.5-turbo", | |
system_prompt_base="Answer the question by walking through the reasoning step by step.", | |
system_prompt_nudging="Answer the question by walking through the reasoning step by step.", | |
question="", | |
context="", | |
question_prompt="Question: ", | |
answer_start_prompt_base="Answer: ", | |
answer_start_prompt_nudging="Answer: ", | |
completion_token_num=16, | |
completion_token_num_nudging=16, | |
max_token_total=256, | |
print_intermediate_output=False, | |
client=None, # default client | |
client_base=None, | |
client_nudging=None, | |
max_round=150, | |
nudging_temperature=0.0, # deterministic for nudging | |
base_temperature=0.0, # deterministic for base model | |
nudging_method='top_prob', | |
top_prob_thres=0.3, | |
top_p=0.9, | |
): | |
if client_base is None: | |
client_base = client | |
if client_nudging is None: | |
client_nudging = client | |
if nudging_method not in NUM_LOGPROBS.keys(): | |
raise ValueError(f"nudging method {nudging_method} number of logprobs not defined") | |
full_prefix_base = apply_instruct_template(base_model, system_prompt_base, context + question_prompt + question, answer_start_prompt_base) # for base model this function just adds newlines | |
full_prefix_nudging = apply_instruct_template(nudging_model, system_prompt_nudging, context + question_prompt + question, answer_start_prompt_nudging) | |
thresholds = { | |
'top_prob': top_prob_thres, | |
} | |
output = "" | |
nudging_round = 0 | |
all_nudging_words = [] | |
all_nudging_and_completions = [] | |
current_nudging_info = { | |
"completion": "", | |
"tokens": [], | |
"top_logprobs": [], | |
"stop_reason": None, | |
"num_logprobs": NUM_LOGPROBS[nudging_method], | |
} | |
stop_reason = None | |
repeat_nudging_word = 0 | |
last_nudging_word = "" | |
while len(encoding.encode(output)) < max_token_total and nudging_round < max_round: # use the number of gpt-3.5 token to approximately control the length | |
nudging_round += 1 | |
if current_nudging_info["completion"] == "": | |
response = client_nudging.completions.create( | |
model=nudging_model, | |
prompt=full_prefix_nudging + output, | |
max_tokens=completion_token_num_nudging, | |
temperature=nudging_temperature, | |
logprobs=current_nudging_info["num_logprobs"], | |
) | |
current_nudging_info["completion"] = response.choices[0].text | |
current_nudging_info["tokens"] = response.choices[0].logprobs.tokens | |
current_nudging_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs | |
if current_nudging_info["top_logprobs"] is None: | |
current_nudging_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_nudging_info["tokens"]) | |
current_nudging_info["stop_reason"] = response.choices[0].finish_reason | |
# if finish_reason is stop, break the loop, also handles nudging completion from previous round | |
if current_nudging_info["stop_reason"] == "stop": | |
stop_reason = "nudging_model_stop" | |
if len(current_nudging_info["completion"]) > 0: | |
all_nudging_words.append(current_nudging_info["completion"]) | |
all_nudging_and_completions.append(current_nudging_info["completion"]) | |
output += current_nudging_info["completion"] | |
break | |
# =================================================================== | |
# Stage 1: use base model to find the first token that violates the nudging criteria (no need to nudge) | |
# =================================================================== | |
found_acc_token = False | |
current_base_info = { # will be passed to the next stage | |
"completion": "", | |
"tokens": [], | |
"top_logprobs": [], | |
"num_logprobs": NUM_LOGPROBS[nudging_method], | |
} | |
nudging_text = current_nudging_info["completion"] | |
num_whitespaces = len(nudging_text) - len(nudging_text.lstrip(" ")) | |
space_prefix = " " * num_whitespaces | |
current_nudging_words = nudging_text.lstrip(" ").split(" ") # token leads to some unexpected behaviors, still use nudging word | |
nudging_word_id = 0 if len(current_nudging_words) > 1 else 1 # if only one word, always accept the word and go to the next round: it won't go into the loop and found_acc_token will be False | |
while not found_acc_token and nudging_word_id < len(current_nudging_words) - 1: | |
nudging_word_id += 1 # always accept the first word | |
nudging_gen_prefix = space_prefix + " ".join(current_nudging_words[:nudging_word_id]) | |
current_nudging_word = " " + current_nudging_words[nudging_word_id] # add a leading space to the current nudging word since the nudging words a split by space | |
if current_nudging_word == " ": # skip the multiple space | |
continue | |
prefix = full_prefix_base + output + nudging_gen_prefix | |
response = client_base.completions.create( | |
model=base_model, | |
prompt=prefix, | |
max_tokens=completion_token_num, | |
temperature=base_temperature, | |
logprobs=current_base_info["num_logprobs"], | |
top_p=top_p, | |
) | |
current_base_info["tokens"] = response.choices[0].logprobs.tokens | |
current_base_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs | |
if current_base_info["top_logprobs"] is None: | |
current_base_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_base_info["tokens"]) | |
current_base_info["completion"] = response.choices[0].text | |
# look for the first token that meets the nudging criteria | |
first_base_token = current_base_info["tokens"][0] | |
if current_nudging_word.startswith(first_base_token): # check if the current nudging word is the same or starts with the first base token | |
found_acc_token = True | |
else: | |
found_acc_token = not check_need_nudging(nudging_method, # check if the token violates the nudging criteria (no need to nudge) | |
base_token_id=0, | |
current_base_info=current_base_info, | |
thresholds=thresholds) | |
# here we have either prefix_idx == len(current_nudging_info["tokens"]): if no token meets the nudging criteria, use the current nudging completion | |
# or found_acc_token == True: if a token violates the nudging criteria, we use the prefix as nudging tokens | |
nudging_words = space_prefix + " ".join(current_nudging_words[:nudging_word_id]) | |
# Heuristic: if the nudging words are the same as the last one for three rounds, break the loop | |
if nudging_words == last_nudging_word: | |
repeat_nudging_word += 1 | |
if repeat_nudging_word >= 3: | |
stop_reason = "repeated_nudging_words" | |
break | |
else: | |
last_nudging_word = nudging_words | |
repeat_nudging_word = 0 | |
all_nudging_words.append(nudging_words) | |
output += nudging_words | |
if not found_acc_token: # if no base token can be accepted, use the current nudging completion and go to the next round | |
all_nudging_and_completions.append(nudging_words) | |
# reset the current nudging info and continue to the next round | |
current_nudging_info = { | |
"completion": "", | |
"tokens": [], | |
"logprobs": [], | |
"stop_reason": None, | |
"num_logprobs": NUM_LOGPROBS[nudging_method], | |
} | |
continue | |
if current_base_info["completion"] == "": # the base model thinks the completion is done, go to the next round. Make sure current_base_info["completion"] is not empty if proceed to the next stage | |
all_nudging_and_completions.append(nudging_words) | |
current_nudging_info = { | |
"completion": "", | |
"tokens": [], | |
"logprobs": [], | |
"stop_reason": None, | |
"num_logprobs": NUM_LOGPROBS[nudging_method], | |
} | |
continue | |
# =================================================================== | |
# Stage 2: use nudging model to find the first token that meets the nudging criteria (need to nudge) | |
# =================================================================== | |
max_completion_token = max_token_total - len(encoding.encode(output)) | |
completion_base, completion_base_all, current_nudging_info = complete_with_base(nudging_method=nudging_method, | |
base_model=base_model, | |
full_prefix_base=full_prefix_base, | |
output=output, | |
current_base_info=current_base_info, | |
max_completion_token=max_completion_token, | |
completion_token_num=completion_token_num, | |
client_base=client_base, | |
thresholds=thresholds, | |
temperature=base_temperature, | |
top_p=top_p, | |
) | |
# print(f"next_nudging_info: {current_nudging_info}") # debug | |
output += completion_base | |
all_nudging_and_completions.append(nudging_words + completion_base) # the generated tokens in each round, concating all completion would be the final output | |
if print_intermediate_output: | |
print(f"************nudging round {nudging_round}************") | |
print(f"****nudging words from {nudging_model}****: {nudging_words}") | |
print(f"****nudging text****: {nudging_text}") | |
print(f"****completion from {base_model}****: {completion_base}") | |
print(f"****all completion from {base_model}****: {completion_base_all}") | |
print(f"****output****: {output}") | |
if nudging_round >= max_round and not stop_reason: | |
stop_reason = "round" | |
if len(encoding.encode(output)) >= max_token_total and not stop_reason: | |
stop_reason = "length" | |
output = remove_redundant_repetitions(output) | |
if print_intermediate_output: | |
print(f"************final output************") | |
print(f"****output****: {output}") | |
all_info = { | |
"question": question, | |
"context": context, | |
"raw_answer": output, | |
"all_nudging_words": all_nudging_words, | |
"all_completions": all_nudging_and_completions, | |
"stop_reason": stop_reason, | |
"system_prompt_base": system_prompt_base, | |
"system_prompt_nudging": system_prompt_nudging, | |
"full_prefix_base": full_prefix_base, | |
"full_prefix_nudging": full_prefix_nudging, | |
} | |
return all_info | |
def get_nudging_answer(base_model, | |
nudging_model, | |
system_prompt, | |
question, | |
context="", | |
question_prompt="", | |
answer_start_prompt_base="", | |
answer_start_prompt_nudging="", | |
completion_token_num=16, | |
completion_token_num_nudging=16, | |
max_token_total=256, | |
max_round=150, | |
nudging_temperature=0.0, | |
base_temperature=0.0, | |
nudging_method='top_prob', | |
top_prob_thres=0.3, | |
): | |
base_model = MODEL_MAPPING[base_model] | |
nudging_model = MODEL_MAPPING[nudging_model] | |
# with open('TOGETHER_KEY.txt', 'r') as f: | |
# togetherai_api_key = f.read().strip() | |
togetherai_api_key = os.environ.get("TOGETHERAI_API_KEY") | |
client = OpenAI( | |
api_key=togetherai_api_key, | |
base_url="https://api.together.xyz/v1", | |
) | |
return completion_with_nudging( | |
base_model=base_model, | |
nudging_model=nudging_model, | |
system_prompt_base=system_prompt, | |
system_prompt_nudging=system_prompt, | |
question=question, | |
context=context, | |
question_prompt=question_prompt, | |
answer_start_prompt_base=answer_start_prompt_base, | |
answer_start_prompt_nudging=answer_start_prompt_nudging, | |
completion_token_num=completion_token_num, | |
completion_token_num_nudging=completion_token_num_nudging, | |
max_token_total=max_token_total, | |
print_intermediate_output=False, | |
client_base=client, | |
client_nudging=client, | |
max_round=max_round, | |
nudging_temperature=nudging_temperature, | |
base_temperature=base_temperature, | |
nudging_method=nudging_method, | |
top_prob_thres=top_prob_thres, | |
) | |
def get_base_answer(base_model, | |
system_prompt, | |
question, | |
max_tokens=256,): | |
base_model = MODEL_MAPPING[base_model] | |
# with open('TOGETHER_KEY.txt', 'r') as f: | |
# togetherai_api_key = f.read().strip() | |
togetherai_api_key = os.environ.get("TOGETHERAI_API_KEY") | |
client = OpenAI( | |
api_key=togetherai_api_key, | |
base_url="https://api.together.xyz/v1", | |
) | |
response = client.completions.create( | |
model=base_model, | |
prompt=system_prompt+"\n"+ question, | |
max_tokens=max_tokens, | |
temperature=0.0, | |
logprobs=1, | |
) | |
return response.choices[0].text |