Who_is_The_Spy / app.py
sci-m-wang's picture
Update app.py
6823e55 verified
import streamlit as st
from random import choices, randint
import os
os.system("pip install transformers")
os.system("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu")
os.system("pip install einops")
os.system("pip install sentencepiece")
os.system("pip install openai")
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from openai import OpenAI
# """
# 谁是卧底游戏
# """
state = st.session_state
# title
st.title("谁是卧底😎")
# read the word list from the json file
with open("word_list.json", "r") as f:
word_list = json.load(f)
pass
# define the avatar dict for the players
avatar_dict = {
"host": "🐼",
"P1": "🚀",
"P2": "🚄",
"P3": "🚁",
"P4": "🚂",
"P5": "🚢",
"P6": "🚤",
"P7": "🚙",
"P8": "🚠",
"P9": "🚲",
"P10": "🚜",
"H": "🤹‍♂️",
}
# define the state of the game and save some data
## 全局消息栈
if "messages" not in state:
state.messages = []
pass
## 玩家列表
if "players" not in state:
state.players = []
pass
## 玩家系统提示
if "prompt" not in state:
with open("prompt.txt", "r") as f:
state.prompt = f.read()
pass
pass
# create a new OpenAI client and define the generation function
if "client" not in state:
state.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("BASE_URL")
)
state.model_name = "internlm/internlm2_5-20b-chat"
# model_path = "internlm/internlm2_5-7b-chat"
# state.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)
# state.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# state.model.eval()
pass
def response(messages):
return state.client.chat.completions.create(
model=state.model_name,
messages=messages,
temperature=0.5,
).choices[0].message.content
# prompt = state.tokenizer.apply_chat_template(
# messages,
# tokenize=False,
# add_generation_prompt=True
# )
# response,history = state.model.chat(state.tokenizer,prompt,history=[])
# return response
# settings
## 记录轮次
if "max_round" not in state:
state.max_round = 100
pass
if "round" not in state:
state.round = 0
pass
## 侧边栏设置游戏,包括关键词,人数,回合数等
with st.sidebar:
st.write("游戏设置")
with st.form(key="game_setting"):
if "words" not in state:
words_num = randint(0, len(word_list)-1)
state.words = word_list[words_num]
total_num = st.number_input("总人数", 5, 10, 5)
spy_num = st.number_input("卧底人数", 1, total_num//2, 1)
max_round = st.number_input("最大回合数", 5, 10, 10)
submitted = st.form_submit_button("保存设置")
## 提交后保存设置,初始化玩家、消息栈等
if submitted:
# print("游戏设置已保存")
state.spy_word = state.words["spy_word"] # 卧底关键词
state.civilian_word = state.words["civilian_word"] # 平民关键词
state.total_num = total_num # 总人数
state.spy_num = spy_num # 卧底人数
state.max_round = max_round # 最大回合数
## 初始化玩家列表,人类玩家和AI玩家分开存
human_dignity = randint(0,1) # 人类玩家的身份,0: 平民 1: 卧底
if human_dignity == 0:
state.players = [{"id": "H", "dignity": "civilian"}]
st.write("你的关键词是{}".format(state.civilian_word))
pass
else:
state.players = [{"id": "H", "dignity": "spy"}]
st.write("你的关键词是{}".format(state.spy_word))
pass
state.players += [{"id":"P"+str(i+1)} for i in range(total_num-1)]
if human_dignity == 1 and state.spy_num-1 == 0:
spy_id = []
pass
else:
spy_id = choices([f"P{a}" for a in list(range(1,total_num))], k=state.spy_num)
pass
for each in state.players:
if each["id"] in spy_id:
each["dignity"] = "spy"
pass
else:
each["dignity"] = "civilian"
pass
pass
pass
pass
pass
# 消息显示窗口
with st.container(height=300):
for message in state.messages:
with st.chat_message(message["id"], avatar=avatar_dict[message["id"]]):
st.text(message["id"])
st.write(message["message"])
pass
# 游戏主体环节
if state.round < state.max_round:
if "description" not in state:
state.description = []
pass
## 控制游戏轮次及开始
start = st.button("开始第{}轮游戏".format(state.round+1))
if start:
if "round" not in state:
state.round = 0
pass
state.messages.append({"id":"host", "message":f"第{state.round+1}轮游戏开始"})
## 生成描述环节
for player in state.players:
## 如果是人类玩家,跳过
if player["id"] == "H":
continue
if player["dignity"] == "spy":
text = response([
{"role":"system", "content":state.prompt.format(state.spy_word)},
{"role":"system", "content":"/describe "+"请根据描述历史记录描述你的关键词,需要注意不能与已有的描述重复,下面是描述历史记录:\n"+"\n".join(state.description)}
])
else:
text = response([
{"role":"system", "content":state.prompt.format(state.civilian_word)},
{"role":"system", "content":"/describe "+"请根据描述历史记录描述你的关键词,下面是描述历史记录:\n"+"\n".join(state.description)}
])
pass
state.messages.append({"id":player["id"], "message": text})
state.description.append(player["id"] + ":" + text)
pass
st.rerun()
pass
## 投票环节, AI玩家生成回复,最后由人类玩家统计并选择投票对象
col1, col2 = st.columns([8,2])
with col1:
vote_id = st.selectbox("投票对象", [a["id"] for a in state.players])
pass
with col2:
if st.button("开始投票"):
for player in state.players:
if player["id"] == "H":
continue
text = response([
{"role":"system", "content":state.prompt.format(state.spy_word)},
{"role":"system", "content":"/vote "+"请根据描述历史记录选择要投出的玩家,下面是描述历史记录:\n"+"\n".join(state.description)+"\n"+"当前场上存活玩家id为:\n"+",".join([a["id"] for a in state.players])}
])
state.messages.append({"id":player["id"], "message": text})
pass
st.rerun()
pass
pass
if st.button("投出玩家"):
state.messages.append({"id":"host", "message":f"玩家{vote_id}被投票出局"})
state.round += 1
for player in state.players:
if player["id"] == vote_id:
state.players.remove(player)
break
pass
## 验证是否还有卧底存活
spy_live = False
for player in state.players:
if player["dignity"] == "spy":
spy_live = True
break
pass
if not spy_live and state.players:
state.messages.append({"id":"host", "message":"平民胜利!"})
pass
elif spy_live:
## 统计当前卧底人数,如果占据一半以上则卧底胜利
spy_num = 0
for player in state.players:
if player["dignity"] == "spy":
spy_num += 1
pass
pass
if spy_num >= len(state.players)//2:
state.messages.append({"id":"host", "message":"卧底胜利!"})
pass
st.rerun()
pass
human_live = False
for player in state.players:
if player["id"] == "H":
human_live = True
break
pass
if human_live:
if des := st.chat_input():
state.messages.append({"id":"H", "message":des})
state.description.append("H:"+des)
st.rerun()