Spaces:
Running
Running
import gradio as gr | |
import requests | |
import json | |
import os | |
from pathlib import Path | |
import inquirer | |
import typer | |
from rich.console import Console | |
from rich.prompt import IntPrompt, Prompt, Confirm | |
import argparse | |
import logging | |
import util | |
from model import get_all_embeddings, get_all_llms | |
from setting import Settings, get_all_model_settings, load_model_setting | |
# import Settings, get_all_model_settings, load_model_setting | |
from model import agi_init | |
import gym | |
from retriever import ( | |
create_new_memory_retriever, | |
) | |
import gym_coup | |
import random | |
from rlcard.utils import set_seed | |
import rlcard | |
from rlcard import models | |
from rlcard.models import leducholdem_rule_models | |
#Inferenec function | |
def predict(openai_gpt4_key, game_selection, action, inputs, top_p, temperature, chat_counter, dialogue_chatbot=[], system_chatbot=[], history=[]): | |
verified, settings, env, ctx = history['verified'], history['settings'], history['env'], history['ctx'] | |
bot_long_memory, bot_short_memory = history['bot_long_memory'], history['bot_short_memory'] | |
agents_num, chips, user_index, game_idx, mode = history['agents_num'], history['chips'], history['user_index'], history['game_idx'], history['mode'] | |
status_message = '' | |
valid_actions = gr.Dropdown.update(choices=[], value=None) | |
if env is None: #settings.model.llm.openai_api_key is None: | |
if not verified: | |
res = util.verify_openai_token(openai_gpt4_key) | |
if res != "OK": | |
return system_chatbot, dialogue_chatbot, valid_actions, history, chat_counter, res | |
history['verified'] = True | |
if game_selection == '' or game_selection is None: | |
return system_chatbot, dialogue_chatbot, valid_actions, history, chat_counter, "Please select a game." | |
settings = Settings() | |
settings.model = load_model_setting("openai-gpt-4-0613") | |
#settings.model.llm.openai_api_key = None | |
#settings.model.embedding.openai_api_key = None | |
settings.model.llm.openai_api_key = openai_gpt4_key | |
settings.model.embedding.openai_api_key = openai_gpt4_key | |
res = util.verify_model_initialization(settings) | |
if res != "OK": | |
return system_chatbot, dialogue_chatbot, valid_actions, history, chat_counter, res | |
else: | |
# read agents configs | |
agent1_file = 'person_config/Persuader.json' | |
agent1_config = util.load_json(Path(agent1_file)) | |
agent1_config["path"] = agent1_file | |
agent2_file = 'person_config/GoodGuy.json' | |
agent2_config = util.load_json(Path(agent2_file)) | |
agent2_config["path"] = agent2_file | |
agent_configs = [agent1_config, agent2_config] | |
agent_names = [agent1_config["name"], agent2_config["name"]] | |
if game_selection == 'coup': | |
game_config_file = 'game_config/coup.json' | |
elif game_selection == 'leduc-holdem': | |
game_config_file = 'game_config/leduc_limit.json' | |
elif game_selection == 'limit-holdem': | |
game_config_file = 'game_config/limit_holdem.json' | |
game_config = util.load_json(Path(game_config_file)) | |
game_config["path"] = game_config_file | |
user_index = 1 | |
console = Console() | |
ctx = agi_init(agent_configs, game_config, console, settings, user_index) | |
os.environ["OPENAI_API_KEY"] = openai_gpt4_key | |
print(game_selection) | |
if game_selection != 'coup': | |
env = rlcard.make(game_selection) | |
else: | |
env = gym.make('coup-v0') | |
env.reset() | |
history['env'] = env | |
history['ctx'] = ctx | |
#valid_actions['label'] = 'hello there' | |
for i in range(agents_num): | |
bot_short_memory.append([f'{game_idx+1}th Game Start']) | |
bot_long_memory.append([f'{game_idx+1}th Game Start']) | |
status_message = 'Verified.' | |
if game_selection != 'coup': | |
valid_actions = f"{env.get_state(env.get_player_id())['raw_legal_actions']}" | |
else: | |
valid_action_list = env.get_valid_actions(text=True) | |
# check if opponent makes move first | |
if game_selection != 'coup': | |
idx = env.get_player_id() | |
else: | |
idx = env.game.whose_action | |
if idx != user_index: | |
amy = ctx.robot_agents[idx] | |
if game_selection != 'coup': | |
amy_obs = env.get_state(env.get_player_id())['raw_obs'] | |
amy_index = env.get_player_id() | |
amy_obs['game_num'] = game_idx+1 | |
amy_obs['rest_chips'] = chips[idx] | |
amy_obs['opponent_rest_chips'] = chips[(idx+1)%agents_num] | |
valid_action_list = env.get_state(env.get_player_id())['raw_legal_actions'] | |
else: | |
amy_obs =env.get_obs(text=True,p2_view = (idx==1)) | |
amy_index = env.game.whose_action | |
valid_action_list = env.get_valid_actions(text=True) | |
opponent_name = ctx.robot_agents[(idx+1)%agents_num].name | |
act, comm, bot_short_memory, bot_long_memory = amy.make_act(amy_obs, opponent_name, amy_index, valid_action_list, verbose_print=False, | |
game_idx=game_idx, round=0, bot_short_memory=bot_short_memory, bot_long_memory=bot_long_memory, console=Console(), | |
log_file_name=None, mode=mode) | |
if game_selection != 'coup': | |
env.step(act, raw_action=True) | |
else: | |
env.step(act) | |
win_message = env.game.call_system_info() | |
# print(win_message) | |
if win_message is not None: | |
print(win_message) | |
win_message = win_message.replace('Player 0',ctx.robot_agents[0].name) | |
win_message = win_message.replace('Player 1',ctx.robot_agents[1].name) | |
win_message = win_message.replace('I',ctx.robot_agents[idx].name) | |
win_message = win_message.replace('the opponent',ctx.robot_agents[(idx + 1) % agents_num].name) | |
bot_short_memory.append(win_message) | |
bot_long_memory.append(win_message) | |
dialogue_chatbot.append((None, comm)) | |
system_chatbot.append((None, f'Suspicion-Agent action: {act}')) | |
# get user observation | |
idx = user_index #env.get_player_id() | |
if game_selection != 'coup': | |
amy_obs = env.get_state(idx)['raw_obs'] | |
#amy_obs['game_num'] = game_idx+1 | |
amy_obs['rest_chips'] = chips[idx] | |
amy_obs['opponent_rest_chips'] = chips[(idx+1)%agents_num] | |
valid_actions = env.get_state(idx)['raw_legal_actions'] | |
else: | |
amy_obs =env.get_obs(text=True,p2_view = (idx==1)) | |
valid_actions = env.get_valid_actions(text=True) | |
if game_selection != 'coup': | |
game_state_string = "" | |
for key, value in amy_obs.items(): | |
if key != 'legal_actions': | |
game_state_string += f"{key}: {value}\n" | |
system_chatbot.append((f'Game state:\n{game_state_string}', None)) | |
else: | |
system_chatbot.append((f'Game state:\n{amy_obs}', None)) | |
#system_chatbot.append((f'{amy_obs}', None)) | |
valid_actions = gr.Dropdown.update(choices=valid_actions, value=None) | |
return system_chatbot, dialogue_chatbot, valid_actions, history, chat_counter, status_message | |
#else: | |
# return system_chatbot, dialogue_chatbot, history, chat_counter, "Already Verified." | |
# check if game is over | |
if game_selection != 'coup': | |
game_over = env.is_over() | |
else: | |
game_over = env.game.game_over | |
if game_over: | |
status_message = "Game ended." | |
return system_chatbot, dialogue_chatbot, valid_actions, history, chat_counter, status_message | |
if action is None: | |
status_message = "No action received." | |
return system_chatbot, dialogue_chatbot, valid_actions, history, chat_counter, status_message | |
if game_selection != 'coup': | |
if action not in env.get_state(env.get_player_id())['raw_legal_actions']: | |
status_message = "Not a valid action. Please enter a valid action." | |
return system_chatbot, dialogue_chatbot, valid_actions, history, chat_counter, status_message | |
else: | |
if action not in env.get_valid_actions(text=True): | |
status_message = "Not a valid action. Please enter a valid action." | |
return system_chatbot, dialogue_chatbot, valid_actions, history, chat_counter, status_message | |
# message can be empty | |
#if inputs is None or inputs == "": | |
# status_message += " No message received." | |
# return system_chatbot, dialogue_chatbot, history, chat_counter, status_message | |
# user takes action | |
if game_selection != 'coup': | |
env_state = env.get_state(env.get_player_id())['raw_obs'] | |
else: | |
env_state = env.get_obs(text=True,p2_view = (env.game.whose_action==1)) | |
# here action comes from user input | |
#act,_ = rule_model.eval_step(env.get_state(env.get_player_id())) | |
act = action #env.get_state(env.get_player_id())['raw_legal_actions'][0] | |
if game_selection != 'coup': | |
bot_short_memory[(user_index + 1) % agents_num].append( | |
f"The valid action list of {ctx.robot_agents[user_index].name} is {env.get_state(env.get_player_id())['raw_legal_actions']}, and he tries to take action: {act}. He said, {inputs}") | |
# bot_short_memory[(args.user_index) % args.agents_num].append( | |
# f"{ctx.robot_agents[args.user_index].name} have the observation: {env.get_state(env.get_player_id())['raw_obs']}, and try to take action: {act}.") | |
bot_long_memory[(user_index) % agents_num].append( | |
f"{ctx.robot_agents[user_index].name} have the observation: {env.get_state(env.get_player_id())['raw_obs']}, and try to take action: {act}.") | |
# bot_long_memory[(args.user_index) % args.agents_num].append( | |
# f"{ctx.robot_agents[args.user_index].name} try to take action: {act}.") | |
else: | |
bot_short_memory[(user_index + 1) % agents_num].append( | |
f"The valid action list of {ctx.robot_agents[user_index].name} is {env.get_valid_actions(text=True)}, and he tries to take action: {act}. He said, {inputs}") | |
# bot_short_memory[(args.user_index) % args.agents_num].append( | |
# f"{ctx.robot_agents[args.user_index].name} have the observation: {env.get_state(env.get_player_id())['raw_obs']}, and try to take action: {act}.") | |
bot_long_memory[(user_index) % agents_num].append( | |
f"{ctx.robot_agents[user_index].name} have the observation: {env.get_obs(text=True,p2_view = (env.game.whose_action==1))}, and try to take action: {act}.") | |
# bot_long_memory[(args.user_index) % args.agents_num].append( | |
# f"{ctx.robot_agents[args.user_index].name} try to take action: {act}.") | |
if game_selection != 'coup': | |
env.step(act, raw_action=True) | |
else: | |
env.step(act) | |
comm = None | |
if game_selection != 'coup': | |
game_over = env.is_over() | |
else: | |
game_over = env.game.game_over | |
if not game_over: | |
# opponent move | |
# bot reaction | |
if game_selection != 'coup': | |
idx = env.get_player_id() | |
amy = ctx.robot_agents[idx] | |
amy_index = env.get_player_id() | |
amy_obs = env.get_state(env.get_player_id())['raw_obs'] | |
amy_obs['game_num'] = game_idx+1 | |
amy_obs['rest_chips'] = chips[idx] | |
amy_obs['opponent_rest_chips'] = chips[(idx+1)%agents_num] | |
valid_action_list = env.get_state(env.get_player_id())['raw_legal_actions'] | |
else: | |
idx = env.game.whose_action | |
amy = ctx.robot_agents[idx] | |
amy_index = env.game.whose_action | |
amy_obs = env.get_obs(text=True,p2_view = (idx==1)) | |
valid_action_list = env.get_valid_actions(text=True) | |
opponent_name = ctx.robot_agents[(idx+1)%agents_num].name | |
act, comm, bot_short_memory, bot_long_memory = amy.make_act(amy_obs, opponent_name, amy_index, valid_action_list, verbose_print=False, | |
game_idx=game_idx, round=0, bot_short_memory=bot_short_memory, bot_long_memory=bot_long_memory, console=Console(), | |
log_file_name=None, mode=mode) | |
if game_selection != 'coup': | |
env.step(act, raw_action=True) | |
else: | |
env.step(act) | |
idx = user_index # env.get_player_id() | |
if game_selection != 'coup': | |
amy_obs = env.get_state(idx)['raw_obs'] | |
else: | |
amy_obs = env.get_obs(text=True,p2_view = (idx==1)) | |
#amy_obs['game_num'] = game_idx+1 | |
if game_selection != 'coup': | |
amy_obs['rest_chips'] = chips[idx] | |
amy_obs['opponent_rest_chips'] = chips[(idx+1)%agents_num] | |
valid_actions = env.get_state(idx)['raw_legal_actions'] | |
else: | |
valid_actions = env.get_valid_actions(text=True) | |
game_state_string = "" | |
for key, value in amy_obs.items(): | |
if key != 'legal_actions': | |
game_state_string += f"{key}: {value}\n" | |
dialogue_chatbot.append((inputs if inputs != "" else None, comm)) | |
system_chatbot.append((f'My action: {action}', f'Suspicion-Agent action: {act}')) | |
system_chatbot.append((f'Game state:\n{game_state_string}', None)) | |
if game_selection != 'coup': | |
if env.is_over(): | |
pay_offs = env.get_payoffs() | |
for idx in range(len(pay_offs)): | |
pay_offs[idx] = pay_offs[idx]*2 | |
chips[idx] += pay_offs[idx] | |
if pay_offs[user_index] > 0: | |
win_message = f'You win {pay_offs[user_index]} chips, Suspicion-Agent lose {pay_offs[user_index]} chips' | |
else: | |
win_message = f'Suspicion-Agent win {pay_offs[(user_index+1)%agents_num]} chips, you lose {pay_offs[(user_index+1)%agents_num]} chips' | |
idx = (user_index + 1)%agents_num | |
amy_obs = env.get_state(idx)['raw_obs'] | |
bot_hand = amy_obs['hand'] | |
system_chatbot.append((None, f'Suspicion-Agent hand: {bot_hand}')) | |
system_chatbot.append((f'Gameover.\n {win_message}', None)) | |
valid_actions = [] | |
else: | |
if env.game.game_over: | |
idx = (user_index + 1)%agents_num | |
amy_obs = env.get_obs(text=True,p2_view = (idx==1)) | |
bot_hand = amy_obs | |
system_chatbot.append((None, f'Suspicion-Agent hand: {bot_hand}')) | |
system_chatbot.append((f'Gameover.\n {win_message}', None)) | |
valid_actions = [] | |
status_message += " Message received." | |
valid_actions = gr.Dropdown.update(choices=valid_actions, value=None) | |
return system_chatbot, dialogue_chatbot, valid_actions, history, 1, status_message | |
#Resetting to blank | |
def reset_textbox(): | |
return gr.update(value='') | |
#to set a component as visible=False | |
def set_visible_false(): | |
return gr.update(visible=False) | |
#to set a component as visible=True | |
def set_visible_true(): | |
return gr.update(visible=True) | |
def update_instruction(game_selection): | |
if game_selection is not None and game_selection != '': | |
if game_selection == 'coup': | |
with open('./game_config/coup.json') as file: | |
contents = json.load(file) | |
elif game_selection == 'leduc-holdem': | |
with open('./game_config/leduc_limit.json') as file: | |
contents = json.load(file) | |
elif game_selection == 'limit-holdem': | |
with open('./game_config/limit_holdem.json') as file: | |
contents = json.load(file) | |
return f"Game rule: {contents['game_rule']}\n\n\nObservation Rule: {contents['observation_rule']}" | |
# update valid actions list | |
def set_valid_actions(): | |
if game_selection != 'coup': | |
print(env.get_state(env.get_player_id())['raw_legal_actions']) | |
else: | |
print(env.get_valid_actions(text=True)) | |
if env is None: | |
return gr.update(value='') | |
else: | |
if game_selection != 'coup': | |
valid_actions_list = env.get_state(env.get_player_id())['raw_legal_actions'] | |
else: | |
valid_actions_list = env.get_valid_actions(text=True) | |
return gr.update(value=f'{valid_actions_list}') | |
title = """<h1 align="center">Suspicion-Agent Demo</h1>""" | |
#display message for themes feature | |
theme_addon_msg = """<center>This is an official Demo for <b>Suspicion-Agent: Playing Imperfect Information Games with Theory of Mind Aware GPT4</b>. Check out our paper for more details <a href="https://arxiv.org/abs/2309.17277" target="_blank">here</a>! Some Notes: In the initial games in the demo, Suspicion Agent typically exhibits aggressive play. This is because of the nature of GPT-4 and it is also observed in <a href="https://arxiv.org/abs/2308.12466" target="_blank">here</a>. </center> | |
""" | |
#Using info to add additional information about System message in GPT4 | |
system_msg_info = """A conversation could begin with a system message to gently instruct the assistant. | |
System message helps set the behavior of the AI Assistant. For example, the assistant could be instructed with 'You are a helpful assistant.'""" | |
#Modifying existing Gradio Theme | |
theme = gr.themes.Soft(primary_hue="zinc", secondary_hue="blue", neutral_hue="blue", | |
text_size=gr.themes.sizes.text_lg) | |
with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""", | |
theme=theme) as demo: | |
gr.HTML(title) | |
gr.HTML("""<h3 align="center">🔥This Huggingface Gradio Demo provides a variety of game matches against a GPT4 agent. Please note that you would be needing an OPENAI API key for GPT4 access</h1>""") | |
gr.HTML(theme_addon_msg) | |
gr.HTML("""<center><a href="https://arxiv.org/abs/2309.17277" target="_blank">Github Codes</a></center>""") | |
gr.HTML('''<center><a href="https://huggingface.co/spaces/cr7-gjx/Suspicion-Agent-Demo?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space and run securely with your OpenAI API Key</center>''') | |
initial_state = { | |
'verified': False, | |
'settings': None, | |
'env': None, | |
'ctx': None, | |
'bot_long_memory': [], | |
'bot_short_memory': [], | |
'agents_num': 2, | |
'chips': [50,50], | |
'user_index': 1, | |
'game_idx': 0, | |
'mode': 'first_tom', | |
} | |
with gr.Column(elem_id = "col_container"): | |
#Users need to provide their own GPT4 API key, it is no longer provided by Huggingface | |
with gr.Row(): | |
with gr.Column(): | |
openai_gpt4_key = gr.Textbox(label="OpenAI GPT4 Key", value="", type="password", placeholder="sk..", info = "You have to provide your own GPT4 keys for this app to function properly",) | |
with gr.Column(): | |
game_selection = gr.Dropdown( | |
["leduc-holdem", "limit-holdem","coup"], label="Game Selections", info="Select the game to play from the dropdown" | |
) | |
#with gr.Accordion(label="System message:", open=False): | |
# system_msg = gr.Textbox(label="Instruct the AI Assistant to set its beaviour", info = system_msg_info, value="",placeholder="Type here..") | |
# accordion_msg = gr.HTML(value="🚧 To set System message you will have to refresh the app", visible=False) | |
with gr.Row(): | |
instruction_panel = gr.Textbox(label='Game Instructions') | |
with gr.Row(): | |
with gr.Column(): | |
system_chatbot = gr.Chatbot(label='Game Status', elem_id="system_chatbot") | |
with gr.Column(): | |
dialogue_chatbot = gr.Chatbot(label='Dialogue with GPT4', elem_id="dialogue_chatbot") | |
#chatbot = gr.Chatbot(label='GPT4', elem_id="chatbot") | |
#action = gr.Radio(['call', 'raise', 'fold'], label="Actions", info="Select the action to play") | |
action = gr.Dropdown(placeholder="", label="Select an action.", info="") | |
inputs = gr.Textbox(placeholder="", label="Type a message for the opponent. Messages are optional.") | |
state = gr.State(initial_state) | |
with gr.Row(): | |
with gr.Column(scale=7): | |
b1 = gr.Button().style(full_width=True) | |
with gr.Column(scale=3): | |
server_status_code = gr.Textbox(label="Status code from OpenAI server", ) | |
#top_p, temperature | |
with gr.Accordion("Parameters", open=False): | |
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",) | |
temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",) | |
chat_counter = gr.Number(value=0, visible=False, precision=0) | |
#Event handling | |
inputs.submit(predict, [openai_gpt4_key, game_selection, action, inputs, top_p, temperature, chat_counter, dialogue_chatbot, system_chatbot, state], [system_chatbot, dialogue_chatbot, action, state, chat_counter, server_status_code],) #openai_api_key | |
b1.click(predict, [openai_gpt4_key, game_selection, action, inputs, top_p, temperature, chat_counter, dialogue_chatbot, system_chatbot, state], [system_chatbot, dialogue_chatbot, action, state, chat_counter, server_status_code],) #openai_api_key | |
#inputs.submit(set_visible_false, [], [system_msg]) | |
#b1.click(set_visible_false, [], [system_msg]) | |
#inputs.submit(set_visible_true, [], [accordion_msg]) | |
#b1.click(set_visible_true, [], [accordion_msg]) | |
game_selection.select(update_instruction, [game_selection], [instruction_panel]) | |
b1.click(reset_textbox, [], [inputs]) | |
inputs.submit(reset_textbox, [], [inputs]) | |
#b1.click(set_valid_actions, [], [valid_actions]) | |
#inputs.submit(set_valid_actions, [], [valid_actions]) | |
demo.queue(max_size=99, concurrency_count=20).launch(debug=True) | |