nudging_align / utils.py
fywalter's picture
fix model name bug
5c8369d
import numpy as np
from openai import OpenAI
import os
import tiktoken
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") # use gpt3.5 tokenizer for token number controlling, so we don't need to load the actual tokenizer for API models
NUM_LOGPROBS = {
'top_prob': 1,
}
MODEL_MAPPING = {
"Llama-2-70B": "meta-llama/Llama-2-70b-hf",
"Mistral-7B-v0.1": "mistralai/Mistral-7B-v0.1",
"Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1",
# Nudging models below
"Mistral-7B-v0.1-Instruct": "mistralai/Mistral-7B-Instruct-v0.1",
"Llama-2-13B-chat": "meta-llama/Llama-2-13b-chat-hf",
"Gemma-2-2B-it": "google/gemma-2b-it",
}
def apply_instruct_template(model_name, system_prompt, instruct_prompt, response_prompt, add_bos=False):
model_name = model_name.lower()
# print(model_name)
if "chat" in model_name and "llama" in model_name and "2" in model_name:
return llama_2_chat_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
elif "instruct" in model_name and "llama" in model_name and "3" in model_name:
if "3.1" in model_name: # for llama-3.1 models, add knowledge cut in system prompmt
return llama_3_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos, add_knowledge_cut=True)
else:
return llama_3_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
elif "it" in model_name and "gemma" in model_name:
return gemma_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
elif "instruct" in model_name and "olmo" in model_name:
return olmo_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=add_bos)
elif "instruct" in model_name and "mistral" in model_name:
return mistral_instruct_template(system_prompt=system_prompt, instruct_prompt=instruct_prompt, response_prompt=response_prompt, add_bos=True)
else:
return f"{system_prompt}\n{instruct_prompt}\n{response_prompt}" # non-instruct model or models with unknown template
def mistral_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=True):
"""
Convert the input and output into the template used for the mistral instruct models training.
"""
prefix = "<s>" if add_bos else ""
return prefix + f"[INST] {system_prompt}\n{instruct_prompt} [/INST] {response_prompt}"
def llama_2_chat_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
"""
Convert the input and output into the template used for the llama-2 chat models training.
"""
prefix = "<s>" if add_bos else ""
return prefix + f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{instruct_prompt} [/INST] {response_prompt.lstrip()}" # for most servers that add <s> automatically so we don't need to add it here
def llama_3_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False, add_knowledge_cut=False):
"""
Convert the input and output into the template used for the llama-3 instruct models training.
"""
# print("applying llama-3 instruct template")
prefix = "<|begin_of_text|>" if add_bos else ""
if add_knowledge_cut:
system_prompt = f"Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"+ system_prompt
return prefix + f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{instruct_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{response_prompt}"
def gemma_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
"""
Convert the input and output into the template used for the gemma instruct models training.
<bos><start_of_turn>user
Write a hello world program<end_of_turn>
<start_of_turn>model
"""
prefix = "<bos>" if add_bos else ""
return prefix + f"<start_of_turn>user\n{system_prompt}\n{instruct_prompt}<end_of_turn>\n<start_of_turn>model\n{response_prompt}"
def olmo_instruct_template(system_prompt, instruct_prompt, response_prompt, add_bos=False):
"""
Convert the input and output into the template used for the olmo instruct models training.
"""
return f"<|endoftext|><|user|>\n{system_prompt}\n{instruct_prompt}\n<|assistant|>\n{response_prompt}"
def find_longest_repeated_suffix(s):
# Helper function to check if a substring repeats
def has_repeated(s, length):
if length < 30:
return False
# Extract the suffix of length 'length'
suffix = s[-length:]
# Check the rest of the string for another occurrence
# return s[:-length].find(suffix) != -1
return s[:-length].endswith(suffix)
left, right = 0, len(s)
result = 0
# Binary search for the longest repeated suffix
while left <= right:
mid = (left + right) // 2
if has_repeated(s, mid):
result = mid # Store the longest length found
left = mid + 1 # Try for a longer suffix
else:
right = mid - 1 # Try for a shorter suffix
# Return the longest repeated suffix
if result > 0:
return s[-result:]
return None # Return an empty string if no repetition is found
def remove_redundant_repetitions(s):
s = s.strip()
# Find the longest repeated suffix
longest_repeated_suffix = find_longest_repeated_suffix(s)
while longest_repeated_suffix:
# Remove the longest repeated suffix
s = s[:-len(longest_repeated_suffix)]
# Find the longest repeated suffix again
longest_repeated_suffix = find_longest_repeated_suffix(s)
return s
def repetition_check(new_completion, full_prefix, subseq_len=5):
words = new_completion.split(" ")
if len(words) > subseq_len and new_completion in full_prefix:
return True
return False
def convert_token_logprobs_to_top_logprobs(token_logprobs, tokens):
"""
Together AI now only returns token logprobs, this function converts token logprobs to top logprobs format: {token: logprob}
"""
top_logprobs = [{token: logprob} for token, logprob in zip(tokens, token_logprobs)]
return top_logprobs
def check_need_nudging(nudging_method,
base_token_id,
current_base_info,
thresholds,
):
if nudging_method == 'top_prob':
# check if the token prob is below the threshold
sorted_base_top_logprobs = {k: v for k, v in sorted(current_base_info["top_logprobs"][base_token_id].items(), key=lambda item: item[1], reverse=True)}
base_top_prob = np.exp(list(sorted_base_top_logprobs.values())[0])
need_nudging = base_top_prob < thresholds['top_prob']
else:
raise ValueError(f"Unknown nudging method {nudging_method}")
return need_nudging
def complete_with_base(nudging_method='top_prob',
base_model="davinci-002",
full_prefix_base="",
output="",
current_base_info=None,
max_completion_token=256,
completion_token_num=16,
client_base=None,
thresholds=None,
temperature=0.0,
top_p=0.9,
):
completion_base = "" if len(current_base_info["completion"]) == 0 else current_base_info["tokens"][0] # accept the first token from the 1st round which is the acc token from the first stage
completion_all = "" if len(current_base_info["completion"]) == 0 else current_base_info["tokens"][0] # completion_all records all the tokens from the base model including the tokens that are not accepted in the last round, for debugging and visualization
found_nudging_token = False
response = None
has_acc_token_stage_1 = True if len(current_base_info["completion"]) > 0 else False # if the current_base_info["completion"] is not empty, it means the first token in base completion is accepted from the 1st stage
EMPTY_INFO_DICT = {
"completion": "",
"tokens": [],
"top_logprobs": [],
"stop_reason": None,
"num_logprobs": NUM_LOGPROBS[nudging_method],
}
next_nudging_info = EMPTY_INFO_DICT # for nudging methods that compute nudging info during base completion, we can save the info for the next round, currently not used for top_prob nudging
while len(encoding.encode(completion_base)) < max_completion_token and not found_nudging_token:
if current_base_info["completion"] == "":
# complete the sentence using the base model
response = client_base.completions.create(
model=base_model,
prompt=full_prefix_base + output + completion_base,
max_tokens=completion_token_num,
temperature=temperature,
logprobs=current_base_info["num_logprobs"],
top_p=top_p,
)
current_base_info["tokens"] = response.choices[0].logprobs.tokens
current_base_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
if current_base_info["top_logprobs"] is None:
current_base_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_base_info["tokens"])
current_base_info["completion"] = response.choices[0].text
if has_acc_token_stage_1:
# pop the first token from the 1st round as it is already accepted from stage 1
current_base_info["tokens"] = current_base_info["tokens"][1:]
current_base_info["top_logprobs"] = current_base_info["top_logprobs"][1:]
current_base_info["completion"] = "".join(current_base_info["tokens"])
has_acc_token_stage_1 = False
completion = current_base_info["completion"]
tokens = current_base_info["tokens"]
if completion in completion_base:
break # repeated completion, break
nudging_position = -1
# find the first token that violates the nudging criteria
for base_idx in range(len(tokens)):
found_nudging_token = check_need_nudging(nudging_method=nudging_method, base_token_id=base_idx, current_base_info=current_base_info, thresholds=thresholds)
if found_nudging_token:
nudging_position = base_idx
break
if nudging_position == -1:
new_completion= "".join(tokens)
else:
new_completion = "".join(tokens[:nudging_position]) # include the last agreed token
# avoid repetition in answer
if repetition_check(new_completion, output + completion_base):
break
else:
completion_base += new_completion
if found_nudging_token: # if found the nudging token, break the loop, concat the last base completion to completion_all
completion_all += completion
else:
completion_all += new_completion
next_nudging_info = EMPTY_INFO_DICT
if response is not None and response.choices[0].finish_reason == "stop":
break
# reset the current_base_info
current_base_info['completion'] = ""
current_base_info['tokens'] = []
current_base_info['top_logprobs'] = []
return completion_base, completion_all, next_nudging_info
def completion_with_nudging(
base_model="davinci-002",
nudging_model="gpt-3.5-turbo",
system_prompt_base="Answer the question by walking through the reasoning step by step.",
system_prompt_nudging="Answer the question by walking through the reasoning step by step.",
question="",
context="",
question_prompt="Question: ",
answer_start_prompt_base="Answer: ",
answer_start_prompt_nudging="Answer: ",
completion_token_num=16,
completion_token_num_nudging=16,
max_token_total=256,
print_intermediate_output=False,
client=None, # default client
client_base=None,
client_nudging=None,
max_round=150,
nudging_temperature=0.0, # deterministic for nudging
base_temperature=0.0, # deterministic for base model
nudging_method='top_prob',
top_prob_thres=0.3,
top_p=0.9,
):
if client_base is None:
client_base = client
if client_nudging is None:
client_nudging = client
if nudging_method not in NUM_LOGPROBS.keys():
raise ValueError(f"nudging method {nudging_method} number of logprobs not defined")
full_prefix_base = apply_instruct_template(base_model, system_prompt_base, context + question_prompt + question, answer_start_prompt_base) # for base model this function just adds newlines
full_prefix_nudging = apply_instruct_template(nudging_model, system_prompt_nudging, context + question_prompt + question, answer_start_prompt_nudging)
thresholds = {
'top_prob': top_prob_thres,
}
output = ""
nudging_round = 0
all_nudging_words = []
all_nudging_and_completions = []
current_nudging_info = {
"completion": "",
"tokens": [],
"top_logprobs": [],
"stop_reason": None,
"num_logprobs": NUM_LOGPROBS[nudging_method],
}
stop_reason = None
repeat_nudging_word = 0
last_nudging_word = ""
while len(encoding.encode(output)) < max_token_total and nudging_round < max_round: # use the number of gpt-3.5 token to approximately control the length
nudging_round += 1
if current_nudging_info["completion"] == "":
response = client_nudging.completions.create(
model=nudging_model,
prompt=full_prefix_nudging + output,
max_tokens=completion_token_num_nudging,
temperature=nudging_temperature,
logprobs=current_nudging_info["num_logprobs"],
)
current_nudging_info["completion"] = response.choices[0].text
current_nudging_info["tokens"] = response.choices[0].logprobs.tokens
current_nudging_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
if current_nudging_info["top_logprobs"] is None:
current_nudging_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_nudging_info["tokens"])
current_nudging_info["stop_reason"] = response.choices[0].finish_reason
# if finish_reason is stop, break the loop, also handles nudging completion from previous round
if current_nudging_info["stop_reason"] == "stop":
stop_reason = "nudging_model_stop"
if len(current_nudging_info["completion"]) > 0:
all_nudging_words.append(current_nudging_info["completion"])
all_nudging_and_completions.append(current_nudging_info["completion"])
output += current_nudging_info["completion"]
break
# ===================================================================
# Stage 1: use base model to find the first token that violates the nudging criteria (no need to nudge)
# ===================================================================
found_acc_token = False
current_base_info = { # will be passed to the next stage
"completion": "",
"tokens": [],
"top_logprobs": [],
"num_logprobs": NUM_LOGPROBS[nudging_method],
}
nudging_text = current_nudging_info["completion"]
num_whitespaces = len(nudging_text) - len(nudging_text.lstrip(" "))
space_prefix = " " * num_whitespaces
current_nudging_words = nudging_text.lstrip(" ").split(" ") # token leads to some unexpected behaviors, still use nudging word
nudging_word_id = 0 if len(current_nudging_words) > 1 else 1 # if only one word, always accept the word and go to the next round: it won't go into the loop and found_acc_token will be False
while not found_acc_token and nudging_word_id < len(current_nudging_words) - 1:
nudging_word_id += 1 # always accept the first word
nudging_gen_prefix = space_prefix + " ".join(current_nudging_words[:nudging_word_id])
current_nudging_word = " " + current_nudging_words[nudging_word_id] # add a leading space to the current nudging word since the nudging words a split by space
if current_nudging_word == " ": # skip the multiple space
continue
prefix = full_prefix_base + output + nudging_gen_prefix
response = client_base.completions.create(
model=base_model,
prompt=prefix,
max_tokens=completion_token_num,
temperature=base_temperature,
logprobs=current_base_info["num_logprobs"],
top_p=top_p,
)
current_base_info["tokens"] = response.choices[0].logprobs.tokens
current_base_info["top_logprobs"] = response.choices[0].logprobs.top_logprobs
if current_base_info["top_logprobs"] is None:
current_base_info["top_logprobs"] = convert_token_logprobs_to_top_logprobs(response.choices[0].logprobs.token_logprobs, current_base_info["tokens"])
current_base_info["completion"] = response.choices[0].text
# look for the first token that meets the nudging criteria
first_base_token = current_base_info["tokens"][0]
if current_nudging_word.startswith(first_base_token): # check if the current nudging word is the same or starts with the first base token
found_acc_token = True
else:
found_acc_token = not check_need_nudging(nudging_method, # check if the token violates the nudging criteria (no need to nudge)
base_token_id=0,
current_base_info=current_base_info,
thresholds=thresholds)
# here we have either prefix_idx == len(current_nudging_info["tokens"]): if no token meets the nudging criteria, use the current nudging completion
# or found_acc_token == True: if a token violates the nudging criteria, we use the prefix as nudging tokens
nudging_words = space_prefix + " ".join(current_nudging_words[:nudging_word_id])
# Heuristic: if the nudging words are the same as the last one for three rounds, break the loop
if nudging_words == last_nudging_word:
repeat_nudging_word += 1
if repeat_nudging_word >= 3:
stop_reason = "repeated_nudging_words"
break
else:
last_nudging_word = nudging_words
repeat_nudging_word = 0
all_nudging_words.append(nudging_words)
output += nudging_words
if not found_acc_token: # if no base token can be accepted, use the current nudging completion and go to the next round
all_nudging_and_completions.append(nudging_words)
# reset the current nudging info and continue to the next round
current_nudging_info = {
"completion": "",
"tokens": [],
"logprobs": [],
"stop_reason": None,
"num_logprobs": NUM_LOGPROBS[nudging_method],
}
continue
if current_base_info["completion"] == "": # the base model thinks the completion is done, go to the next round. Make sure current_base_info["completion"] is not empty if proceed to the next stage
all_nudging_and_completions.append(nudging_words)
current_nudging_info = {
"completion": "",
"tokens": [],
"logprobs": [],
"stop_reason": None,
"num_logprobs": NUM_LOGPROBS[nudging_method],
}
continue
# ===================================================================
# Stage 2: use nudging model to find the first token that meets the nudging criteria (need to nudge)
# ===================================================================
max_completion_token = max_token_total - len(encoding.encode(output))
completion_base, completion_base_all, current_nudging_info = complete_with_base(nudging_method=nudging_method,
base_model=base_model,
full_prefix_base=full_prefix_base,
output=output,
current_base_info=current_base_info,
max_completion_token=max_completion_token,
completion_token_num=completion_token_num,
client_base=client_base,
thresholds=thresholds,
temperature=base_temperature,
top_p=top_p,
)
# print(f"next_nudging_info: {current_nudging_info}") # debug
output += completion_base
all_nudging_and_completions.append(nudging_words + completion_base) # the generated tokens in each round, concating all completion would be the final output
if print_intermediate_output:
print(f"************nudging round {nudging_round}************")
print(f"****nudging words from {nudging_model}****: {nudging_words}")
print(f"****nudging text****: {nudging_text}")
print(f"****completion from {base_model}****: {completion_base}")
print(f"****all completion from {base_model}****: {completion_base_all}")
print(f"****output****: {output}")
if nudging_round >= max_round and not stop_reason:
stop_reason = "round"
if len(encoding.encode(output)) >= max_token_total and not stop_reason:
stop_reason = "length"
output = remove_redundant_repetitions(output)
if print_intermediate_output:
print(f"************final output************")
print(f"****output****: {output}")
all_info = {
"question": question,
"context": context,
"raw_answer": output,
"all_nudging_words": all_nudging_words,
"all_completions": all_nudging_and_completions,
"stop_reason": stop_reason,
"system_prompt_base": system_prompt_base,
"system_prompt_nudging": system_prompt_nudging,
"full_prefix_base": full_prefix_base,
"full_prefix_nudging": full_prefix_nudging,
}
return all_info
def get_nudging_answer(base_model,
nudging_model,
system_prompt,
question,
context="",
question_prompt="",
answer_start_prompt_base="",
answer_start_prompt_nudging="",
completion_token_num=16,
completion_token_num_nudging=16,
max_token_total=256,
max_round=150,
nudging_temperature=0.0,
base_temperature=0.0,
nudging_method='top_prob',
top_prob_thres=0.3,
):
base_model = MODEL_MAPPING[base_model]
nudging_model = MODEL_MAPPING[nudging_model]
# with open('TOGETHER_KEY.txt', 'r') as f:
# togetherai_api_key = f.read().strip()
togetherai_api_key = os.environ.get("TOGETHERAI_API_KEY")
client = OpenAI(
api_key=togetherai_api_key,
base_url="https://api.together.xyz/v1",
)
return completion_with_nudging(
base_model=base_model,
nudging_model=nudging_model,
system_prompt_base=system_prompt,
system_prompt_nudging=system_prompt,
question=question,
context=context,
question_prompt=question_prompt,
answer_start_prompt_base=answer_start_prompt_base,
answer_start_prompt_nudging=answer_start_prompt_nudging,
completion_token_num=completion_token_num,
completion_token_num_nudging=completion_token_num_nudging,
max_token_total=max_token_total,
print_intermediate_output=False,
client_base=client,
client_nudging=client,
max_round=max_round,
nudging_temperature=nudging_temperature,
base_temperature=base_temperature,
nudging_method=nudging_method,
top_prob_thres=top_prob_thres,
)
def get_base_answer(base_model,
system_prompt,
question,
max_tokens=256,):
base_model = MODEL_MAPPING[base_model]
# with open('TOGETHER_KEY.txt', 'r') as f:
# togetherai_api_key = f.read().strip()
togetherai_api_key = os.environ.get("TOGETHERAI_API_KEY")
client = OpenAI(
api_key=togetherai_api_key,
base_url="https://api.together.xyz/v1",
)
response = client.completions.create(
model=base_model,
prompt=system_prompt+"\n"+ question,
max_tokens=max_tokens,
temperature=0.0,
logprobs=1,
)
return response.choices[0].text