Spaces:
Running
Running
import streamlit as st | |
from random import choices, randint | |
import os | |
os.system("pip install transformers") | |
os.system("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu") | |
os.system("pip install einops") | |
os.system("pip install sentencepiece") | |
os.system("pip install openai") | |
import json | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from openai import OpenAI | |
# """ | |
# 谁是卧底游戏 | |
# """ | |
state = st.session_state | |
# title | |
st.title("谁是卧底😎") | |
# read the word list from the json file | |
with open("word_list.json", "r") as f: | |
word_list = json.load(f) | |
pass | |
# define the avatar dict for the players | |
avatar_dict = { | |
"host": "🐼", | |
"P1": "🚀", | |
"P2": "🚄", | |
"P3": "🚁", | |
"P4": "🚂", | |
"P5": "🚢", | |
"P6": "🚤", | |
"P7": "🚙", | |
"P8": "🚠", | |
"P9": "🚲", | |
"P10": "🚜", | |
"H": "🤹♂️", | |
} | |
# define the state of the game and save some data | |
## 全局消息栈 | |
if "messages" not in state: | |
state.messages = [] | |
pass | |
## 玩家列表 | |
if "players" not in state: | |
state.players = [] | |
pass | |
## 玩家系统提示 | |
if "prompt" not in state: | |
with open("prompt.txt", "r") as f: | |
state.prompt = f.read() | |
pass | |
pass | |
# create a new OpenAI client and define the generation function | |
if "client" not in state: | |
state.client = OpenAI( | |
api_key=os.getenv("OPENAI_API_KEY"), | |
base_url=os.getenv("BASE_URL") | |
) | |
state.model_name = "internlm/internlm2_5-20b-chat" | |
# model_path = "internlm/internlm2_5-7b-chat" | |
# state.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True) | |
# state.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
# state.model.eval() | |
pass | |
def response(messages): | |
return state.client.chat.completions.create( | |
model=state.model_name, | |
messages=messages, | |
temperature=0.5, | |
).choices[0].message.content | |
# prompt = state.tokenizer.apply_chat_template( | |
# messages, | |
# tokenize=False, | |
# add_generation_prompt=True | |
# ) | |
# response,history = state.model.chat(state.tokenizer,prompt,history=[]) | |
# return response | |
# settings | |
## 记录轮次 | |
if "max_round" not in state: | |
state.max_round = 100 | |
pass | |
if "round" not in state: | |
state.round = 0 | |
pass | |
## 侧边栏设置游戏,包括关键词,人数,回合数等 | |
with st.sidebar: | |
st.write("游戏设置") | |
with st.form(key="game_setting"): | |
if "words" not in state: | |
words_num = randint(0, len(word_list)-1) | |
state.words = word_list[words_num] | |
total_num = st.number_input("总人数", 5, 10, 5) | |
spy_num = st.number_input("卧底人数", 1, total_num//2, 1) | |
max_round = st.number_input("最大回合数", 5, 10, 10) | |
submitted = st.form_submit_button("保存设置") | |
## 提交后保存设置,初始化玩家、消息栈等 | |
if submitted: | |
# print("游戏设置已保存") | |
state.spy_word = state.words["spy_word"] # 卧底关键词 | |
state.civilian_word = state.words["civilian_word"] # 平民关键词 | |
state.total_num = total_num # 总人数 | |
state.spy_num = spy_num # 卧底人数 | |
state.max_round = max_round # 最大回合数 | |
## 初始化玩家列表,人类玩家和AI玩家分开存 | |
human_dignity = randint(0,1) # 人类玩家的身份,0: 平民 1: 卧底 | |
if human_dignity == 0: | |
state.players = [{"id": "H", "dignity": "civilian"}] | |
st.write("你的关键词是{}".format(state.civilian_word)) | |
pass | |
else: | |
state.players = [{"id": "H", "dignity": "spy"}] | |
st.write("你的关键词是{}".format(state.spy_word)) | |
pass | |
state.players += [{"id":"P"+str(i+1)} for i in range(total_num-1)] | |
if human_dignity == 1 and state.spy_num-1 == 0: | |
spy_id = [] | |
pass | |
else: | |
spy_id = choices([f"P{a}" for a in list(range(1,total_num))], k=state.spy_num) | |
pass | |
for each in state.players: | |
if each["id"] in spy_id: | |
each["dignity"] = "spy" | |
pass | |
else: | |
each["dignity"] = "civilian" | |
pass | |
pass | |
pass | |
pass | |
pass | |
# 消息显示窗口 | |
with st.container(height=300): | |
for message in state.messages: | |
with st.chat_message(message["id"], avatar=avatar_dict[message["id"]]): | |
st.text(message["id"]) | |
st.write(message["message"]) | |
pass | |
# 游戏主体环节 | |
if state.round < state.max_round: | |
if "description" not in state: | |
state.description = [] | |
pass | |
## 控制游戏轮次及开始 | |
start = st.button("开始第{}轮游戏".format(state.round+1)) | |
if start: | |
if "round" not in state: | |
state.round = 0 | |
pass | |
state.messages.append({"id":"host", "message":f"第{state.round+1}轮游戏开始"}) | |
## 生成描述环节 | |
for player in state.players: | |
## 如果是人类玩家,跳过 | |
if player["id"] == "H": | |
continue | |
if player["dignity"] == "spy": | |
text = response([ | |
{"role":"system", "content":state.prompt.format(state.spy_word)}, | |
{"role":"system", "content":"/describe "+"请根据描述历史记录描述你的关键词,需要注意不能与已有的描述重复,下面是描述历史记录:\n"+"\n".join(state.description)} | |
]) | |
else: | |
text = response([ | |
{"role":"system", "content":state.prompt.format(state.civilian_word)}, | |
{"role":"system", "content":"/describe "+"请根据描述历史记录描述你的关键词,下面是描述历史记录:\n"+"\n".join(state.description)} | |
]) | |
pass | |
state.messages.append({"id":player["id"], "message": text}) | |
state.description.append(player["id"] + ":" + text) | |
pass | |
st.rerun() | |
pass | |
## 投票环节, AI玩家生成回复,最后由人类玩家统计并选择投票对象 | |
col1, col2 = st.columns([8,2]) | |
with col1: | |
vote_id = st.selectbox("投票对象", [a["id"] for a in state.players]) | |
pass | |
with col2: | |
if st.button("开始投票"): | |
for player in state.players: | |
if player["id"] == "H": | |
continue | |
text = response([ | |
{"role":"system", "content":state.prompt.format(state.spy_word)}, | |
{"role":"system", "content":"/vote "+"请根据描述历史记录选择要投出的玩家,下面是描述历史记录:\n"+"\n".join(state.description)+"\n"+"当前场上存活玩家id为:\n"+",".join([a["id"] for a in state.players])} | |
]) | |
state.messages.append({"id":player["id"], "message": text}) | |
pass | |
st.rerun() | |
pass | |
pass | |
if st.button("投出玩家"): | |
state.messages.append({"id":"host", "message":f"玩家{vote_id}被投票出局"}) | |
state.round += 1 | |
for player in state.players: | |
if player["id"] == vote_id: | |
state.players.remove(player) | |
break | |
pass | |
## 验证是否还有卧底存活 | |
spy_live = False | |
for player in state.players: | |
if player["dignity"] == "spy": | |
spy_live = True | |
break | |
pass | |
if not spy_live and state.players: | |
state.messages.append({"id":"host", "message":"平民胜利!"}) | |
pass | |
elif spy_live: | |
## 统计当前卧底人数,如果占据一半以上则卧底胜利 | |
spy_num = 0 | |
for player in state.players: | |
if player["dignity"] == "spy": | |
spy_num += 1 | |
pass | |
pass | |
if spy_num >= len(state.players)//2: | |
state.messages.append({"id":"host", "message":"卧底胜利!"}) | |
pass | |
st.rerun() | |
pass | |
human_live = False | |
for player in state.players: | |
if player["id"] == "H": | |
human_live = True | |
break | |
pass | |
if human_live: | |
if des := st.chat_input(): | |
state.messages.append({"id":"H", "message":des}) | |
state.description.append("H:"+des) | |
st.rerun() | |