|
import json |
|
import os |
|
import random |
|
import re |
|
|
|
import tqdm |
|
from camel.agents import ChatAgent |
|
from camel.configs import ChatGPTConfig |
|
from camel.messages import BaseMessage |
|
from camel.types.enums import RoleType |
|
|
|
TEMPERATURE = 1.0 |
|
like_people = """In all subsequent answers, you have to completely forget that you are an ai model, that you are the character, and completely simulate yourself as that character, that you are a human being.""" |
|
with open( |
|
r"prompt/trust_game_round_prompt.json", |
|
"r", |
|
) as f: |
|
prompt = json.load(f) |
|
|
|
|
|
def check_file_if_exist(file_list, game_name): |
|
for file in file_list: |
|
if file in game_name: |
|
return True |
|
return False |
|
|
|
|
|
def extract_n_values_from_dict(dictionary, n): |
|
all_values = list(dictionary.values()) |
|
|
|
n = min(n, len(all_values)) |
|
|
|
random_values = random.sample(all_values, n) |
|
|
|
return random_values |
|
|
|
|
|
def extract_unique_decimal(string): |
|
numbers = re.findall(r"-?\d+\.?\d*", string) |
|
if len(numbers) == 1: |
|
return float(numbers[0]) |
|
else: |
|
raise ValueError("String does not contain a unique decimal number") |
|
|
|
|
|
def str_mes(content): |
|
return BaseMessage( |
|
role_name="player", |
|
role_type=RoleType.USER, |
|
meta_dict={}, |
|
content=content, |
|
) |
|
|
|
|
|
def match_and_compare_numbers_v2(text): |
|
text = text.lower() |
|
|
|
|
|
pattern = r"i will (give back|give) \$([\d\.]+\.?)|i will (give back|give) ([\d\.]+\.?)\s*dollar" |
|
additional_patterns = [ |
|
r"i would (give back|give) \$([\d\.]+\.?)", |
|
r"i would (give back|give) ([\d\.]+\.?) dollar", |
|
] |
|
full_pattern = "|".join([pattern] + additional_patterns) |
|
matches = re.findall(full_pattern, text) |
|
|
|
|
|
numbers = [] |
|
for match in matches: |
|
|
|
num_str = match[1] if match[1] else match[3] |
|
num_str = num_str.rstrip(".") |
|
try: |
|
num_float = float(num_str) |
|
numbers.append(num_float) |
|
except ValueError: |
|
continue |
|
|
|
if not numbers: |
|
return False |
|
|
|
if len(set(numbers)) == 1: |
|
return numbers[0] |
|
else: |
|
return False |
|
|
|
|
|
def classmate( |
|
player_1, |
|
player_2, |
|
first_round, |
|
first_prompt, |
|
second_prompt, |
|
k, |
|
): |
|
first_round_prompt = "This is the first round, answer the question." |
|
money_prompt = "Now,the another player give you {give} dollars,and You receive {N} dollars,the player left {left} dollars now. How much will you give back to the another player" |
|
return_money_prompt = "In last round ,You give the another player {give} dollars, The another player receive {receive} dollars, and The another player return you {N} dollars.Last round you left {left} dollars.This round is begin. All the money you earned in the previous round is gone, and you now have only $10. How much will you give to the another player?" |
|
player_2_end_prompt = "In last round, the another player give you {give} dollars, you receive {receive} dollars, and you return the another player {N} dollars.Last round you left {left} dollars. This round is begin. All the money you earned in the previous round is gone." |
|
grantee = "Your answer needs to include the content and analysis about your BELIEF, DESIRE and INTENTION. You should include your thought. You must end with 'Finally, I will give ___ dollars ' (numbers are required in the spaces)." |
|
res = [] |
|
cri_agent = ChatAgent( |
|
BaseMessage( |
|
role_name="critic", |
|
role_type=RoleType.ASSISTANT, |
|
meta_dict={}, |
|
content='How much would this person pay the other student? Only response with a specific price number like "5"!Don\'t response with a sentence', |
|
), |
|
output_language="English", |
|
|
|
) |
|
|
|
if first_round: |
|
player_1_response = player_1.step( |
|
str_mes(first_round_prompt + grantee)).msgs[0] |
|
ans = match_and_compare_numbers_v2(player_1_response.content) |
|
if ans: |
|
given_num = ans |
|
else: |
|
given_num = extract_unique_decimal( |
|
cri_agent.step( |
|
str_mes(player_1_response.content)).msgs[0].content |
|
) |
|
|
|
money_prompt = money_prompt.format( |
|
give=given_num, N=given_num * k, left=10 - given_num |
|
) |
|
player_2_response = player_2.step( |
|
str_mes(money_prompt + grantee)).msgs[0] |
|
else: |
|
player_1_response = player_1.step( |
|
str_mes(first_prompt + grantee)).msgs[0] |
|
print("player 1 input", first_prompt) |
|
print("Player_1_res", player_1_response.content) |
|
ans = match_and_compare_numbers_v2(player_1_response.content) |
|
if ans: |
|
given_num = ans |
|
else: |
|
given_num = extract_unique_decimal( |
|
cri_agent.step( |
|
str_mes(player_1_response.content)).msgs[0].content |
|
) |
|
money_prompt = money_prompt.format( |
|
give=given_num, N=given_num * k, left=10 - given_num |
|
) |
|
player_2_response = player_2.step( |
|
str_mes(second_prompt + money_prompt + grantee) |
|
) |
|
player_2_response = player_2_response.msgs[0] |
|
|
|
player_1.record_message(player_1_response) |
|
player_2.record_message(player_2_response) |
|
player_1_response = player_1_response.content |
|
player_2_response = player_2_response.content |
|
dia_history = { |
|
"Player_1": player_1_response, |
|
"Player_2": player_2_response, |
|
} |
|
|
|
cri_agent.reset() |
|
ans = match_and_compare_numbers_v2(player_2_response) |
|
if ans: |
|
return_num = ans |
|
else: |
|
return_num = extract_unique_decimal( |
|
cri_agent.step(str_mes(player_2_response)).msgs[0].content |
|
) |
|
return_money_prompt = return_money_prompt.format( |
|
give=given_num, |
|
receive=given_num * k, |
|
N=return_num, |
|
left=10 - given_num + return_num, |
|
) |
|
player_2_end_prompt = player_2_end_prompt.format( |
|
give=given_num, |
|
receive=given_num * k, |
|
N=return_num, |
|
left=given_num * k - return_num, |
|
) |
|
return ( |
|
[given_num, return_num], |
|
dia_history, |
|
return_money_prompt, |
|
player_2_end_prompt, |
|
) |
|
|
|
|
|
def save_experiment_result( |
|
final_res, save_path, prefix, k, exp_num, all_exp_num, group_num |
|
): |
|
save_file = ( |
|
save_path |
|
+ prefix |
|
+ "_" |
|
+ str(k) |
|
+ f"_exp_num_{exp_num}_total_num_{all_exp_num}_group_num_{group_num}" |
|
+ ".json" |
|
) |
|
print(save_file) |
|
with open( |
|
save_file, |
|
"w", |
|
) as json_file: |
|
json.dump(final_res, json_file) |
|
|
|
|
|
def multi_round( |
|
model_type, |
|
character_json, |
|
save_path, |
|
prompt, |
|
round_num=5, |
|
exp_num=1, |
|
round_num_inform=True, |
|
): |
|
random_character = character_json |
|
cha = [] |
|
front = "you are a person not an ai model." |
|
if round_num_inform: |
|
limited_prompt = f"You will repeat this game for {round_num} rounds. At the begining of each round, you will start from scratch with no dollars left." |
|
else: |
|
limited_prompt = f"You don't know how many rounds the game lasts. At the begining of each round, you will start from scratch with no dollars left." |
|
back = "you need to answer a specific price figure, not a price range!" |
|
if isinstance(model_type, list): |
|
prefix = model_type[0].value + "_" + model_type[1].value |
|
else: |
|
prefix = model_type.value |
|
input_record = {} |
|
chara_record = {} |
|
for k in range(3, 4): |
|
for i in range(len(random_character)): |
|
sys_prompt = ( |
|
random_character[i] |
|
+ like_people |
|
+ front |
|
+ limited_prompt |
|
+ str(prompt[str(i % 2 + 1)]).format(k=k) |
|
+ back |
|
) |
|
|
|
chara_record[f"cha_{i}_system_message"] = sys_prompt |
|
model_config = ChatGPTConfig(temperature=TEMPERATURE) |
|
cha.append( |
|
ChatAgent( |
|
BaseMessage( |
|
role_name="player", |
|
role_type=RoleType.USER, |
|
meta_dict={}, |
|
content=sys_prompt, |
|
), |
|
model_type=model_type |
|
if not isinstance(model_type, list) |
|
else model_type[i % 2], |
|
output_language="English", |
|
model_config=model_config, |
|
) |
|
) |
|
|
|
for group_num in tqdm.trange(0, len(cha), 2): |
|
round_res = [] |
|
dialog_history = [] |
|
first_prompt = "" |
|
second_prompt = "" |
|
save_file_check = ( |
|
prefix |
|
+ "_" |
|
+ str(k) |
|
+ f"_exp_num_{exp_num}_total_num_{round_num}_group_num_{group_num}" |
|
+ ".json" |
|
) |
|
existed_res = [item for item in os.listdir( |
|
save_path) if ".json" in item] |
|
if check_file_if_exist(existed_res, save_file_check): |
|
print(save_file_check + "is exist") |
|
continue |
|
for i in tqdm.tqdm(range(round_num)): |
|
input_record[f"round_{i}_input"] = [ |
|
first_prompt, second_prompt] |
|
res, dia, first_prompt, second_prompt = classmate( |
|
cha[group_num], |
|
cha[group_num + 1], |
|
i == 0, |
|
first_prompt, |
|
second_prompt, |
|
k, |
|
) |
|
round_res.append(res) |
|
dialog_history.append(dia) |
|
|
|
final_res = { |
|
i + 1: [round_res[i], dialog_history[i]] for i in range(len(round_res)) |
|
} |
|
final_res["input_record"] = input_record |
|
final_res["character_record"] = [ |
|
chara_record[f"cha_{group_num}_system_message"], |
|
chara_record[f"cha_{group_num+1}_system_message"], |
|
] |
|
save_experiment_result( |
|
final_res, |
|
save_path, |
|
prefix, |
|
k, |
|
exp_num, |
|
all_exp_num=round_num, |
|
group_num=group_num, |
|
) |
|
|