Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gitignore +2 -3
- agents/Gemma-2-9b-it.ipynb +1 -0
- agents/__init__.py +5 -0
- agents/_reference.py +216 -0
- agents/chatgpt.py +145 -0
- agents/dsr1_distill.py +138 -0
- agents/gemma_2_9b_it.py +104 -0
- agents/llama3.py +102 -0
- agents/qwen2_5_7b_instruct.py +112 -0
- agents/qwen2_5_math.py +137 -0
- agents/runner.py +89 -0
- play_gradio.py +2 -2
- play_helper.py +69 -27
- play_with_auth.py +1 -1
- play_with_hf.py +132 -0
- problemsets/Anagram Scribble_1.json +0 -0
- problemsets/Anagram Scribble_2.json +0 -0
- problemsets/Anagram Scribble_3.json +0 -0
- problemsets/Bracket Game_1.json +0 -0
- problemsets/Bracket Game_2.json +0 -0
- problemsets/Bracket Game_3.json +0 -0
- problemsets/Crossword Arranger_1.json +0 -0
- problemsets/Crossword Arranger_2.json +0 -0
- problemsets/Crossword Arranger_3.json +0 -0
- reval_ana3.py +87 -0
- reval_bracket_all.py +94 -0
- reval_bracket_rerun.py +46 -0
- reval_crosswords_all.py +94 -0
- reval_sudoku_all.py +94 -0
- textgames-scrabble-black2-ss.png +0 -0
- textgames/__init__.py +10 -7
- textgames/anagram_scribble/anagram_scribble.py +40 -8
- textgames/bracket_game/bracket_game.py +97 -38
- textgames/crossword_arranger/crossword_arranger.py +30 -2
- textgames/islands/islands.py +15 -3
- textgames/ordering_text/ordering_text.py +42 -12
- textgames/password_game/password_game.py +10 -0
- textgames/string_search/string_search.py +13 -1
- textgames/sudoku/sudoku.py +32 -8
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
textgames-scrabble-black2-ss.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
*/*.DS_Store
|
| 2 |
.DS_Store
|
| 3 |
|
| 4 |
-
|
| 5 |
-
problemsets_*
|
| 6 |
-
|
| 7 |
user_outputs/
|
|
|
|
| 8 |
|
| 9 |
.idea/
|
| 10 |
|
|
|
|
| 1 |
*/*.DS_Store
|
| 2 |
.DS_Store
|
| 3 |
|
| 4 |
+
agents/*.sh
|
|
|
|
|
|
|
| 5 |
user_outputs/
|
| 6 |
+
model_outputs/__runs__
|
| 7 |
|
| 8 |
.idea/
|
| 9 |
|
agents/Gemma-2-9b-it.ipynb
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":1,"metadata":{"id":"Rli_enT6lBDT","executionInfo":{"status":"ok","timestamp":1737395007014,"user_tz":-540,"elapsed":5212,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"outputs":[],"source":["##%%\n","import os\n","import pickle\n","import json\n","# import random\n","# import torch\n","# import numpy as np\n","# import argparse\n","# import cohere\n","# from openai import OpenAI\n"]},{"cell_type":"code","source":["##%%\n","# import hashlib\n","from tqdm import tqdm\n","from itertools import product\n","# from collections import Counter\n","\n","# from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM\n","from transformers import AutoTokenizer, AutoModelForCausalLM\n","from textgames import GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name\n"],"metadata":{"id":"dp1F32B8oSfD","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395010583,"user_tz":-540,"elapsed":3547,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"e9adeb5f-70eb-4ca9-dcbb-428e4b28ab41"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stderr","text":["/home/is/frederikus-h/miniconda3/envs/textgame/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n"," from .autonotebook import tqdm as notebook_tqdm\n"]}]},{"cell_type":"code","source":["os.environ.setdefault(\"TEXTGAMES_OUTPUT_DIR\", \"user_outputs\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"2wEu1V1wvxn0","executionInfo":{"status":"ok","timestamp":1737395010664,"user_tz":-540,"elapsed":67,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"cdcad20f-e357-4009-9f4f-0d4495ebd894"},"execution_count":3,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'user_outputs'"]},"metadata":{},"execution_count":3}]},{"cell_type":"code","source":["##%%\n","gen_model_checkpoint = \"google/gemma-2-9b-it\"\n","quantize = True"],"metadata":{"id":"jZF8bkUcojTX","executionInfo":{"status":"ok","timestamp":1737395010678,"user_tz":-540,"elapsed":13,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":4,"outputs":[]},{"cell_type":"code","source":["kwargs = {\n"," \"device_map\": \"auto\",\n","} if quantize else {}"],"metadata":{"id":"VAF5sR9arYzS","executionInfo":{"status":"ok","timestamp":1737395010683,"user_tz":-540,"elapsed":2,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":5,"outputs":[]},{"cell_type":"code","source":["##%%\n","gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **kwargs)\n","tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **kwargs)"],"metadata":{"id":"tzqldl8ooRVL","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395038547,"user_tz":-540,"elapsed":27859,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"902b638c-e6ce-4f8a-bba2-e9f7241c9a27"},"execution_count":6,"outputs":[{"output_type":"stream","name":"stderr","text":["Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:24<00:00, 6.19s/it]\n"]}]},{"cell_type":"code","source":["gen_model.device"],"metadata":{"id":"FeBUXdkWsWrL","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395038552,"user_tz":-540,"elapsed":3,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"6437d1b7-02f8-47f5-d519-e979cefde795"},"execution_count":7,"outputs":[{"output_type":"execute_result","data":{"text/plain":["device(type='cuda', index=0)"]},"metadata":{},"execution_count":7}]},{"cell_type":"code","source":["def get_gemma_response(text):\n"," # global gen_model, tokenizer\n"," messages = [\n"," {\"role\": \"user\", \"content\": text},\n"," ]\n","\n"," input_ids = tokenizer.apply_chat_template(\n"," messages,\n"," add_generation_prompt=True,\n"," return_tensors=\"pt\"\n"," ).to(gen_model.device)\n","\n"," terminators = [\n"," tokenizer.eos_token_id,\n"," tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n"," ]\n","\n"," outputs = gen_model.generate(\n"," input_ids,\n"," max_new_tokens=100,\n"," eos_token_id=terminators,\n"," do_sample=True,\n"," temperature=.001,\n"," top_p=1,\n"," )\n","\n"," response = outputs[0][input_ids.shape[-1]:]\n"," return tokenizer.decode(response, skip_special_tokens=True)"],"metadata":{"id":"R5D4K-P2sPaj","executionInfo":{"status":"ok","timestamp":1737395038554,"user_tz":-540,"elapsed":1,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":8,"outputs":[]},{"cell_type":"markdown","source":["---\n","Example Call"],"metadata":{"id":"s5FEwOOvxf4h"}},{"cell_type":"code","source":["# @title\n","text = \\\n","\"\"\"\n","Given a set of rules to calculate point, sort the set of words in decreasing order.\n","When there 2 or more words with same point, sort lexicographically.\n","\n","Rules:\n","- every pair of consecutive consonant gets 5 points\n","- every pair of consecutive vowel gets 3 points\n","- add 1 point if there exists exactly 1 'g' in the word\n","- word less than 5 characters gets 10 points\n","- word starts with gen gets 100 points\n","- word ends with ta gets -1000 point\n","\n","Words:\n","- genta\n","- winata\n","- hudi\n","- alham\n","- aji\n","- ruochen\n","\n","Print only the answer.\n","\"\"\"\n","\n","print(text)"],"metadata":{"id":"T_tk4hTGsxsR","colab":{"base_uri":"https://localhost:8080/"},"cellView":"form","executionInfo":{"status":"ok","timestamp":1737392776367,"user_tz":-540,"elapsed":27,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"d5ea884f-d0fa-4134-ecd9-690eab51c976"},"execution_count":14,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","Given a set of rules to calculate point, sort the set of words in decreasing order.\n","When there 2 or more words with same point, sort lexicographically.\n","\n","Rules:\n","- every pair of consecutive consonant gets 5 points\n","- every pair of consecutive vowel gets 3 points\n","- add 1 point if there exists exactly 1 'g' in the word\n","- word less than 5 characters gets 10 points\n","- word starts with gen gets 100 points\n","- word ends with ta gets -1000 point\n","\n","Words:\n","- genta\n","- winata\n","- hudi\n","- alham\n","- aji\n","- ruochen\n","\n","Print only the answer.\n","\n"]}]},{"cell_type":"code","source":["# Gold Answer:\n","# - aji 10\n","# - hudi 10\n","# - ruochen 5 3\n","# - alham 5\n","# - genta 5 1 100 -1000\n","# - winata -1000"],"metadata":{"id":"G-5yS4S-rdsN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(get_gemma_response(text))"],"metadata":{"id":"05OI36v6vGoY","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737392724119,"user_tz":-540,"elapsed":6741,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"fe5d6ed2-d063-4f1c-b2e1-b3af8dbc456e"},"execution_count":9,"outputs":[{"output_type":"stream","name":"stdout","text":["genta\n","winata\n","ruochen\n","hudi\n","alham\n","aji \n","\n"]}]},{"cell_type":"markdown","source":["---\n","Automate run all sessions"],"metadata":{"id":"cxJ4WqHpxi75"}},{"cell_type":"code","source":["for game_name, difficulty_level in product([GAME_NAMES[4], *GAME_NAMES[:4], *GAME_NAMES[5:]], LEVEL_IDS[:3]):\n"," game_cls = _game_class_from_name(game_name)\n"," with open(f\"problemsets/{game_filename(game_name)}_{difficulty_level}.json\", \"r\", encoding=\"utf8\") as f:\n"," sid_prompt_dict = json.load(f)\n","\n"," correct_cnt = 0\n"," for sid, prompt in tqdm(list(sid_prompt_dict.items()), desc=f\"{game_filename(game_name)}_-_{difficulty_level}\"):\n"," cur_game = game_cls()\n"," cur_game.load_game(prompt)\n"," response = get_gemma_response(cur_game.get_prompt()).strip()\n"," solved, val_msg = cur_game.validate(response)\n"," with open(f\"model_outputs/results_gemma_2_9B_it.pkl\", \"ab\") as o:\n"," pickle.dump((f\"{game_filename(game_name)}_{difficulty_level}\", sid, response, (solved, val_msg)), o)\n"," if solved:\n"," correct_cnt += 1\n","\n"," print(f\"{game_name}_-_{difficulty_level}\")\n"," print(f\" Acc.: {correct_cnt / len(sid_prompt_dict):.2%}\")"],"metadata":{"id":"hCTXYpXa1UQ6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"GC-zkVI52IJX"},"execution_count":null,"outputs":[]}]}
|
agents/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Define the __all__ variable
|
| 2 |
+
__all__ = ["run_with_agent"]
|
| 3 |
+
|
| 4 |
+
# Import the submodules
|
| 5 |
+
from .runner import run_with_agent
|
agents/_reference.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import cohere
|
| 8 |
+
from openai import OpenAI
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from collections import Counter
|
| 13 |
+
|
| 14 |
+
from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
| 15 |
+
import hashlib
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
OPENAI_TOKEN = ""
|
| 19 |
+
COHERE_TOKEN = ""
|
| 20 |
+
HF_TOKEN = ""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def argmax(array):
|
| 24 |
+
"""argmax with deterministic pseudorandom tie breaking."""
|
| 25 |
+
max_indices = np.arange(len(array))[array == np.max(array)]
|
| 26 |
+
idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(), 16) % len(max_indices)
|
| 27 |
+
return max_indices[idx]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def logsumexp(x):
|
| 31 |
+
c = x.max()
|
| 32 |
+
return c + np.log(np.sum(np.exp(x - c)))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def normalize(x):
|
| 36 |
+
x = np.array(x)
|
| 37 |
+
return np.exp(x - logsumexp(x))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def set_seed(seed):
|
| 41 |
+
random.seed(seed)
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
torch.manual_seed(seed)
|
| 44 |
+
torch.cuda.manual_seed(seed)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_commandr_chat_response(gen_model, gen_model_checkpoint, text, seed):
|
| 48 |
+
response = gen_model.chat(
|
| 49 |
+
model="command-r",
|
| 50 |
+
message=text,
|
| 51 |
+
temperature=0,
|
| 52 |
+
max_tokens=64,
|
| 53 |
+
seed=seed,
|
| 54 |
+
p=1
|
| 55 |
+
)
|
| 56 |
+
return response.text
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_mt0_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
|
| 60 |
+
input_ids = tokenizer.encode(text, return_tensors="pt").to(gen_model.device)
|
| 61 |
+
|
| 62 |
+
outputs = gen_model.generate(
|
| 63 |
+
input_ids,
|
| 64 |
+
max_new_tokens=10,
|
| 65 |
+
do_sample=True,
|
| 66 |
+
temperature=0.2,
|
| 67 |
+
top_p=1
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
response = outputs[0]
|
| 71 |
+
return tokenizer.decode(response, skip_special_tokens=True)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_gemma_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
|
| 75 |
+
messages = [
|
| 76 |
+
{"role": "user", "content": text},
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
input_ids = tokenizer.apply_chat_template(
|
| 80 |
+
messages,
|
| 81 |
+
add_generation_prompt=True,
|
| 82 |
+
return_tensors="pt"
|
| 83 |
+
).to(gen_model.device)
|
| 84 |
+
|
| 85 |
+
terminators = [
|
| 86 |
+
tokenizer.eos_token_id,
|
| 87 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
outputs = gen_model.generate(
|
| 91 |
+
input_ids,
|
| 92 |
+
max_new_tokens=10,
|
| 93 |
+
eos_token_id=terminators,
|
| 94 |
+
do_sample=True,
|
| 95 |
+
temperature=0.2,
|
| 96 |
+
top_p=1
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
response = outputs[0][input_ids.shape[-1]:]
|
| 100 |
+
return tokenizer.decode(response, skip_special_tokens=True)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_mistral_instruct_chat_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
|
| 104 |
+
messages = [
|
| 105 |
+
{"role": "user", "content": text},
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
input_ids = tokenizer.apply_chat_template(
|
| 109 |
+
messages,
|
| 110 |
+
add_generation_prompt=True,
|
| 111 |
+
return_tensors="pt"
|
| 112 |
+
).to(gen_model.device)
|
| 113 |
+
|
| 114 |
+
terminators = [
|
| 115 |
+
tokenizer.eos_token_id,
|
| 116 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
outputs = gen_model.generate(
|
| 120 |
+
input_ids,
|
| 121 |
+
max_new_tokens=10,
|
| 122 |
+
eos_token_id=terminators,
|
| 123 |
+
do_sample=True,
|
| 124 |
+
temperature=0.2,
|
| 125 |
+
top_p=1
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
response = outputs[0][input_ids.shape[-1]:]
|
| 129 |
+
return tokenizer.decode(response, skip_special_tokens=True)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_llama3_instruct_chat_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
|
| 133 |
+
messages = [
|
| 134 |
+
{"role": "user", "content": text},
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
input_ids = tokenizer.apply_chat_template(
|
| 138 |
+
messages,
|
| 139 |
+
add_generation_prompt=True,
|
| 140 |
+
return_tensors="pt"
|
| 141 |
+
).to(gen_model.device)
|
| 142 |
+
|
| 143 |
+
terminators = [
|
| 144 |
+
tokenizer.eos_token_id,
|
| 145 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
outputs = gen_model.generate(
|
| 149 |
+
input_ids,
|
| 150 |
+
max_new_tokens=10,
|
| 151 |
+
eos_token_id=terminators,
|
| 152 |
+
do_sample=True,
|
| 153 |
+
temperature=0.2,
|
| 154 |
+
top_p=1
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
response = outputs[0][input_ids.shape[-1]:]
|
| 158 |
+
return tokenizer.decode(response, skip_special_tokens=True)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_openai_chat_response(gen_model, gen_model_checkpoint, text, seed):
|
| 162 |
+
messages = [
|
| 163 |
+
{
|
| 164 |
+
"role": "user",
|
| 165 |
+
"content": text
|
| 166 |
+
}
|
| 167 |
+
]
|
| 168 |
+
response = gen_model.chat.completions.create(
|
| 169 |
+
model=gen_model_checkpoint,
|
| 170 |
+
messages=messages,
|
| 171 |
+
temperature=0,
|
| 172 |
+
max_tokens=64,
|
| 173 |
+
top_p=1,
|
| 174 |
+
seed=seed
|
| 175 |
+
)
|
| 176 |
+
return response.choices[0].message.content
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def load_model(gen_model_checkpoint, load_in_8bit=False):
|
| 180 |
+
gen_model = None
|
| 181 |
+
tokenizer = None
|
| 182 |
+
|
| 183 |
+
if "mistralai/Mistral-7B-Instruct-v0.3" in gen_model_checkpoint or "meta-llama/Meta-Llama-3-8B-Instruct" in gen_model_checkpoint or "google/gemma-1.1-7b-it" in gen_model_checkpoint:
|
| 184 |
+
if load_in_8bit:
|
| 185 |
+
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
| 186 |
+
load_in_8bit=True)
|
| 187 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
| 188 |
+
load_in_8bit=True)
|
| 189 |
+
else:
|
| 190 |
+
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
| 191 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
| 192 |
+
elif "CohereForAI/aya-101" in gen_model_checkpoint or "bigscience/mt0" in gen_model_checkpoint:
|
| 193 |
+
if load_in_8bit:
|
| 194 |
+
gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
| 195 |
+
load_in_8bit=True)
|
| 196 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
| 197 |
+
load_in_8bit=True)
|
| 198 |
+
else:
|
| 199 |
+
gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
| 200 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
| 201 |
+
elif "facebook/xglm" in gen_model_checkpoint or "bigscience/bloomz" in gen_model_checkpoint or "aya-23-8B" in args.gen_model_checkpoint:
|
| 202 |
+
if load_in_8bit:
|
| 203 |
+
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
| 204 |
+
load_in_8bit=True)
|
| 205 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
| 206 |
+
load_in_8bit=True)
|
| 207 |
+
else:
|
| 208 |
+
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
| 209 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
| 210 |
+
elif "gpt-3.5-turbo" in gen_model_checkpoint or "gpt-4" in gen_model_checkpoint:
|
| 211 |
+
gen_model = OpenAI(api_key=OPENAI_TOKEN)
|
| 212 |
+
elif "command-r" in gen_model_checkpoint:
|
| 213 |
+
gen_model = cohere.Client(COHERE_TOKEN)
|
| 214 |
+
|
| 215 |
+
return gen_model, tokenizer
|
| 216 |
+
|
agents/chatgpt.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
#%%
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from transformers import set_seed
|
| 12 |
+
from textgames import GAME_NAMES, LEVEL_IDS, game_filename
|
| 13 |
+
from agents import run_with_agent
|
| 14 |
+
|
| 15 |
+
#%%
|
| 16 |
+
def set_all_seed(seed=42):
|
| 17 |
+
set_seed(seed)
|
| 18 |
+
np.random.seed(seed)
|
| 19 |
+
torch.manual_seed(seed)
|
| 20 |
+
torch.cuda.manual_seed_all(seed)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
#%%
|
| 24 |
+
def _getenv_as_int(attr, default=None):
|
| 25 |
+
ret = os.getenv(attr, default)
|
| 26 |
+
return None if ret is None else int(ret)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
| 30 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
| 31 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
| 32 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 1)
|
| 33 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
| 34 |
+
GPT_MODEL = os.getenv("TG_GPT_MODEL", "")
|
| 35 |
+
# MAX_NEW_TOKENS = _getenv_as_int("TG_MAX_NEW_TOKENS", 12000)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
#%%
|
| 39 |
+
def preload_responses():
|
| 40 |
+
responses_all = dict()
|
| 41 |
+
for _turn in range(1, N_TURNS+1):
|
| 42 |
+
fp = os.getenv(
|
| 43 |
+
f"TG_GPT_OUTPUT_FILE_TURN_{_turn}",
|
| 44 |
+
(f"model_outputs/__runs__/chatgpt_4o_mini_results/raw/batch_output_chatgpt-{GPT_MODEL}_turn{_turn}"
|
| 45 |
+
f"{'.1s' if ONE_SHOT else '.zs'}.jsonl")
|
| 46 |
+
)
|
| 47 |
+
if not Path(fp).exists():
|
| 48 |
+
if _turn < N_TURNS:
|
| 49 |
+
print(f" batch_output turn {_turn} is not available. path: \"{fp}\"")
|
| 50 |
+
break
|
| 51 |
+
with open(fp, "r", encoding="utf8") as i:
|
| 52 |
+
data = [json.loads(line) for line in i]
|
| 53 |
+
for d in data:
|
| 54 |
+
sid, g = d['custom_id'].rsplit('-', 2)[-2:]
|
| 55 |
+
msg = d['response']['body']['choices'][0]['message']
|
| 56 |
+
responses_all.setdefault((g, _turn), dict())[sid] = msg['content']
|
| 57 |
+
responses_all[g, _turn][sid] = msg['content']
|
| 58 |
+
# assert msg['role'] == 'assistant'
|
| 59 |
+
# assert msg['refusal'] is None
|
| 60 |
+
# assert sum(len(v) for v in responses_all.values()) == 24000
|
| 61 |
+
return responses_all
|
| 62 |
+
RESPONSES_ALL = preload_responses()
|
| 63 |
+
print(f"len(RESPONSES_ALL) = {len(RESPONSES_ALL)}")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
#%%
|
| 67 |
+
def gpt_postproc(response_txt_batch, *args, **kwargs):
|
| 68 |
+
response_txt_batch = [response_txt_batch]
|
| 69 |
+
ret = []
|
| 70 |
+
for response_txt in response_txt_batch:
|
| 71 |
+
if response_txt is None:
|
| 72 |
+
ret.append(response_txt)
|
| 73 |
+
continue
|
| 74 |
+
cur = None
|
| 75 |
+
for pat in [
|
| 76 |
+
re.compile(r'^```\n?([^`]*)\n?```'),
|
| 77 |
+
# re.compile(r'\*\*\"?([^\"*]*)\"?\*\*'),
|
| 78 |
+
re.compile(r'((.|\n)*)\n\nExplanation:\n'),
|
| 79 |
+
]:
|
| 80 |
+
match = pat.search(response_txt)
|
| 81 |
+
if match:
|
| 82 |
+
cur = match.group(1).strip()
|
| 83 |
+
# .replace(" ", "")
|
| 84 |
+
break
|
| 85 |
+
ret.append(cur if cur else response_txt)
|
| 86 |
+
return ret[0]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
#%%
|
| 90 |
+
def get_gpt_response(texts, game_name, difficulty_level, turn, *args, **kwargs):
|
| 91 |
+
# global model, tokenizer
|
| 92 |
+
sid = kwargs['sid'] # sid must be fed as params
|
| 93 |
+
messages = [
|
| 94 |
+
({"role": "user", "content": text}
|
| 95 |
+
if i % 2 == 0 else
|
| 96 |
+
{"role": "assistant", "content": text})
|
| 97 |
+
for i, text in enumerate(texts)
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
response = None
|
| 101 |
+
responses_all = RESPONSES_ALL.get((f"{game_filename(game_name)}_{difficulty_level}", turn), {})
|
| 102 |
+
if responses_all:
|
| 103 |
+
response = responses_all[sid]
|
| 104 |
+
elif fp_next := os.getenv("TG_GPT_NEXTTURN_OUTPUT_FILE", None):
|
| 105 |
+
with open(fp_next, "a", encoding="utf8") as o:
|
| 106 |
+
o.write(json.dumps({
|
| 107 |
+
'custom_id': f"{sid}-{game_filename(game_name)}_{difficulty_level}",
|
| 108 |
+
"method": "POST", "url": "/v1/chat/completions",
|
| 109 |
+
"body": {
|
| 110 |
+
"model": "gpt-4o-mini-2024-07-18",
|
| 111 |
+
"max_completion_tokens": 200,
|
| 112 |
+
# "messages": [],
|
| 113 |
+
'messages': messages,
|
| 114 |
+
"seed": 42,
|
| 115 |
+
"temperature": 0,
|
| 116 |
+
}
|
| 117 |
+
}))
|
| 118 |
+
o.write("\n")
|
| 119 |
+
|
| 120 |
+
return response
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
#%%
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
fp_out = (f"model_outputs/__runs__/chatgpt_4o_mini_results/process/results_chatgpt-{GPT_MODEL}"
|
| 126 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
| 127 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
| 128 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
| 129 |
+
f".jsonl")
|
| 130 |
+
|
| 131 |
+
set_all_seed()
|
| 132 |
+
|
| 133 |
+
run_with_agent(
|
| 134 |
+
fp_out,
|
| 135 |
+
get_gpt_response,
|
| 136 |
+
gpt_postproc,
|
| 137 |
+
n_turns=N_TURNS,
|
| 138 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
| 139 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
| 140 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
| 141 |
+
if SID_ST or SID_ED else None),
|
| 142 |
+
prepend_example=ONE_SHOT,
|
| 143 |
+
# remove_if_output_file_exist=False,
|
| 144 |
+
assistant_uses_raw_response=False,
|
| 145 |
+
)
|
agents/dsr1_distill.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
#%%
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
| 9 |
+
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS
|
| 10 |
+
from agents import run_with_agent
|
| 11 |
+
|
| 12 |
+
#%%
|
| 13 |
+
def set_all_seed(seed=42):
|
| 14 |
+
set_seed(seed)
|
| 15 |
+
np.random.seed(seed)
|
| 16 |
+
torch.manual_seed(seed)
|
| 17 |
+
torch.cuda.manual_seed_all(seed)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#%%
|
| 21 |
+
def _getenv_as_int(attr, default=None):
|
| 22 |
+
ret = os.getenv(attr, default)
|
| 23 |
+
return None if ret is None else int(ret)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
| 27 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
| 28 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
| 29 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 1)
|
| 30 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
| 31 |
+
MAX_NEW_TOKENS = _getenv_as_int("TG_MAX_NEW_TOKENS", 12000)
|
| 32 |
+
DSR1_SIZE = os.getenv("TG_DSR1_SIZE", "14") # {1.5, 7, 8, 14, 32, 70}
|
| 33 |
+
DSR1_NAME = {
|
| 34 |
+
"1.5": "Qwen-1.5",
|
| 35 |
+
"7": "Qwen-7",
|
| 36 |
+
"8": "Llama-8",
|
| 37 |
+
"14": "Qwen-14",
|
| 38 |
+
"32": "Qwen-32",
|
| 39 |
+
"70": "Llama-70",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
#%%
|
| 44 |
+
def dsr1_postproc(response_txt_batch, *args, **kwargs):
|
| 45 |
+
response_txt_batch = [response_txt_batch]
|
| 46 |
+
ret = []
|
| 47 |
+
for response_txt in response_txt_batch:
|
| 48 |
+
_match = None
|
| 49 |
+
for pat in [
|
| 50 |
+
re.compile(r'\\boxed\{([\s\S]*)}'),
|
| 51 |
+
re.compile(r'</think>\n([\s\S]*)$'),
|
| 52 |
+
re.compile(r'^```\n?([^`]*)\n?```'),
|
| 53 |
+
]:
|
| 54 |
+
matches = pat.search(response_txt)
|
| 55 |
+
if matches:
|
| 56 |
+
_match = matches.group(1).strip()
|
| 57 |
+
break
|
| 58 |
+
if _match is not None:
|
| 59 |
+
ret.append(_match)
|
| 60 |
+
else:
|
| 61 |
+
ret.append(response_txt[:256].strip() if response_txt else "")
|
| 62 |
+
return ret[0]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
#%%
|
| 66 |
+
def get_dsr1_response(texts_batch, *args, **kwargs):
|
| 67 |
+
# global model, tokenizer
|
| 68 |
+
texts_batch = [texts_batch]
|
| 69 |
+
for texts in texts_batch:
|
| 70 |
+
if len(texts) > 1 and texts[1].startswith('Correct guess.'):
|
| 71 |
+
texts[1] = f"\\boxed{{{texts[1]}}}"
|
| 72 |
+
messages = [
|
| 73 |
+
[
|
| 74 |
+
{"role": "user",
|
| 75 |
+
"content": f"{text}\nPlease reason step by step, and put your final answer within \\boxed{{}} as plain text."}
|
| 76 |
+
if i % 2 == 0 else
|
| 77 |
+
{"role": "assistant", "content": {text}}
|
| 78 |
+
for i, text in enumerate(texts)
|
| 79 |
+
]
|
| 80 |
+
for texts in texts_batch
|
| 81 |
+
]
|
| 82 |
+
text_inputs = tokenizer.apply_chat_template(
|
| 83 |
+
messages,
|
| 84 |
+
tokenize=False,
|
| 85 |
+
add_generation_prompt=True
|
| 86 |
+
)
|
| 87 |
+
model_inputs = tokenizer(text_inputs, return_tensors="pt", add_special_tokens=False).to(model.device)
|
| 88 |
+
output_ids = model.generate(
|
| 89 |
+
**model_inputs,
|
| 90 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 91 |
+
do_sample=False,
|
| 92 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 93 |
+
)
|
| 94 |
+
generated_ids = [
|
| 95 |
+
_output_ids[len(input_ids):] for input_ids, _output_ids in zip(model_inputs.input_ids, output_ids)
|
| 96 |
+
]
|
| 97 |
+
response = [r.strip() for r in tokenizer.batch_decode(generated_ids, skip_special_tokens=True)]
|
| 98 |
+
return response[0]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
#%%
|
| 102 |
+
# response = get_dsr1_response(texts)
|
| 103 |
+
# print(dsr1_postproc(response))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
#%%
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
fp_out = (f"model_outputs/__runs__/results_deepseek-r1-distill-{DSR1_SIZE}b"
|
| 109 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
| 110 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
| 111 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
| 112 |
+
f".jsonl")
|
| 113 |
+
|
| 114 |
+
set_all_seed()
|
| 115 |
+
model_name = f"deepseek-ai/DeepSeek-R1-Distill-{DSR1_NAME[DSR1_SIZE]}B"
|
| 116 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 117 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 118 |
+
model_name,
|
| 119 |
+
device_map="auto",
|
| 120 |
+
torch_dtype="auto",
|
| 121 |
+
)
|
| 122 |
+
model.generation_config.temperature = None
|
| 123 |
+
model.generation_config.top_k = None
|
| 124 |
+
model.generation_config.top_p = None
|
| 125 |
+
|
| 126 |
+
run_with_agent(
|
| 127 |
+
fp_out,
|
| 128 |
+
get_dsr1_response,
|
| 129 |
+
dsr1_postproc,
|
| 130 |
+
n_turns=N_TURNS,
|
| 131 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
| 132 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
| 133 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
| 134 |
+
if SID_ST or SID_ED else None),
|
| 135 |
+
prepend_example=ONE_SHOT,
|
| 136 |
+
# remove_if_output_file_exist=False,
|
| 137 |
+
assistant_uses_raw_response=False,
|
| 138 |
+
)
|
agents/gemma_2_9b_it.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
#%%
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 7 |
+
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS
|
| 8 |
+
from agents import run_with_agent
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
#%%
|
| 12 |
+
def _getenv_as_int(attr, default=None):
|
| 13 |
+
ret = os.getenv(attr, default)
|
| 14 |
+
return None if ret is None else int(ret)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
| 18 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
| 19 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
| 20 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
|
| 21 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
| 22 |
+
GEMMA_SIZE = int(os.getenv("TG_GEMMA_SIZE", "9")) # {3, 9, 27}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
#%%
|
| 26 |
+
def gemma_postproc(response_txt, game_name, difficulty_level, *args, **kwargs):
|
| 27 |
+
# if game_name in [THE_GAMES[i] for i in ["1", "7"]]: # crossword
|
| 28 |
+
pat = re.compile(r'^```\n?([^`]*)\n?```')
|
| 29 |
+
match = pat.search(response_txt)
|
| 30 |
+
if match:
|
| 31 |
+
return match.group(1).strip().replace(" ", "")
|
| 32 |
+
|
| 33 |
+
# elif game_name == THE_GAMES["6"]: # anagram
|
| 34 |
+
pat = re.compile(r'\*\*\"?([^\"*]*)\"?\*\*')
|
| 35 |
+
match = pat.search(response_txt)
|
| 36 |
+
if match:
|
| 37 |
+
return match.group(1).strip()
|
| 38 |
+
|
| 39 |
+
return response_txt or ""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
#%%
|
| 43 |
+
def get_gemma_response(texts, game_name, difficulty_level, turn, *args, **kwargs):
|
| 44 |
+
# global gen_model, tokenizer
|
| 45 |
+
messages = [
|
| 46 |
+
{"role": ("model" if i % 2 else "user"), "content": text}
|
| 47 |
+
for i, text in enumerate(texts)
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
input_ids = tokenizer.apply_chat_template(
|
| 51 |
+
messages,
|
| 52 |
+
add_generation_prompt=True,
|
| 53 |
+
return_tensors="pt"
|
| 54 |
+
).to(gen_model.device)
|
| 55 |
+
|
| 56 |
+
terminators = [
|
| 57 |
+
tokenizer.eos_token_id,
|
| 58 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
gen_model.generation_config.temperature = None
|
| 62 |
+
outputs = gen_model.generate(
|
| 63 |
+
input_ids,
|
| 64 |
+
max_new_tokens=100,
|
| 65 |
+
eos_token_id=terminators,
|
| 66 |
+
do_sample=False,
|
| 67 |
+
# temperature=.0,
|
| 68 |
+
# top_p=1,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
response = outputs[0][input_ids.shape[-1]:]
|
| 72 |
+
return tokenizer.decode(response, skip_special_tokens=True).strip()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
#%%
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
fp_out = (f"model_outputs/results_gemma-2-{GEMMA_SIZE}b-it"
|
| 78 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
| 79 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
| 80 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
| 81 |
+
f".jsonl")
|
| 82 |
+
gen_model_checkpoint = f"google/gemma-2-{GEMMA_SIZE}b-it"
|
| 83 |
+
|
| 84 |
+
quantize = True
|
| 85 |
+
_kwargs = {
|
| 86 |
+
"device_map": "auto",
|
| 87 |
+
} if quantize else {}
|
| 88 |
+
|
| 89 |
+
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **_kwargs)
|
| 90 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **_kwargs)
|
| 91 |
+
print(f" > model.dtype: {gen_model.dtype}")
|
| 92 |
+
|
| 93 |
+
run_with_agent(
|
| 94 |
+
fp_out,
|
| 95 |
+
get_gemma_response,
|
| 96 |
+
gemma_postproc,
|
| 97 |
+
n_turns=N_TURNS,
|
| 98 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
| 99 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
| 100 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
| 101 |
+
if SID_ST or SID_ED else None),
|
| 102 |
+
prepend_example=ONE_SHOT,
|
| 103 |
+
# remove_if_output_file_exist=False,
|
| 104 |
+
)
|
agents/llama3.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
#%%
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 7 |
+
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS
|
| 8 |
+
from agents import run_with_agent
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
#%%
|
| 12 |
+
def _getenv_as_int(attr, default=None):
|
| 13 |
+
ret = os.getenv(attr, default)
|
| 14 |
+
return None if ret is None else int(ret)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
| 18 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
| 19 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
| 20 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
|
| 21 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
| 22 |
+
LLAMA_SIZE = os.getenv("TG_LLAMA_SIZE", "1-8")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
#%%
|
| 26 |
+
def llama_postproc(response_txt, *args, **kwargs):
|
| 27 |
+
# # if game_name in [THE_GAMES[i] for i in ["1", "7"]]: # crossword
|
| 28 |
+
# pat = re.compile(r'^```\n?([^`]*)\n?```')
|
| 29 |
+
# match = pat.search(response_txt)
|
| 30 |
+
# if match:
|
| 31 |
+
# return match.group(1).strip().replace(" ", "")
|
| 32 |
+
#
|
| 33 |
+
# # elif game_name == THE_GAMES["6"]: # anagram
|
| 34 |
+
# pat = re.compile(r'\*\*\"?([^\"*]*)\"?\*\*')
|
| 35 |
+
# match = pat.search(response_txt)
|
| 36 |
+
# if match:
|
| 37 |
+
# return match.group(1).strip()
|
| 38 |
+
return response_txt or ""
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
#%%
|
| 42 |
+
def get_llama_response(texts, *args, **kwargs):
|
| 43 |
+
# global model, tokenizer
|
| 44 |
+
|
| 45 |
+
messages = [
|
| 46 |
+
# {"role": "system", "content": "You are a bot that responds to weather queries."},
|
| 47 |
+
*[{"role": ("assistant" if i % 2 else "user"), "content": text} for i, text in enumerate(texts)]
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
text_inputs = tokenizer.apply_chat_template(
|
| 51 |
+
messages,
|
| 52 |
+
tokenize=False,
|
| 53 |
+
add_generation_prompt=True,
|
| 54 |
+
)
|
| 55 |
+
model_inputs = tokenizer([text_inputs], return_tensors="pt").to(model.device)
|
| 56 |
+
|
| 57 |
+
model.generation_config.do_sample = False
|
| 58 |
+
model.generation_config.temperature = None
|
| 59 |
+
model.generation_config.top_k = None
|
| 60 |
+
model.generation_config.top_p = None
|
| 61 |
+
generated_ids = model.generate(
|
| 62 |
+
**model_inputs,
|
| 63 |
+
max_new_tokens=128,
|
| 64 |
+
do_sample=False,
|
| 65 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 66 |
+
)
|
| 67 |
+
generated_ids = [
|
| 68 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 72 |
+
return response.strip()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
#%%
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
fp_out = (f"model_outputs/__runs__/results_llama-3.{LLAMA_SIZE}b-instruct"
|
| 78 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
| 79 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
| 80 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
| 81 |
+
f".jsonl")
|
| 82 |
+
|
| 83 |
+
model_name = f"meta-llama/Llama-3.{LLAMA_SIZE}B-Instruct"
|
| 84 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 85 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 86 |
+
model_name,
|
| 87 |
+
device_map="auto",
|
| 88 |
+
torch_dtype="bfloat16",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
run_with_agent(
|
| 92 |
+
fp_out,
|
| 93 |
+
get_llama_response,
|
| 94 |
+
llama_postproc,
|
| 95 |
+
n_turns=N_TURNS,
|
| 96 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
| 97 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
| 98 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
| 99 |
+
if SID_ST or SID_ED else None),
|
| 100 |
+
prepend_example=ONE_SHOT,
|
| 101 |
+
# remove_if_output_file_exist=False,
|
| 102 |
+
)
|
agents/qwen2_5_7b_instruct.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
#%%
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
| 9 |
+
from textgames import GAME_NAMES, LEVEL_IDS
|
| 10 |
+
from agents import run_with_agent
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
#%%
|
| 14 |
+
def set_all_seed(seed=42):
|
| 15 |
+
set_seed(seed)
|
| 16 |
+
np.random.seed(seed)
|
| 17 |
+
torch.manual_seed(seed)
|
| 18 |
+
torch.cuda.manual_seed_all(seed)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
#%%
|
| 22 |
+
def _getenv_as_int(attr, default=None):
|
| 23 |
+
ret = os.getenv(attr, default)
|
| 24 |
+
return None if ret is None else int(ret)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
| 28 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
| 29 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
| 30 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
|
| 31 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
| 32 |
+
QWEN_SIZE = int(os.getenv("TG_QWEN_SIZE", "32")) # {3, 7, 14, 32, 72} unsupported: {0.5, 1.5}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
#%%
|
| 36 |
+
def qwen_postproc(response_txt, game_name, difficulty_level, *args, **kwargs):
|
| 37 |
+
# # if game_name in [THE_GAMES[i] for i in ["1", "7"]]: # crossword
|
| 38 |
+
# pat = re.compile(r'^```\n?([^`]*)\n?```')
|
| 39 |
+
# match = pat.search(response_txt)
|
| 40 |
+
# if match:
|
| 41 |
+
# return match.group(1).strip().replace(" ", "")
|
| 42 |
+
#
|
| 43 |
+
# # elif game_name == THE_GAMES["6"]: # anagram
|
| 44 |
+
# pat = re.compile(r'\*\*\"?([^\"*]*)\"?\*\*')
|
| 45 |
+
# match = pat.search(response_txt)
|
| 46 |
+
# if match:
|
| 47 |
+
# return match.group(1).strip()
|
| 48 |
+
return response_txt or ""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
#%%
|
| 52 |
+
def get_qwen_response(texts_batch, game_name, difficulty_level, turn, *args, **kwargs):
|
| 53 |
+
# global model, tokenizer
|
| 54 |
+
texts_batch = [texts_batch] # currently we do not support batch
|
| 55 |
+
messages = [[
|
| 56 |
+
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
| 57 |
+
*[{"role": ("assistant" if i % 2 else "user"), "content": text} for i, text in enumerate(texts)]
|
| 58 |
+
] for texts in texts_batch ]
|
| 59 |
+
|
| 60 |
+
text_inputs = tokenizer.apply_chat_template(
|
| 61 |
+
messages,
|
| 62 |
+
tokenize=False,
|
| 63 |
+
add_generation_prompt=True
|
| 64 |
+
)
|
| 65 |
+
model_inputs = tokenizer([text_inputs], return_tensors="pt").to(model.device)
|
| 66 |
+
|
| 67 |
+
model.generation_config.temperature = None
|
| 68 |
+
model.generation_config.top_k = None
|
| 69 |
+
model.generation_config.top_p = None
|
| 70 |
+
generated_ids = model.generate(
|
| 71 |
+
**model_inputs,
|
| 72 |
+
max_new_tokens=128,
|
| 73 |
+
do_sample=False,
|
| 74 |
+
)
|
| 75 |
+
generated_ids = [
|
| 76 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 80 |
+
return response.strip()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
#%%
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
fp_out = (f"model_outputs/__runs__/results_qwen2-5-{QWEN_SIZE}b-instruct"
|
| 86 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
| 87 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
| 88 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
| 89 |
+
f".jsonl")
|
| 90 |
+
|
| 91 |
+
set_all_seed()
|
| 92 |
+
model_name = f"Qwen/Qwen2.5-{QWEN_SIZE}B-Instruct"
|
| 93 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 94 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 95 |
+
model_name,
|
| 96 |
+
device_map="auto",
|
| 97 |
+
torch_dtype="auto",
|
| 98 |
+
)
|
| 99 |
+
print(f" > model.dtype: {model.dtype}")
|
| 100 |
+
|
| 101 |
+
run_with_agent(
|
| 102 |
+
fp_out,
|
| 103 |
+
get_qwen_response,
|
| 104 |
+
qwen_postproc,
|
| 105 |
+
n_turns=N_TURNS,
|
| 106 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
| 107 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
| 108 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
| 109 |
+
if SID_ST or SID_ED else None),
|
| 110 |
+
prepend_example=ONE_SHOT,
|
| 111 |
+
# remove_if_output_file_exist=False,
|
| 112 |
+
)
|
agents/qwen2_5_math.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
#%%
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed, BitsAndBytesConfig
|
| 9 |
+
from textgames import GAME_NAMES, LEVEL_IDS
|
| 10 |
+
from agents import run_with_agent
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
#%%
|
| 14 |
+
def set_all_seed(seed=42):
|
| 15 |
+
set_seed(seed)
|
| 16 |
+
np.random.seed(seed)
|
| 17 |
+
torch.manual_seed(seed)
|
| 18 |
+
torch.cuda.manual_seed_all(seed)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
#%%
|
| 22 |
+
def _getenv_as_int(attr, default=None):
|
| 23 |
+
ret = os.getenv(attr, default)
|
| 24 |
+
return None if ret is None else int(ret)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
| 28 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
| 29 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
| 30 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
|
| 31 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
| 32 |
+
# MAX_NEW_TOKENS = _getenv_as_int("TG_MAX_NEW_TOKENS", 4096)
|
| 33 |
+
QWEN_MATH_SIZE = os.getenv("TG_QWEN_MATH_SIZE", "7") # {1.5, 7, 72}
|
| 34 |
+
QUANTIZE = _getenv_as_int("TG_QUANTIZE", 4)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
#%%
|
| 38 |
+
def qwenmath_postproc(response_txt_batch, *args, **kwargs):
|
| 39 |
+
response_txt_batch = [response_txt_batch]
|
| 40 |
+
ret = []
|
| 41 |
+
for response_txt in response_txt_batch:
|
| 42 |
+
_match = None
|
| 43 |
+
for pat in [
|
| 44 |
+
re.compile(r'\\boxed\{([\s\S]*)}'),
|
| 45 |
+
re.compile(r'^```\n?([^`]*)\n?```'),
|
| 46 |
+
# re.compile(r'</think>\n([\s\S]*)$'),
|
| 47 |
+
]:
|
| 48 |
+
matches = pat.search(response_txt)
|
| 49 |
+
if matches:
|
| 50 |
+
_match = matches.group(1).strip()
|
| 51 |
+
break
|
| 52 |
+
if _match is not None:
|
| 53 |
+
ret.append(_match)
|
| 54 |
+
else:
|
| 55 |
+
ret.append(response_txt if response_txt else "")
|
| 56 |
+
return ret[0]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
#%%
|
| 60 |
+
def get_qwenmath_response(texts_batch, *args, **kwargs):
|
| 61 |
+
# global model, tokenizer
|
| 62 |
+
texts_batch = [texts_batch]
|
| 63 |
+
for texts in texts_batch:
|
| 64 |
+
if (len(texts) > 1) and texts[2].startswith('Correct guess.'): # assert len(texts) % 2 == 1
|
| 65 |
+
texts[1] = f"\\boxed{{{texts[1]}}}"
|
| 66 |
+
messages = [
|
| 67 |
+
[
|
| 68 |
+
{"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{} as plain text."},
|
| 69 |
+
*[{"role": ("user" if i % 2 == 0 else "assistant"), "content": text} for i, text in enumerate(texts)],
|
| 70 |
+
]
|
| 71 |
+
for texts in texts_batch
|
| 72 |
+
]
|
| 73 |
+
# print(f"\n{messages[0]}", end="\n=====\n\n")
|
| 74 |
+
|
| 75 |
+
text_inputs = tokenizer.apply_chat_template(
|
| 76 |
+
messages,
|
| 77 |
+
tokenize=False,
|
| 78 |
+
add_generation_prompt=True
|
| 79 |
+
)
|
| 80 |
+
model_inputs = tokenizer(text_inputs, return_tensors="pt", add_special_tokens=False).to(model.device)
|
| 81 |
+
|
| 82 |
+
generated_ids = model.generate(
|
| 83 |
+
**model_inputs,
|
| 84 |
+
max_new_tokens=512,
|
| 85 |
+
do_sample=False,
|
| 86 |
+
)
|
| 87 |
+
generated_ids = [
|
| 88 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 92 |
+
return response.strip()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
#%%
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
fp_out = (f"model_outputs/__runs__/results_qwen2-5-math-{QWEN_MATH_SIZE}b-instruct_{QUANTIZE}bit"
|
| 98 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
| 99 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
| 100 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
| 101 |
+
f".jsonl")
|
| 102 |
+
|
| 103 |
+
set_all_seed()
|
| 104 |
+
if QWEN_MATH_SIZE in ['72'] and QUANTIZE < 16:
|
| 105 |
+
_additional_kwargs = {
|
| 106 |
+
"quantization_config": (
|
| 107 |
+
BitsAndBytesConfig(load_in_8bit=True)
|
| 108 |
+
if QUANTIZE == 8 else
|
| 109 |
+
BitsAndBytesConfig(load_in_4bit=True)
|
| 110 |
+
),
|
| 111 |
+
"low_cpu_mem_usage": True,
|
| 112 |
+
}
|
| 113 |
+
else:
|
| 114 |
+
_additional_kwargs = {"device_map": "auto"}
|
| 115 |
+
|
| 116 |
+
model_name = f"Qwen/Qwen2.5-Math-{QWEN_MATH_SIZE}B-Instruct"
|
| 117 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 118 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 119 |
+
model_name,
|
| 120 |
+
torch_dtype="auto",
|
| 121 |
+
**_additional_kwargs,
|
| 122 |
+
)
|
| 123 |
+
print(f" > model.dtype: {model.dtype}")
|
| 124 |
+
|
| 125 |
+
run_with_agent(
|
| 126 |
+
fp_out,
|
| 127 |
+
get_qwenmath_response,
|
| 128 |
+
qwenmath_postproc,
|
| 129 |
+
n_turns=N_TURNS,
|
| 130 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
| 131 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
| 132 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
| 133 |
+
if SID_ST or SID_ED else None),
|
| 134 |
+
prepend_example=ONE_SHOT,
|
| 135 |
+
# remove_if_output_file_exist=False,
|
| 136 |
+
assistant_uses_raw_response=True,
|
| 137 |
+
)
|
agents/runner.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
from textgames import GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name
|
| 6 |
+
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from itertools import product
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Union, Callable
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def response_postprocess(response_txt, game_name, difficulty_level):
|
| 14 |
+
return response_txt or ""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def run_with_agent(fp_out: Union[str, Path],
|
| 18 |
+
get_response: Callable,
|
| 19 |
+
get_postprocess: Callable = response_postprocess,
|
| 20 |
+
n_turns=3,
|
| 21 |
+
game_names_list=GAME_NAMES,
|
| 22 |
+
level_ids_list=LEVEL_IDS[:3],
|
| 23 |
+
sid_indices=None, # sid_index_range=range(0, 1000),
|
| 24 |
+
remove_if_output_file_exist=True,
|
| 25 |
+
prepend_example=False,
|
| 26 |
+
assistant_uses_raw_response=True,
|
| 27 |
+
) -> None:
|
| 28 |
+
os.makedirs(os.path.dirname(os.path.abspath(fp_out)), exist_ok=True)
|
| 29 |
+
print(fp_out)
|
| 30 |
+
if remove_if_output_file_exist:
|
| 31 |
+
with open(fp_out, "wb"):
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
for game_name, difficulty_level in product(game_names_list, level_ids_list):
|
| 35 |
+
game_str = f"{game_filename(game_name)}_{difficulty_level}"
|
| 36 |
+
game_cls = _game_class_from_name(game_name)
|
| 37 |
+
with open(f"problemsets/{game_str}.json", "r", encoding="utf8") as f:
|
| 38 |
+
sid_prompt_dict = json.load(f)
|
| 39 |
+
if sid_indices is not None:
|
| 40 |
+
sid_prompt_dict = {k: sid_prompt_dict[k] for k in sid_indices}
|
| 41 |
+
|
| 42 |
+
correct_cnt, exception_cnt = 0, 0
|
| 43 |
+
for sid, prompt in tqdm(sid_prompt_dict.items(), desc=game_str, total=len(sid_prompt_dict)):
|
| 44 |
+
cur_game = game_cls()
|
| 45 |
+
cur_game.load_game(prompt)
|
| 46 |
+
if prepend_example:
|
| 47 |
+
texts = [*cur_game.example(), f"Correct guess. Now let's try another example.\n{cur_game.get_prompt()}"]
|
| 48 |
+
else:
|
| 49 |
+
texts = [cur_game.get_prompt()]
|
| 50 |
+
for turn in range(1, n_turns + 1):
|
| 51 |
+
response_raw, response, e = None, None, None
|
| 52 |
+
solved, val_msg = False, None
|
| 53 |
+
try:
|
| 54 |
+
response_raw = get_response(texts, game_name, difficulty_level, turn, sid=sid)
|
| 55 |
+
response = get_postprocess(response_raw, game_name, difficulty_level)
|
| 56 |
+
texts.append(response_raw if assistant_uses_raw_response else response)
|
| 57 |
+
solved, val_msg = (False, None) if response is None else cur_game.validate(response)
|
| 58 |
+
texts.append(
|
| 59 |
+
f"Bad guess (Wrong Answer).\n{val_msg}\nPlease try again and print the answer only."
|
| 60 |
+
if not solved else "Correct guess."
|
| 61 |
+
)
|
| 62 |
+
except Exception as _e:
|
| 63 |
+
e = _e
|
| 64 |
+
# print(e)
|
| 65 |
+
# assert False, {"texts": texts, "response": response_raw,
|
| 66 |
+
# "args": (n_turns, game_names_list, remove_if_output_file_exist, prepend_example, assistant_uses_raw_response)}
|
| 67 |
+
with open(fp_out, "a", encoding="utf8") as o:
|
| 68 |
+
json.dump({
|
| 69 |
+
"game": game_str,
|
| 70 |
+
"session": sid,
|
| 71 |
+
"turn": turn,
|
| 72 |
+
"response": response,
|
| 73 |
+
"solved": solved,
|
| 74 |
+
"val_msg": val_msg,
|
| 75 |
+
"response_raw": response_raw,
|
| 76 |
+
"error": repr(e) if e else e,
|
| 77 |
+
}, o, ensure_ascii=False)
|
| 78 |
+
o.write("\n")
|
| 79 |
+
if solved:
|
| 80 |
+
correct_cnt += 1
|
| 81 |
+
if e:
|
| 82 |
+
exception_cnt += 1
|
| 83 |
+
if solved or e:
|
| 84 |
+
break
|
| 85 |
+
|
| 86 |
+
print(f"{game_filename(game_name)}_-_{difficulty_level}")
|
| 87 |
+
print(f" > Correct: {correct_cnt:>6,} ({correct_cnt / len(sid_prompt_dict):.2%})")
|
| 88 |
+
print(f" > Error : {exception_cnt:>6,} ({exception_cnt / len(sid_prompt_dict):.2%})")
|
| 89 |
+
|
play_gradio.py
CHANGED
|
@@ -51,7 +51,7 @@ def greet(request: gr.Request):
|
|
| 51 |
|
| 52 |
#%%
|
| 53 |
with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
| 54 |
-
((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle),
|
| 55 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
| 56 |
) = declare_components(demo, greet)
|
| 57 |
|
|
@@ -64,7 +64,7 @@ demo.launch(
|
|
| 64 |
auth=file_based_auth,
|
| 65 |
favicon_path=favicon_path if os.path.exists(favicon_path) else None,
|
| 66 |
share=True,
|
| 67 |
-
|
| 68 |
)
|
| 69 |
|
| 70 |
|
|
|
|
| 51 |
|
| 52 |
#%%
|
| 53 |
with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
| 54 |
+
((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
|
| 55 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
| 56 |
) = declare_components(demo, greet)
|
| 57 |
|
|
|
|
| 64 |
auth=file_based_auth,
|
| 65 |
favicon_path=favicon_path if os.path.exists(favicon_path) else None,
|
| 66 |
share=True,
|
| 67 |
+
show_api=False,
|
| 68 |
)
|
| 69 |
|
| 70 |
|
play_helper.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
# %%
|
| 2 |
import os
|
| 3 |
import time
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
import gradio as gr
|
| 6 |
import hashlib
|
|
@@ -19,19 +20,27 @@ from googleapiclient.discovery import build
|
|
| 19 |
from googleapiclient.errors import HttpError
|
| 20 |
from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# %%
|
| 24 |
-
def declare_components(demo, greet):
|
| 25 |
with gr.Row():
|
| 26 |
with gr.Column(scale=1):
|
| 27 |
m = gr.Markdown("Welcome to TextGames!", elem_id="md-greeting")
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
with gr.Column(scale=2):
|
| 30 |
-
solved_games_df = gr.DataFrame(headers=[g.split('\t', 1)[0] for g in GAME_NAMES], label="
|
| 31 |
-
interactive=False, elem_id="df-solved-games")
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
new_game_btn = gr.Button("Start Game", elem_id="btn-start-game")
|
| 35 |
render_toggle = gr.Checkbox(False, visible=False, interactive=False)
|
| 36 |
|
| 37 |
# cur_game_start = gr.BrowserState()
|
|
@@ -41,9 +50,12 @@ def declare_components(demo, greet):
|
|
| 41 |
user_state = gr.State()
|
| 42 |
uid_state = gr.State()
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
session_state.change(
|
| 45 |
-
lambda s: session_state_change_fn(s, 2, 0,
|
| 46 |
-
[session_state], [game_radio, level_radio, new_game_btn, logout_btn], js=js_remove_input_helper,
|
| 47 |
)
|
| 48 |
new_game_btn.click(check_to_start_new_game, [game_radio, level_radio, user_state, uid_state], [session_state])
|
| 49 |
solved_games.change(solved_games_change_fn, solved_games, solved_games_df)
|
|
@@ -54,13 +66,15 @@ def declare_components(demo, greet):
|
|
| 54 |
).then(
|
| 55 |
lambda: gr.update(interactive=False), None, [new_game_btn],
|
| 56 |
).then(
|
| 57 |
-
check_played_game, [solved_games,
|
| 58 |
).then(
|
| 59 |
-
lambda: gr.update(interactive=True)
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
return (
|
| 63 |
-
(m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle),
|
| 64 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
| 65 |
)
|
| 66 |
|
|
@@ -489,7 +503,8 @@ def _is_checksum_same(fp_out, matches=None, mime_type="application/octet-stream"
|
|
| 489 |
matches = _files.list(
|
| 490 |
q=f"'{_folder_id}' in parents and mimeType='{mime_type}' and name = '{fp_out.rsplit('/', 1)[-1]}'",
|
| 491 |
fields=f"files(name, id, {_cksm_methods_str})",
|
| 492 |
-
).execute()
|
|
|
|
| 493 |
if not os.path.exists(fp_out):
|
| 494 |
return None, None, matches
|
| 495 |
with open(fp_out, "rb") as o:
|
|
@@ -502,9 +517,9 @@ def _is_checksum_same(fp_out, matches=None, mime_type="application/octet-stream"
|
|
| 502 |
|
| 503 |
|
| 504 |
# %%
|
| 505 |
-
def upload_to_drive(fp_out, matches=None, mime_type="application/octet-stream", compare_checksum=True):
|
| 506 |
if compare_checksum:
|
| 507 |
-
same_checksum, _,
|
| 508 |
# same_checksum, _, _ = _is_checksum_same(
|
| 509 |
# fp_out, **{k: v for k, v in [('matches', matches), ('mime_type', mime_type)] if v})
|
| 510 |
if same_checksum:
|
|
@@ -513,7 +528,11 @@ def upload_to_drive(fp_out, matches=None, mime_type="application/octet-stream",
|
|
| 513 |
file_metadata = {"name": fn, "parents": [_folder_id]}
|
| 514 |
media = MediaFileUpload(fp_out)
|
| 515 |
try:
|
| 516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
except HttpError as error:
|
| 518 |
msg = f"Failed to upload the file, error: {error}"
|
| 519 |
print(msg)
|
|
@@ -547,7 +566,7 @@ def download_from_drive(fp_out, matches=None, mime_type="application/octet-strea
|
|
| 547 |
|
| 548 |
# %%
|
| 549 |
def start_new_game(game_name, level, session_state_component, is_solved_component, solved_games_component,
|
| 550 |
-
user=None, show_timer=False, uid=None):
|
| 551 |
# cur_game_id = GAME_IDS[GAME_NAMES.index(game_name)]
|
| 552 |
difficulty_level = LEVEL_IDS[LEVELS.index(level)]
|
| 553 |
|
|
@@ -555,11 +574,16 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
| 555 |
# elapsed_text = gr.Textbox("N/A", label=f"{game_name}", info=f"{level}", )
|
| 556 |
# gr.Timer(.3).tick(_calc_time_elapsed, [cur_game_start, elapsed_text, is_solved_component], [elapsed_text])
|
| 557 |
|
| 558 |
-
|
|
|
|
|
|
|
|
|
|
| 559 |
cur_game = (
|
| 560 |
new_game(game_name, difficulty_level)
|
| 561 |
if user is None else
|
| 562 |
preload_game(game_name, difficulty_level, user)
|
|
|
|
|
|
|
| 563 |
)
|
| 564 |
cur_game.attach_stats_output_(fp_out)
|
| 565 |
cur_game.flush_stats_(user_info_to_flush=user)
|
|
@@ -616,8 +640,12 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
| 616 |
js=js_submit)
|
| 617 |
give_up_checkbox = gr.Checkbox(False, visible=False, interactive=False)
|
| 618 |
give_up_btn.click(
|
|
|
|
|
|
|
| 619 |
lambda x: x, [give_up_checkbox], [give_up_checkbox],
|
| 620 |
js="(x) => confirm('🥹 Give-up? 💸')"
|
|
|
|
|
|
|
| 621 |
)
|
| 622 |
|
| 623 |
def _forfeiting(confirmed, _solved_games):
|
|
@@ -640,6 +668,8 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
| 640 |
def game_is_solved(_is_solved, _session_state, _solved_games, progress=gr.Progress()):
|
| 641 |
if _is_solved:
|
| 642 |
if level in LEVELS and level not in _solved_games[game_name]:
|
|
|
|
|
|
|
| 643 |
_solved_games[game_name].append(level)
|
| 644 |
return (
|
| 645 |
2,
|
|
@@ -655,8 +685,16 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
| 655 |
|
| 656 |
def finalize_game(_is_solved):
|
| 657 |
if _is_solved:
|
| 658 |
-
gr.Info("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
upload_to_drive(fp_out)
|
|
|
|
| 660 |
return gr.update(interactive=True)
|
| 661 |
return gr.update()
|
| 662 |
|
|
@@ -673,13 +711,14 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
| 673 |
|
| 674 |
|
| 675 |
# %%
|
| 676 |
-
def check_to_start_new_game(game_name, level, user=None, uid=None):
|
| 677 |
-
print(game_name, level)
|
| 678 |
if game_name is None or level is None:
|
| 679 |
raise gr.Error("please choose both Game & Level")
|
| 680 |
-
fp = _get_file_output(game_name, LEVEL_IDS[LEVELS.index(level)], uid)
|
| 681 |
if os.path.exists(fp):
|
| 682 |
-
raise gr.Error(f"You have done this game already.<br/>{game_name} - {level}")
|
|
|
|
| 683 |
if user is None:
|
| 684 |
gr.Warning("no user, game will be generated randomly")
|
| 685 |
# else:
|
|
@@ -691,16 +730,19 @@ def check_to_start_new_game(game_name, level, user=None, uid=None):
|
|
| 691 |
|
| 692 |
|
| 693 |
# %%
|
| 694 |
-
def check_played_game(solved_games,
|
|
|
|
|
|
|
| 695 |
matches = _files.list(
|
| 696 |
q=f"'{_folder_id}' in parents and mimeType='application/octet-stream' and name contains '{uid}_-_'",
|
| 697 |
fields=f"files(name, id, {_cksm_methods_str})",
|
| 698 |
-
).execute()
|
|
|
|
| 699 |
ret = dict()
|
| 700 |
for game_name in solved_games.keys():
|
| 701 |
cur = []
|
| 702 |
for level, level_id in zip(LEVELS, LEVEL_IDS):
|
| 703 |
-
fp_out = _get_file_output(game_name, level_id, uid)
|
| 704 |
_matches = list(filter(lambda m: fp_out.endswith(m['name']), matches))
|
| 705 |
if os.path.exists(fp_out):
|
| 706 |
upload_to_drive(fp_out, _matches)
|
|
@@ -708,7 +750,7 @@ def check_played_game(solved_games, uid, progress=gr.Progress()):
|
|
| 708 |
download_from_drive(fp_out, _matches)
|
| 709 |
if os.path.exists(fp_out):
|
| 710 |
cur.append(level)
|
| 711 |
-
ret[game_name] = cur
|
| 712 |
return ret, gr.update()
|
| 713 |
|
| 714 |
|
|
|
|
| 1 |
# %%
|
| 2 |
import os
|
| 3 |
import time
|
| 4 |
+
import json
|
| 5 |
import pandas as pd
|
| 6 |
import gradio as gr
|
| 7 |
import hashlib
|
|
|
|
| 20 |
from googleapiclient.errors import HttpError
|
| 21 |
from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload
|
| 22 |
|
| 23 |
+
# %%
|
| 24 |
+
_leaderboards = f"{os.getenv('TEXTGAMES_OUTPUT_DIR', '.')}/_leaderboards.jsonl"
|
| 25 |
+
|
| 26 |
|
| 27 |
# %%
|
| 28 |
+
def declare_components(demo, greet, use_login_button=False):
|
| 29 |
with gr.Row():
|
| 30 |
with gr.Column(scale=1):
|
| 31 |
m = gr.Markdown("Welcome to TextGames!", elem_id="md-greeting")
|
| 32 |
+
if use_login_button:
|
| 33 |
+
logout_btn = gr.LoginButton(size='sm')
|
| 34 |
+
reset_sid_btn = gr.Button("♻️ Reset Game Progress", variant='huggingface', size='sm')
|
| 35 |
+
else:
|
| 36 |
+
logout_btn = gr.Button("Logout", link="/logout", variant='huggingface', size='sm', elem_id="btn-logout")
|
| 37 |
+
reset_sid_btn = gr.Button(interactive=False, visible=False, size='sm')
|
| 38 |
with gr.Column(scale=2):
|
| 39 |
+
solved_games_df = gr.DataFrame(headers=[g.split('\t', 1)[0] for g in GAME_NAMES], label="Attempted Games",
|
| 40 |
+
row_count=(1, 'fixed'), interactive=False, elem_id="df-solved-games")
|
| 41 |
+
level_radio = gr.Radio(LEVELS, label="Level", elem_id="radio-level-name", visible=False)
|
| 42 |
+
game_radio = gr.Radio(GAME_NAMES, label="Game", elem_id="radio-game-name", visible=False)
|
| 43 |
+
new_game_btn = gr.Button("Start Game", elem_id="btn-start-game", visible=False)
|
| 44 |
render_toggle = gr.Checkbox(False, visible=False, interactive=False)
|
| 45 |
|
| 46 |
# cur_game_start = gr.BrowserState()
|
|
|
|
| 50 |
user_state = gr.State()
|
| 51 |
uid_state = gr.State()
|
| 52 |
|
| 53 |
+
if not os.path.exists(_leaderboards):
|
| 54 |
+
download_from_drive(_leaderboards, compare_checksum=False)
|
| 55 |
+
|
| 56 |
session_state.change(
|
| 57 |
+
lambda s: session_state_change_fn(s, 2, 0, 3, 0),
|
| 58 |
+
[session_state], [game_radio, level_radio, new_game_btn, logout_btn, reset_sid_btn], js=js_remove_input_helper,
|
| 59 |
)
|
| 60 |
new_game_btn.click(check_to_start_new_game, [game_radio, level_radio, user_state, uid_state], [session_state])
|
| 61 |
solved_games.change(solved_games_change_fn, solved_games, solved_games_df)
|
|
|
|
| 66 |
).then(
|
| 67 |
lambda: gr.update(interactive=False), None, [new_game_btn],
|
| 68 |
).then(
|
| 69 |
+
check_played_game, [solved_games, user_state], [solved_games, solved_games_df]
|
| 70 |
).then(
|
| 71 |
+
lambda uid: ([gr.update(visible=True, interactive=True)] if uid else
|
| 72 |
+
[gr.update(visible=False, interactive=False)]) * 3,
|
| 73 |
+
[uid_state], [level_radio, game_radio, new_game_btn]
|
| 74 |
)
|
| 75 |
|
| 76 |
return (
|
| 77 |
+
(m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
|
| 78 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
| 79 |
)
|
| 80 |
|
|
|
|
| 503 |
matches = _files.list(
|
| 504 |
q=f"'{_folder_id}' in parents and mimeType='{mime_type}' and name = '{fp_out.rsplit('/', 1)[-1]}'",
|
| 505 |
fields=f"files(name, id, {_cksm_methods_str})",
|
| 506 |
+
).execute()
|
| 507 |
+
matches = matches['files']
|
| 508 |
if not os.path.exists(fp_out):
|
| 509 |
return None, None, matches
|
| 510 |
with open(fp_out, "rb") as o:
|
|
|
|
| 517 |
|
| 518 |
|
| 519 |
# %%
|
| 520 |
+
def upload_to_drive(fp_out, matches=None, mime_type="application/octet-stream", compare_checksum=True, update=False):
|
| 521 |
if compare_checksum:
|
| 522 |
+
same_checksum, _, matches = _is_checksum_same(fp_out, matches, mime_type)
|
| 523 |
# same_checksum, _, _ = _is_checksum_same(
|
| 524 |
# fp_out, **{k: v for k, v in [('matches', matches), ('mime_type', mime_type)] if v})
|
| 525 |
if same_checksum:
|
|
|
|
| 528 |
file_metadata = {"name": fn, "parents": [_folder_id]}
|
| 529 |
media = MediaFileUpload(fp_out)
|
| 530 |
try:
|
| 531 |
+
if update and matches:
|
| 532 |
+
file_metadata.pop("parents")
|
| 533 |
+
_files.update(fileId=matches[0]['id'], body=file_metadata, media_body=media).execute()
|
| 534 |
+
else:
|
| 535 |
+
_files.create(body=file_metadata, media_body=media).execute()
|
| 536 |
except HttpError as error:
|
| 537 |
msg = f"Failed to upload the file, error: {error}"
|
| 538 |
print(msg)
|
|
|
|
| 566 |
|
| 567 |
# %%
|
| 568 |
def start_new_game(game_name, level, session_state_component, is_solved_component, solved_games_component,
|
| 569 |
+
user=None, show_timer=False, uid=None, sid=None):
|
| 570 |
# cur_game_id = GAME_IDS[GAME_NAMES.index(game_name)]
|
| 571 |
difficulty_level = LEVEL_IDS[LEVELS.index(level)]
|
| 572 |
|
|
|
|
| 574 |
# elapsed_text = gr.Textbox("N/A", label=f"{game_name}", info=f"{level}", )
|
| 575 |
# gr.Timer(.3).tick(_calc_time_elapsed, [cur_game_start, elapsed_text, is_solved_component], [elapsed_text])
|
| 576 |
|
| 577 |
+
if (not sid) and user and ('sid' in user):
|
| 578 |
+
sid = user['sid']
|
| 579 |
+
|
| 580 |
+
fp_out = _get_file_output(game_name, difficulty_level, f"{uid}_{sid}")
|
| 581 |
cur_game = (
|
| 582 |
new_game(game_name, difficulty_level)
|
| 583 |
if user is None else
|
| 584 |
preload_game(game_name, difficulty_level, user)
|
| 585 |
+
if sid is None else
|
| 586 |
+
preload_game(game_name, difficulty_level, user, sid=sid)
|
| 587 |
)
|
| 588 |
cur_game.attach_stats_output_(fp_out)
|
| 589 |
cur_game.flush_stats_(user_info_to_flush=user)
|
|
|
|
| 640 |
js=js_submit)
|
| 641 |
give_up_checkbox = gr.Checkbox(False, visible=False, interactive=False)
|
| 642 |
give_up_btn.click(
|
| 643 |
+
lambda: (gr.update(interactive=False), gr.update(interactive=False)), None, [submit_btn, give_up_btn]
|
| 644 |
+
).then(
|
| 645 |
lambda x: x, [give_up_checkbox], [give_up_checkbox],
|
| 646 |
js="(x) => confirm('🥹 Give-up? 💸')"
|
| 647 |
+
).then(
|
| 648 |
+
lambda: (gr.update(interactive=True), gr.update(interactive=True)), None, [submit_btn, give_up_btn]
|
| 649 |
)
|
| 650 |
|
| 651 |
def _forfeiting(confirmed, _solved_games):
|
|
|
|
| 668 |
def game_is_solved(_is_solved, _session_state, _solved_games, progress=gr.Progress()):
|
| 669 |
if _is_solved:
|
| 670 |
if level in LEVELS and level not in _solved_games[game_name]:
|
| 671 |
+
if isinstance(_solved_games[game_name], str):
|
| 672 |
+
_solved_games[game_name] = []
|
| 673 |
_solved_games[game_name].append(level)
|
| 674 |
return (
|
| 675 |
2,
|
|
|
|
| 685 |
|
| 686 |
def finalize_game(_is_solved):
|
| 687 |
if _is_solved:
|
| 688 |
+
gr.Info(f"Wrapping things up... Please click the button when available...<br/>"
|
| 689 |
+
f"Time: {cur_game.end_timestamp-cur_game.start_timestamp:4.1f} sec. Attempt: {cur_game.attempt_count}.")
|
| 690 |
+
with open(_leaderboards, "a", encoding="utf-8") as f:
|
| 691 |
+
json.dump({'uid': uid, 'sid': sid, 'turns': cur_game.attempt_count,
|
| 692 |
+
'st': cur_game.start_timestamp, 'ed': cur_game.end_timestamp,
|
| 693 |
+
'game_name': game_name, 'difficulty_level': difficulty_level,
|
| 694 |
+
}, f)
|
| 695 |
+
f.write("\n")
|
| 696 |
upload_to_drive(fp_out)
|
| 697 |
+
upload_to_drive(_leaderboards, update=True)
|
| 698 |
return gr.update(interactive=True)
|
| 699 |
return gr.update()
|
| 700 |
|
|
|
|
| 711 |
|
| 712 |
|
| 713 |
# %%
|
| 714 |
+
def check_to_start_new_game(game_name, level, user=None, uid=None, sid=None):
|
| 715 |
+
print(game_name, level, uid, sid)
|
| 716 |
if game_name is None or level is None:
|
| 717 |
raise gr.Error("please choose both Game & Level")
|
| 718 |
+
fp = _get_file_output(game_name, LEVEL_IDS[LEVELS.index(level)], f"{uid}_{sid}")
|
| 719 |
if os.path.exists(fp):
|
| 720 |
+
# raise gr.Error(f"You have done this game already.<br/>{game_name} - {level}")
|
| 721 |
+
gr.Warning("You have done this game already. Only first attempt is recorded in the scoreboard.")
|
| 722 |
if user is None:
|
| 723 |
gr.Warning("no user, game will be generated randomly")
|
| 724 |
# else:
|
|
|
|
| 730 |
|
| 731 |
|
| 732 |
# %%
|
| 733 |
+
def check_played_game(solved_games, user, progress=gr.Progress()):
|
| 734 |
+
uid = user['email']
|
| 735 |
+
sid = user.get('sid', None)
|
| 736 |
matches = _files.list(
|
| 737 |
q=f"'{_folder_id}' in parents and mimeType='application/octet-stream' and name contains '{uid}_-_'",
|
| 738 |
fields=f"files(name, id, {_cksm_methods_str})",
|
| 739 |
+
).execute()
|
| 740 |
+
matches = matches['files']
|
| 741 |
ret = dict()
|
| 742 |
for game_name in solved_games.keys():
|
| 743 |
cur = []
|
| 744 |
for level, level_id in zip(LEVELS, LEVEL_IDS):
|
| 745 |
+
fp_out = _get_file_output(game_name, level_id, f"{uid}_{sid}")
|
| 746 |
_matches = list(filter(lambda m: fp_out.endswith(m['name']), matches))
|
| 747 |
if os.path.exists(fp_out):
|
| 748 |
upload_to_drive(fp_out, _matches)
|
|
|
|
| 750 |
download_from_drive(fp_out, _matches)
|
| 751 |
if os.path.exists(fp_out):
|
| 752 |
cur.append(level)
|
| 753 |
+
ret[game_name] = cur or '∅'
|
| 754 |
return ret, gr.update()
|
| 755 |
|
| 756 |
|
play_with_auth.py
CHANGED
|
@@ -130,7 +130,7 @@ with gr.Blocks(title="TextGames") as login_demo:
|
|
| 130 |
app = gr.mount_gradio_app(app, login_demo, path="/login")
|
| 131 |
|
| 132 |
with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
| 133 |
-
((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle),
|
| 134 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
| 135 |
) = declare_components(demo, greet)
|
| 136 |
|
|
|
|
| 130 |
app = gr.mount_gradio_app(app, login_demo, path="/login")
|
| 131 |
|
| 132 |
with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
| 133 |
+
((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
|
| 134 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
| 135 |
) = declare_components(demo, greet)
|
| 136 |
|
play_with_hf.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
#%%
|
| 4 |
+
import os
|
| 5 |
+
# os.environ.setdefault("GRADIO_SERVER_PORT", "1080")
|
| 6 |
+
# os.environ.setdefault("TEXTGAMES_SHOW_HIDDEN_LEVEL", "1")
|
| 7 |
+
os.environ.setdefault("TEXTGAMES_LOADGAME_DIR", "problemsets")
|
| 8 |
+
os.environ.setdefault("TEXTGAMES_LOADGAME_ID", "42")
|
| 9 |
+
os.environ.setdefault("TEXTGAMES_MOCKUSER", "")
|
| 10 |
+
os.environ.setdefault("TEXTGAMES_OUTPUT_DIR", "user_outputs")
|
| 11 |
+
favicon_path = "textgames-scrabble-black2-ss.png"
|
| 12 |
+
|
| 13 |
+
#%%
|
| 14 |
+
from play_helper import css, declare_components, start_new_game, check_played_game, download_from_drive, upload_to_drive, _leaderboards
|
| 15 |
+
import pandas as pd
|
| 16 |
+
import gradio as gr
|
| 17 |
+
import random
|
| 18 |
+
import json
|
| 19 |
+
from textgames import GAME_NAMES
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
#%%
|
| 23 |
+
os.makedirs(os.getenv('TEXTGAMES_OUTPUT_DIR', '.'), exist_ok=True)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
#%%
|
| 27 |
+
def generate_sid(fp):
|
| 28 |
+
rand_int = random.randint(0, 1000)
|
| 29 |
+
with open(fp, "w", encoding="utf8") as f:
|
| 30 |
+
f.write(f"session_{rand_int:04}\n")
|
| 31 |
+
upload_to_drive(fp, mime_type="text/plain", update=True)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
#%%
|
| 35 |
+
def get_sid(uid, force_generate_sid=False):
|
| 36 |
+
fp = f"{os.getenv('TEXTGAMES_OUTPUT_DIR')}/{uid}_sid.txt"
|
| 37 |
+
if force_generate_sid:
|
| 38 |
+
generate_sid(fp)
|
| 39 |
+
if not os.path.exists(fp):
|
| 40 |
+
download_from_drive(fp, mime_type="text/plain", compare_checksum=False)
|
| 41 |
+
if not os.path.exists(fp):
|
| 42 |
+
generate_sid(fp)
|
| 43 |
+
with open(fp, "r", encoding="utf8") as f:
|
| 44 |
+
sid = [_ for _ in f][-1]
|
| 45 |
+
return sid.strip()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
#%%
|
| 49 |
+
def greet(request: gr.OAuthProfile | None):
|
| 50 |
+
user = {'email': os.getenv('TEXTGAMES_MOCKUSER', ''), 'name': ""}
|
| 51 |
+
if request is not None:
|
| 52 |
+
user = {'email': request.username, 'name': request.name, 'sid': get_sid(request.username)}
|
| 53 |
+
return f"""
|
| 54 |
+
Welcome to TextGames, {user['name'] or 'please login'}!
|
| 55 |
+
""", user, user['email']
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
#%%
|
| 59 |
+
with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
| 60 |
+
((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
|
| 61 |
+
(session_state, is_solved, solved_games, user_state, uid_state),
|
| 62 |
+
) = declare_components(demo, greet, use_login_button=True)
|
| 63 |
+
logout_btn.activate()
|
| 64 |
+
|
| 65 |
+
reset_sid_checkbox = gr.Checkbox(False, visible=False, interactive=False)
|
| 66 |
+
reset_sid_btn.click(
|
| 67 |
+
lambda: [gr.update(interactive=False)]*2, None, [reset_sid_btn, new_game_btn]
|
| 68 |
+
).then(
|
| 69 |
+
lambda x: x, [reset_sid_checkbox], [reset_sid_checkbox],
|
| 70 |
+
js="(x) => confirm('Reset Progress? (cannot be undone)')"
|
| 71 |
+
).then(
|
| 72 |
+
lambda: [gr.update(interactive=True)]*2, None, [reset_sid_btn, new_game_btn]
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def _resetting(confirmed, user):
|
| 76 |
+
uid = user.get('email', None) if isinstance(user, dict) else None
|
| 77 |
+
if uid is None:
|
| 78 |
+
gr.Warning("You need to log in first!")
|
| 79 |
+
elif confirmed:
|
| 80 |
+
user['sid'] = get_sid(uid, force_generate_sid=True)
|
| 81 |
+
return user, False
|
| 82 |
+
reset_sid_checkbox.change(
|
| 83 |
+
lambda: [gr.update(interactive=False)]*3, None, [logout_btn, reset_sid_btn, new_game_btn]
|
| 84 |
+
).then(
|
| 85 |
+
_resetting, [reset_sid_checkbox, user_state], [user_state, reset_sid_checkbox]
|
| 86 |
+
).then(
|
| 87 |
+
check_played_game, [solved_games, user_state], [solved_games, solved_games_df]
|
| 88 |
+
).then(
|
| 89 |
+
lambda: [gr.update(interactive=True)]*3, None, [logout_btn, reset_sid_btn, new_game_btn]
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@gr.render(inputs=[game_radio, level_radio, user_state, session_state, uid_state], triggers=[render_toggle.change])
|
| 94 |
+
def _start_new_game(game_name, level, user, _session_state, _uid_state):
|
| 95 |
+
if _session_state in [1, 2]:
|
| 96 |
+
start_new_game(game_name, level, session_state, is_solved, solved_games, user=user, uid=_uid_state)
|
| 97 |
+
|
| 98 |
+
#%%
|
| 99 |
+
with demo.route("Leaderboards", "/leaderboard") as demo_leaderboard:
|
| 100 |
+
gr.Markdown("Under Construction. Will be available soon.")
|
| 101 |
+
leaderboards = []
|
| 102 |
+
for tab in ["🚅 Easy", "🚀 Medium", "🛸 Hard"]:
|
| 103 |
+
with gr.Tab(tab):
|
| 104 |
+
leaderboards.append(gr.DataFrame(label="Rankings"))
|
| 105 |
+
|
| 106 |
+
# if os.path.exists(_leaderboards):
|
| 107 |
+
# datas = []
|
| 108 |
+
# with open(_leaderboards, "r", encoding="utf8") as f:
|
| 109 |
+
# for line in f:
|
| 110 |
+
# datas.append(json.loads(line))
|
| 111 |
+
# concat = [{'Level': d['difficulty_level'], 'User': d['uid'], 'Game': d['game_name'].split('\t', 1)[0], 'Attempts': d['turns'],
|
| 112 |
+
# "Time": d['ed'] - d['st']} for d in datas]
|
| 113 |
+
# else:
|
| 114 |
+
def add_dummies():
|
| 115 |
+
return pd.DataFrame({
|
| 116 |
+
'User': ['dummy'],
|
| 117 |
+
'Solved': [' '.join([g.split('\t', 1)[0] for g in GAME_NAMES])],
|
| 118 |
+
'Attempts': [8],
|
| 119 |
+
'Time': [7200.8],
|
| 120 |
+
})
|
| 121 |
+
for l in leaderboards:
|
| 122 |
+
demo_leaderboard.load(add_dummies, None, [l])
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
#%%
|
| 126 |
+
# demo.launch()
|
| 127 |
+
demo.launch(
|
| 128 |
+
favicon_path=favicon_path if os.path.exists(favicon_path) else None,
|
| 129 |
+
show_api=False,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
problemsets/Anagram Scribble_1.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
problemsets/Anagram Scribble_2.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
problemsets/Anagram Scribble_3.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
problemsets/Bracket Game_1.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
problemsets/Bracket Game_2.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
problemsets/Bracket Game_3.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
problemsets/Crossword Arranger_1.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
problemsets/Crossword Arranger_2.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
problemsets/Crossword Arranger_3.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
reval_ana3.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from textgames import GAME_NAMES, game_filename, _game_class_from_name
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
GAME_NAME = GAME_NAMES[5]
|
| 8 |
+
PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
|
| 9 |
+
MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
|
| 10 |
+
OUTPUT_FILENAMES = [
|
| 11 |
+
# "results_gemma-2-9b-it.1s.jsonl",
|
| 12 |
+
# "results_gemma-2-9b-it.zs.jsonl",
|
| 13 |
+
# "results_gemma-2-27b-it.1s.jsonl",
|
| 14 |
+
# "results_gemma-2-27b-it.zs.jsonl",
|
| 15 |
+
#
|
| 16 |
+
# "results_llama-3.1-8b-instruct.1s.jsonl",
|
| 17 |
+
# "results_llama-3.1-8b-instruct.zs.jsonl",
|
| 18 |
+
# "results_llama-3.1-70b-instruct.1s.jsonl",
|
| 19 |
+
# "results_llama-3.1-70b-instruct.zs.jsonl",
|
| 20 |
+
# "results_llama-3.3-70b-instruct.1s.jsonl",
|
| 21 |
+
# "results_llama-3.3-70b-instruct.zs.jsonl",
|
| 22 |
+
#
|
| 23 |
+
# "results_qwen2-5-7b-instruct.1s.jsonl",
|
| 24 |
+
# "results_qwen2-5-7b-instruct.zs.jsonl",
|
| 25 |
+
# "results_qwen2-5-14b-instruct.1s.jsonl",
|
| 26 |
+
# "results_qwen2-5-14b-instruct.zs.jsonl",
|
| 27 |
+
# "results_qwen2-5-32b-instruct.1s.jsonl",
|
| 28 |
+
# "results_qwen2-5-32b-instruct.zs.jsonl",
|
| 29 |
+
# "results_qwen2-5-72b-instruct.1s.jsonl",
|
| 30 |
+
# "results_qwen2-5-72b-instruct.zs.jsonl",
|
| 31 |
+
#
|
| 32 |
+
# "results_deepseek-r1-distill-14b.1s.jsonl",
|
| 33 |
+
# "results_deepseek-r1-distill-14b.zs.jsonl",
|
| 34 |
+
# "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
|
| 35 |
+
#
|
| 36 |
+
# "results_chatgpt-4o-mini.zs.jsonl",
|
| 37 |
+
# "results_chatgpt-o3-mini.zs.jsonl",
|
| 38 |
+
#
|
| 39 |
+
# "results_qwen2-5-7b-instruct_sp.1s.jsonl",
|
| 40 |
+
# "results_qwen2-5-7b-instruct_sp.zs.jsonl",
|
| 41 |
+
|
| 42 |
+
# "results_deepseek-r1-distill-8b.1s.jsonl",
|
| 43 |
+
"results_deepseek-r1-distill-8b.zs.jsonl",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 47 |
+
# !!! Must run reval_bracket_rerun.py first !!!
|
| 48 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def revalidate_anagram_3(fp, reval_dir="revalidate_anagram_3", source_dir="prior_revalidate"):
|
| 52 |
+
os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
|
| 53 |
+
count_pos, count_neg = 0, 0
|
| 54 |
+
with (open(MODEL_OUTPUT_DIR/source_dir/fp, "r", encoding="utf8") as i,
|
| 55 |
+
open(MODEL_OUTPUT_DIR/reval_dir/fp, "w", encoding="utf8") as o,
|
| 56 |
+
tqdm(total=1000, desc=fp) as pbar,
|
| 57 |
+
):
|
| 58 |
+
for line in i:
|
| 59 |
+
res = json.loads(line)
|
| 60 |
+
if (res['game'] == f"{game_filename(GAME_NAME)}_3"):
|
| 61 |
+
if (res['turn'] == 1):
|
| 62 |
+
cur_sid = res["session"]
|
| 63 |
+
prompt = sid_prompt_dict[cur_sid]
|
| 64 |
+
cur_game = game_cls()
|
| 65 |
+
cur_game.load_game(prompt)
|
| 66 |
+
pbar.update(1)
|
| 67 |
+
elif solved == True:
|
| 68 |
+
continue
|
| 69 |
+
else:
|
| 70 |
+
assert cur_sid == res["session"]
|
| 71 |
+
solved, _ = cur_game.validate(res["response"])
|
| 72 |
+
if solved and not res["solved"]:
|
| 73 |
+
count_pos += 1
|
| 74 |
+
elif not solved and res["solved"]:
|
| 75 |
+
count_neg += 1
|
| 76 |
+
res["solved"] = solved
|
| 77 |
+
o.write(json.dumps(res))
|
| 78 |
+
o.write("\n")
|
| 79 |
+
return count_pos, count_neg
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
game_cls = _game_class_from_name(GAME_NAME)
|
| 84 |
+
with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_3.json", "r", encoding="utf8") as f:
|
| 85 |
+
sid_prompt_dict = json.load(f)
|
| 86 |
+
for fp in OUTPUT_FILENAMES:
|
| 87 |
+
print(revalidate_anagram_3(fp))
|
reval_bracket_all.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from textgames import GAME_NAMES, game_filename, _game_class_from_name
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
GAME_NAME = GAME_NAMES[6]
|
| 8 |
+
PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
|
| 9 |
+
MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
|
| 10 |
+
OUTPUT_FILENAMES = [
|
| 11 |
+
# "results_gemma-2-9b-it.1s.jsonl",
|
| 12 |
+
# "results_gemma-2-9b-it.zs.jsonl",
|
| 13 |
+
# "results_gemma-2-27b-it.1s.jsonl",
|
| 14 |
+
# "results_gemma-2-27b-it.zs.jsonl",
|
| 15 |
+
#
|
| 16 |
+
# "results_llama-3.1-8b-instruct.1s.jsonl",
|
| 17 |
+
# "results_llama-3.1-8b-instruct.zs.jsonl",
|
| 18 |
+
# "results_llama-3.1-70b-instruct.1s.jsonl",
|
| 19 |
+
# "results_llama-3.1-70b-instruct.zs.jsonl",
|
| 20 |
+
# "results_llama-3.3-70b-instruct.1s.jsonl",
|
| 21 |
+
# "results_llama-3.3-70b-instruct.zs.jsonl",
|
| 22 |
+
#
|
| 23 |
+
# "results_qwen2-5-7b-instruct.1s.jsonl",
|
| 24 |
+
# "results_qwen2-5-7b-instruct.zs.jsonl",
|
| 25 |
+
# "results_qwen2-5-14b-instruct.1s.jsonl",
|
| 26 |
+
# "results_qwen2-5-14b-instruct.zs.jsonl",
|
| 27 |
+
# "results_qwen2-5-32b-instruct.1s.jsonl",
|
| 28 |
+
# "results_qwen2-5-32b-instruct.zs.jsonl",
|
| 29 |
+
# "results_qwen2-5-72b-instruct.1s.jsonl",
|
| 30 |
+
# "results_qwen2-5-72b-instruct.zs.jsonl",
|
| 31 |
+
#
|
| 32 |
+
# "results_deepseek-r1-distill-14b.1s.jsonl",
|
| 33 |
+
# "results_deepseek-r1-distill-14b.zs.jsonl",
|
| 34 |
+
# "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
|
| 35 |
+
#
|
| 36 |
+
# "results_chatgpt-4o-mini.1s.jsonl",
|
| 37 |
+
# "results_chatgpt-4o-mini.zs.jsonl",
|
| 38 |
+
# "results_chatgpt-o3-mini.zs.jsonl",
|
| 39 |
+
#
|
| 40 |
+
# "results_qwen2-5-7b-instruct_sp.1s.jsonl",
|
| 41 |
+
# "results_qwen2-5-7b-instruct_sp.zs.jsonl",
|
| 42 |
+
|
| 43 |
+
# "results_deepseek-r1-distill-8b.1s.jsonl",
|
| 44 |
+
"results_deepseek-r1-distill-8b.zs.jsonl",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def revalidate_bracket(fp, reval_dir="revalidate_bracket_all",
|
| 49 |
+
source_dirs=("revalidate_bracket_rerun", "revalidate_anagram_3",)):
|
| 50 |
+
os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
|
| 51 |
+
count_pos, count_neg = 0, 0
|
| 52 |
+
source_dir = "."
|
| 53 |
+
for source_dir in source_dirs:
|
| 54 |
+
if (MODEL_OUTPUT_DIR / source_dir / fp).exists():
|
| 55 |
+
break
|
| 56 |
+
with (open(MODEL_OUTPUT_DIR / source_dir / fp, "r", encoding="utf8") as i,
|
| 57 |
+
open(MODEL_OUTPUT_DIR / reval_dir / fp, "w", encoding="utf8") as o,
|
| 58 |
+
tqdm(total=3000, desc=fp) as pbar,
|
| 59 |
+
):
|
| 60 |
+
for line in i:
|
| 61 |
+
res = json.loads(line)
|
| 62 |
+
if (res['game'].startswith(f"{game_filename(GAME_NAME)}")):
|
| 63 |
+
sid_prompt_dict = sid_prompt_dicts[res['game'].rsplit("_", 1)[-1]]
|
| 64 |
+
if (res['turn'] == 1):
|
| 65 |
+
cur_sid = res["session"]
|
| 66 |
+
prompt = sid_prompt_dict[cur_sid]
|
| 67 |
+
cur_game = game_cls()
|
| 68 |
+
cur_game.load_game(prompt)
|
| 69 |
+
pbar.update(1)
|
| 70 |
+
elif solved == True:
|
| 71 |
+
continue
|
| 72 |
+
else:
|
| 73 |
+
assert cur_sid == res["session"]
|
| 74 |
+
solved, _ = cur_game.validate(res["response"])
|
| 75 |
+
if solved and not res["solved"]:
|
| 76 |
+
count_pos += 1
|
| 77 |
+
elif not solved and res["solved"]:
|
| 78 |
+
count_neg += 1
|
| 79 |
+
res["solved"] = solved
|
| 80 |
+
o.write(json.dumps(res))
|
| 81 |
+
o.write("\n")
|
| 82 |
+
return count_pos, count_neg
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
def load(k):
|
| 87 |
+
with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_{k}.json", "r", encoding="utf8") as f:
|
| 88 |
+
sid_prompt_dict = json.load(f)
|
| 89 |
+
return sid_prompt_dict
|
| 90 |
+
sid_prompt_dicts = {k: load(k) for k in map(str, range(1, 4))}
|
| 91 |
+
game_cls = _game_class_from_name(GAME_NAME)
|
| 92 |
+
for fp in OUTPUT_FILENAMES:
|
| 93 |
+
print(revalidate_bracket(fp))
|
| 94 |
+
|
reval_bracket_rerun.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @title ##### Combine Rerun of the Bracket - All
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
|
| 8 |
+
fd_new = MODEL_OUTPUT_DIR / "__runs__" / "_redo_bracket"
|
| 9 |
+
fd_ori = MODEL_OUTPUT_DIR / "revalidate_anagram_3"
|
| 10 |
+
fd_out = MODEL_OUTPUT_DIR / "revalidate_bracket_rerun"
|
| 11 |
+
|
| 12 |
+
OUTPUT_FILENAMES = [
|
| 13 |
+
"results_gemma-2-9b-it.1s.jsonl",
|
| 14 |
+
"results_gemma-2-9b-it.zs.jsonl",
|
| 15 |
+
"results_gemma-2-27b-it.1s.jsonl",
|
| 16 |
+
"results_gemma-2-27b-it.zs.jsonl",
|
| 17 |
+
|
| 18 |
+
"results_llama-3.1-8b-instruct.1s.jsonl",
|
| 19 |
+
"results_llama-3.1-8b-instruct.zs.jsonl",
|
| 20 |
+
"results_llama-3.1-70b-instruct.1s.jsonl",
|
| 21 |
+
"results_llama-3.1-70b-instruct.zs.jsonl",
|
| 22 |
+
"results_llama-3.3-70b-instruct.1s.jsonl",
|
| 23 |
+
"results_llama-3.3-70b-instruct.zs.jsonl",
|
| 24 |
+
|
| 25 |
+
"results_qwen2-5-7b-instruct.1s.jsonl",
|
| 26 |
+
"results_qwen2-5-7b-instruct.zs.jsonl",
|
| 27 |
+
"results_qwen2-5-14b-instruct.1s.jsonl",
|
| 28 |
+
"results_qwen2-5-14b-instruct.zs.jsonl",
|
| 29 |
+
"results_qwen2-5-32b-instruct.1s.jsonl",
|
| 30 |
+
"results_qwen2-5-32b-instruct.zs.jsonl",
|
| 31 |
+
"results_qwen2-5-72b-instruct.1s.jsonl",
|
| 32 |
+
"results_qwen2-5-72b-instruct.zs.jsonl",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
os.makedirs(fd_out, exist_ok=True)
|
| 36 |
+
for fp in tqdm(OUTPUT_FILENAMES):
|
| 37 |
+
with open(fd_out / fp, "w", encoding="utf8") as o:
|
| 38 |
+
with open(fd_ori / fp, "r", encoding="utf8") as i:
|
| 39 |
+
for line in i:
|
| 40 |
+
res = json.loads(line)
|
| 41 |
+
if res['game'].startswith("Bracket Game"):
|
| 42 |
+
continue
|
| 43 |
+
o.write(line)
|
| 44 |
+
with open((fd_new / fp).with_suffix(".6.jsonl"), "r", encoding="utf8") as i:
|
| 45 |
+
for line in i:
|
| 46 |
+
o.write(line)
|
reval_crosswords_all.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from textgames import GAME_NAMES, game_filename, _game_class_from_name
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
GAME_NAME = GAME_NAMES[0]
|
| 8 |
+
PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
|
| 9 |
+
MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
|
| 10 |
+
OUTPUT_FILENAMES = [
|
| 11 |
+
# "results_gemma-2-9b-it.1s.jsonl",
|
| 12 |
+
# "results_gemma-2-9b-it.zs.jsonl",
|
| 13 |
+
# "results_gemma-2-27b-it.1s.jsonl",
|
| 14 |
+
# "results_gemma-2-27b-it.zs.jsonl",
|
| 15 |
+
#
|
| 16 |
+
# "results_llama-3.1-8b-instruct.1s.jsonl",
|
| 17 |
+
# "results_llama-3.1-8b-instruct.zs.jsonl",
|
| 18 |
+
# "results_llama-3.1-70b-instruct.1s.jsonl",
|
| 19 |
+
# "results_llama-3.1-70b-instruct.zs.jsonl",
|
| 20 |
+
# "results_llama-3.3-70b-instruct.1s.jsonl",
|
| 21 |
+
# "results_llama-3.3-70b-instruct.zs.jsonl",
|
| 22 |
+
#
|
| 23 |
+
# "results_qwen2-5-7b-instruct.1s.jsonl",
|
| 24 |
+
# "results_qwen2-5-7b-instruct.zs.jsonl",
|
| 25 |
+
# "results_qwen2-5-14b-instruct.1s.jsonl",
|
| 26 |
+
# "results_qwen2-5-14b-instruct.zs.jsonl",
|
| 27 |
+
# "results_qwen2-5-32b-instruct.1s.jsonl",
|
| 28 |
+
# "results_qwen2-5-32b-instruct.zs.jsonl",
|
| 29 |
+
# "results_qwen2-5-72b-instruct.1s.jsonl",
|
| 30 |
+
# "results_qwen2-5-72b-instruct.zs.jsonl",
|
| 31 |
+
#
|
| 32 |
+
# "results_deepseek-r1-distill-14b.1s.jsonl",
|
| 33 |
+
# "results_deepseek-r1-distill-14b.zs.jsonl",
|
| 34 |
+
# # "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
|
| 35 |
+
#
|
| 36 |
+
# "results_chatgpt-4o-mini.1s.jsonl",
|
| 37 |
+
# "results_chatgpt-4o-mini.zs.jsonl",
|
| 38 |
+
# "results_chatgpt-o3-mini.zs.jsonl",
|
| 39 |
+
#
|
| 40 |
+
# # "results_qwen2-5-7b-instruct_sp.1s.jsonl",
|
| 41 |
+
# # "results_qwen2-5-7b-instruct_sp.zs.jsonl",
|
| 42 |
+
|
| 43 |
+
"results_deepseek-r1-distill-8b.1s.jsonl",
|
| 44 |
+
"results_deepseek-r1-distill-8b.zs.jsonl",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def revalidate_bracket(fp, reval_dir="revalidate_crosswords_all",
|
| 49 |
+
source_dirs=("revalidate_bracket_all",)):
|
| 50 |
+
os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
|
| 51 |
+
count_pos, count_neg = 0, 0
|
| 52 |
+
source_dir = "."
|
| 53 |
+
for source_dir in source_dirs:
|
| 54 |
+
if (MODEL_OUTPUT_DIR / source_dir / fp).exists():
|
| 55 |
+
break
|
| 56 |
+
with (open(MODEL_OUTPUT_DIR / source_dir / fp, "r", encoding="utf8") as i,
|
| 57 |
+
open(MODEL_OUTPUT_DIR / reval_dir / fp, "w", encoding="utf8") as o,
|
| 58 |
+
tqdm(total=3000, desc=fp) as pbar,
|
| 59 |
+
):
|
| 60 |
+
for line in i:
|
| 61 |
+
res = json.loads(line)
|
| 62 |
+
if (res['game'].startswith(f"{game_filename(GAME_NAME)}")):
|
| 63 |
+
sid_prompt_dict = sid_prompt_dicts[res['game'].rsplit("_", 1)[-1]]
|
| 64 |
+
if (res['turn'] == 1):
|
| 65 |
+
cur_sid = res["session"]
|
| 66 |
+
prompt = sid_prompt_dict[cur_sid]
|
| 67 |
+
cur_game = game_cls()
|
| 68 |
+
cur_game.load_game(prompt)
|
| 69 |
+
pbar.update(1)
|
| 70 |
+
elif solved == True:
|
| 71 |
+
continue
|
| 72 |
+
else:
|
| 73 |
+
assert cur_sid == res["session"]
|
| 74 |
+
solved, _ = cur_game.validate(res["response"])
|
| 75 |
+
if solved and not res["solved"]:
|
| 76 |
+
count_pos += 1
|
| 77 |
+
elif not solved and res["solved"]:
|
| 78 |
+
count_neg += 1
|
| 79 |
+
res["solved"] = solved
|
| 80 |
+
o.write(json.dumps(res))
|
| 81 |
+
o.write("\n")
|
| 82 |
+
return count_pos, count_neg
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
def load(k):
|
| 87 |
+
with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_{k}.json", "r", encoding="utf8") as f:
|
| 88 |
+
sid_prompt_dict = json.load(f)
|
| 89 |
+
return sid_prompt_dict
|
| 90 |
+
sid_prompt_dicts = {k: load(k) for k in map(str, range(1, 4))}
|
| 91 |
+
game_cls = _game_class_from_name(GAME_NAME)
|
| 92 |
+
for fp in OUTPUT_FILENAMES:
|
| 93 |
+
print(revalidate_bracket(fp))
|
| 94 |
+
|
reval_sudoku_all.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from textgames import GAME_NAMES, game_filename, _game_class_from_name
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
GAME_NAME = GAME_NAMES[1]
|
| 8 |
+
PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
|
| 9 |
+
MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
|
| 10 |
+
OUTPUT_FILENAMES = [
|
| 11 |
+
# "results_gemma-2-9b-it.1s.jsonl",
|
| 12 |
+
# "results_gemma-2-9b-it.zs.jsonl",
|
| 13 |
+
# "results_gemma-2-27b-it.1s.jsonl",
|
| 14 |
+
# "results_gemma-2-27b-it.zs.jsonl",
|
| 15 |
+
#
|
| 16 |
+
# "results_llama-3.1-8b-instruct.1s.jsonl",
|
| 17 |
+
# "results_llama-3.1-8b-instruct.zs.jsonl",
|
| 18 |
+
# "results_llama-3.1-70b-instruct.1s.jsonl",
|
| 19 |
+
# "results_llama-3.1-70b-instruct.zs.jsonl",
|
| 20 |
+
# "results_llama-3.3-70b-instruct.1s.jsonl",
|
| 21 |
+
# "results_llama-3.3-70b-instruct.zs.jsonl",
|
| 22 |
+
#
|
| 23 |
+
# "results_qwen2-5-7b-instruct.1s.jsonl",
|
| 24 |
+
# "results_qwen2-5-7b-instruct.zs.jsonl",
|
| 25 |
+
# "results_qwen2-5-14b-instruct.1s.jsonl",
|
| 26 |
+
# "results_qwen2-5-14b-instruct.zs.jsonl",
|
| 27 |
+
# "results_qwen2-5-32b-instruct.1s.jsonl",
|
| 28 |
+
# "results_qwen2-5-32b-instruct.zs.jsonl",
|
| 29 |
+
# "results_qwen2-5-72b-instruct.1s.jsonl",
|
| 30 |
+
# "results_qwen2-5-72b-instruct.zs.jsonl",
|
| 31 |
+
#
|
| 32 |
+
# "results_deepseek-r1-distill-14b.1s.jsonl",
|
| 33 |
+
# "results_deepseek-r1-distill-14b.zs.jsonl",
|
| 34 |
+
# # "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
|
| 35 |
+
#
|
| 36 |
+
# "results_chatgpt-4o-mini.1s.jsonl",
|
| 37 |
+
# "results_chatgpt-4o-mini.zs.jsonl",
|
| 38 |
+
# "results_chatgpt-o3-mini.zs.jsonl",
|
| 39 |
+
#
|
| 40 |
+
# # "results_qwen2-5-7b-instruct_sp.1s.jsonl",
|
| 41 |
+
# # "results_qwen2-5-7b-instruct_sp.zs.jsonl",
|
| 42 |
+
|
| 43 |
+
"results_deepseek-r1-distill-8b.1s.jsonl",
|
| 44 |
+
"results_deepseek-r1-distill-8b.zs.jsonl",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def revalidate_bracket(fp, reval_dir="revalidate_sudoku_all",
|
| 49 |
+
source_dirs=("revalidate_crosswords_all",)):
|
| 50 |
+
os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
|
| 51 |
+
count_pos, count_neg = 0, 0
|
| 52 |
+
source_dir = "."
|
| 53 |
+
for source_dir in source_dirs:
|
| 54 |
+
if (MODEL_OUTPUT_DIR / source_dir / fp).exists():
|
| 55 |
+
break
|
| 56 |
+
with (open(MODEL_OUTPUT_DIR / source_dir / fp, "r", encoding="utf8") as i,
|
| 57 |
+
open(MODEL_OUTPUT_DIR / reval_dir / fp, "w", encoding="utf8") as o,
|
| 58 |
+
tqdm(total=3000, desc=fp) as pbar,
|
| 59 |
+
):
|
| 60 |
+
for line in i:
|
| 61 |
+
res = json.loads(line)
|
| 62 |
+
if (res['game'].startswith(f"{game_filename(GAME_NAME)}")):
|
| 63 |
+
sid_prompt_dict = sid_prompt_dicts[res['game'].rsplit("_", 1)[-1]]
|
| 64 |
+
if (res['turn'] == 1):
|
| 65 |
+
cur_sid = res["session"]
|
| 66 |
+
prompt = sid_prompt_dict[cur_sid]
|
| 67 |
+
cur_game = game_cls()
|
| 68 |
+
cur_game.load_game(prompt)
|
| 69 |
+
pbar.update(1)
|
| 70 |
+
elif solved == True:
|
| 71 |
+
continue
|
| 72 |
+
else:
|
| 73 |
+
assert cur_sid == res["session"]
|
| 74 |
+
solved, _ = cur_game.validate(res["response"])
|
| 75 |
+
if solved and not res["solved"]:
|
| 76 |
+
count_pos += 1
|
| 77 |
+
elif not solved and res["solved"]:
|
| 78 |
+
count_neg += 1
|
| 79 |
+
res["solved"] = solved
|
| 80 |
+
o.write(json.dumps(res))
|
| 81 |
+
o.write("\n")
|
| 82 |
+
return count_pos, count_neg
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
def load(k):
|
| 87 |
+
with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_{k}.json", "r", encoding="utf8") as f:
|
| 88 |
+
sid_prompt_dict = json.load(f)
|
| 89 |
+
return sid_prompt_dict
|
| 90 |
+
sid_prompt_dicts = {k: load(k) for k in map(str, range(1, 4))}
|
| 91 |
+
game_cls = _game_class_from_name(GAME_NAME)
|
| 92 |
+
for fp in OUTPUT_FILENAMES:
|
| 93 |
+
print(revalidate_bracket(fp))
|
| 94 |
+
|
textgames-scrabble-black2-ss.png
CHANGED
|
|
Git LFS Details
|
textgames/__init__.py
CHANGED
|
@@ -14,8 +14,10 @@ from pandas import read_csv
|
|
| 14 |
import json
|
| 15 |
|
| 16 |
|
| 17 |
-
# [
|
| 18 |
-
# "
|
|
|
|
|
|
|
| 19 |
THE_GAMES = {
|
| 20 |
k: v.get_game_name() for k, v in [
|
| 21 |
("1", CrosswordArrangerGame),
|
|
@@ -60,12 +62,13 @@ def _game_class_from_name(game_name):
|
|
| 60 |
return None
|
| 61 |
|
| 62 |
|
| 63 |
-
def preload_game(game_name, level_id, user):
|
| 64 |
game_cls = _game_class_from_name(game_name)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
| 69 |
print(f"preload_game('{game_name}', '{level_id}', '{user['email']}') on {sid}")
|
| 70 |
|
| 71 |
with open(f"problemsets/{game_filename(game_name)}_{level_id}.json", "r", encoding="utf8") as f:
|
|
|
|
| 14 |
import json
|
| 15 |
|
| 16 |
|
| 17 |
+
# [
|
| 18 |
+
# "📰\tCrossword Arranger", "🧩\tText Sudoku", "🏝️\tIslands", "🔑\tPassword Game",
|
| 19 |
+
# "📈\tOrdering Text", "🔤\tAnagram Scribble", "🗳️\tBracket Game", "🔎\tString Search",
|
| 20 |
+
# ]
|
| 21 |
THE_GAMES = {
|
| 22 |
k: v.get_game_name() for k, v in [
|
| 23 |
("1", CrosswordArrangerGame),
|
|
|
|
| 62 |
return None
|
| 63 |
|
| 64 |
|
| 65 |
+
def preload_game(game_name, level_id, user, sid=None):
|
| 66 |
game_cls = _game_class_from_name(game_name)
|
| 67 |
+
if not sid:
|
| 68 |
+
email_sid_dict = read_csv(
|
| 69 |
+
f"{os.getenv('TEXTGAMES_OUTPUT_DIR')}/textgames_userauth.tsv", sep='\t'
|
| 70 |
+
).dropna().set_index("EMAIL").SID.to_dict()
|
| 71 |
+
sid = email_sid_dict.get(user["email"])
|
| 72 |
print(f"preload_game('{game_name}', '{level_id}', '{user['email']}') on {sid}")
|
| 73 |
|
| 74 |
with open(f"problemsets/{game_filename(game_name)}_{level_id}.json", "r", encoding="utf8") as f:
|
textgames/anagram_scribble/anagram_scribble.py
CHANGED
|
@@ -5,6 +5,7 @@ import json
|
|
| 5 |
import string
|
| 6 |
import re
|
| 7 |
|
|
|
|
| 8 |
class AnagramScribble(BaseGame):
|
| 9 |
@staticmethod
|
| 10 |
def get_game_name() -> str:
|
|
@@ -43,6 +44,18 @@ class AnagramScribble(BaseGame):
|
|
| 43 |
if total_chars_extraction != "Error loading game state.":
|
| 44 |
characters = total_chars_extraction.split(",")
|
| 45 |
self.total_chars = [char.strip().strip("'") for char in characters]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
| 48 |
self.low_num_chars = kwargs['low_num_chars']
|
|
@@ -57,16 +70,16 @@ class AnagramScribble(BaseGame):
|
|
| 57 |
|
| 58 |
def _get_prompt(self) -> str:
|
| 59 |
if self.allow_repeat:
|
| 60 |
-
prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can be used multiple times. Please write None if there is no valid combination."
|
| 61 |
else:
|
| 62 |
-
prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can only be used once. Please write None if there is no valid combination."
|
| 63 |
return prompt
|
| 64 |
|
| 65 |
def _validate(self, answer: str) -> (bool, str):
|
| 66 |
-
answer
|
| 67 |
-
if self.possible_ans != "" and answer == "none":
|
| 68 |
val_msg = "There is a valid answer."
|
| 69 |
return False, val_msg
|
|
|
|
| 70 |
if len(answer) != self.num_chars:
|
| 71 |
val_msg = f"Your answer must be exactly {self.num_chars} characters long"
|
| 72 |
return False, val_msg
|
|
@@ -74,12 +87,31 @@ class AnagramScribble(BaseGame):
|
|
| 74 |
if char not in self.total_chars:
|
| 75 |
val_msg = "Your answer must only contain the characters provided"
|
| 76 |
return False, val_msg
|
| 77 |
-
if (not self.allow_repeat and (len(set(answer)) != len(answer))
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
if answer not in self.WORD_LIST_BIN[str(self.num_chars)]:
|
| 82 |
val_msg = "Your answer is not a valid English word"
|
| 83 |
return False, val_msg
|
| 84 |
|
| 85 |
return True, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import string
|
| 6 |
import re
|
| 7 |
|
| 8 |
+
|
| 9 |
class AnagramScribble(BaseGame):
|
| 10 |
@staticmethod
|
| 11 |
def get_game_name() -> str:
|
|
|
|
| 44 |
if total_chars_extraction != "Error loading game state.":
|
| 45 |
characters = total_chars_extraction.split(",")
|
| 46 |
self.total_chars = [char.strip().strip("'") for char in characters]
|
| 47 |
+
self.possible_ans = ""
|
| 48 |
+
_chars = sorted(self.total_chars)
|
| 49 |
+
for w in self.WORD_LIST_BIN[str(self.num_chars)]:
|
| 50 |
+
_ans = sorted(w)
|
| 51 |
+
j, k = 0, 0
|
| 52 |
+
while j < len(_ans) and k < len(_chars):
|
| 53 |
+
if _ans[j] == _chars[k]:
|
| 54 |
+
j += 1
|
| 55 |
+
k += 1
|
| 56 |
+
if j >= len(_ans):
|
| 57 |
+
self.possible_ans = w
|
| 58 |
+
break
|
| 59 |
|
| 60 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
| 61 |
self.low_num_chars = kwargs['low_num_chars']
|
|
|
|
| 70 |
|
| 71 |
def _get_prompt(self) -> str:
|
| 72 |
if self.allow_repeat:
|
| 73 |
+
prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can be used multiple times. Please write None if there is no valid combination. Print only the answer.\n"
|
| 74 |
else:
|
| 75 |
+
prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can only be used once. Please write None if there is no valid combination. Print only the answer.\n"
|
| 76 |
return prompt
|
| 77 |
|
| 78 |
def _validate(self, answer: str) -> (bool, str):
|
| 79 |
+
if self.possible_ans != "" and answer == "None":
|
|
|
|
| 80 |
val_msg = "There is a valid answer."
|
| 81 |
return False, val_msg
|
| 82 |
+
answer = answer.lower()
|
| 83 |
if len(answer) != self.num_chars:
|
| 84 |
val_msg = f"Your answer must be exactly {self.num_chars} characters long"
|
| 85 |
return False, val_msg
|
|
|
|
| 87 |
if char not in self.total_chars:
|
| 88 |
val_msg = "Your answer must only contain the characters provided"
|
| 89 |
return False, val_msg
|
| 90 |
+
# if (not self.allow_repeat and (len(set(answer)) != len(answer))
|
| 91 |
+
# and (len(self.possible_ans) == len(set(self.possible_ans)))):
|
| 92 |
+
if not self.allow_repeat:
|
| 93 |
+
_ans = sorted(answer)
|
| 94 |
+
_chars = sorted(self.total_chars)
|
| 95 |
+
j, k = 0, 0
|
| 96 |
+
while j < len(_ans) and k < len(_chars):
|
| 97 |
+
if _ans[j] == _chars[k]:
|
| 98 |
+
j += 1
|
| 99 |
+
k += 1
|
| 100 |
+
if j < len(_ans):
|
| 101 |
+
val_msg = "Your answer must not contain repeated characters"
|
| 102 |
+
return False, val_msg
|
| 103 |
if answer not in self.WORD_LIST_BIN[str(self.num_chars)]:
|
| 104 |
val_msg = "Your answer is not a valid English word"
|
| 105 |
return False, val_msg
|
| 106 |
|
| 107 |
return True, ""
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def example() -> (str, str):
|
| 111 |
+
prompt = ("Construct a valid 5-character English word from the following letters:\n"
|
| 112 |
+
"['e', 'l', 'o', 'b', 's', 'p'].\n"
|
| 113 |
+
"Each character can be used multiple times. Please write None if there is no valid combination."
|
| 114 |
+
" Print only the answer.\n")
|
| 115 |
+
answer = "sleep"
|
| 116 |
+
return prompt, answer
|
| 117 |
+
|
textgames/bracket_game/bracket_game.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import random
|
| 2 |
import re
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
from textgames.base_game import BaseGame
|
| 5 |
#%%
|
|
@@ -57,48 +58,94 @@ class BracketGame(BaseGame):
|
|
| 57 |
self.MULTI_WORD_LIST.append(self.WORD_LIST[num1] + self.WORD_LIST[num2])
|
| 58 |
|
| 59 |
def _validate(self, answer: str) -> (bool, str):
|
| 60 |
-
|
| 61 |
-
arr = answer.split(rule[0])
|
| 62 |
-
|
| 63 |
-
if rule[1][1] not in arr[0] or rule[1][2] not in arr[1]:
|
| 64 |
-
val_msg = f"{rule[0]} is not between the correct bracket, {rule[1][1]} not in {arr[0]} and {rule[1][2]} not in {arr[1]}"
|
| 65 |
-
return False, val_msg
|
| 66 |
-
|
| 67 |
-
filter_answer = answer
|
| 68 |
-
for i in range(0, 26):
|
| 69 |
-
cc = chr(ord("a") + i)
|
| 70 |
-
filter_answer = filter_answer.replace(cc,"")
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
| 93 |
val_msg = "There is a closing bracket without an open bracket"
|
| 94 |
return False, val_msg
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
val_msg = f"The depth of the bracket is {
|
| 100 |
return False, val_msg
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
| 103 |
num_words = kwargs["num_words"]
|
| 104 |
num_rules = kwargs["num_rules"]
|
|
@@ -141,6 +188,7 @@ class BracketGame(BaseGame):
|
|
| 141 |
prompt = f"You are given a text {self.string} Your job is to put some valid parenthesis brackets in the text such that:\n"
|
| 142 |
for rule in self.rules:
|
| 143 |
prompt += f"- \"{rule[0]}\" is inside a {rule[1][0]} bracket\n"
|
|
|
|
| 144 |
prompt += f"The bracket depth must be {self.depth} and print only the answer\n"
|
| 145 |
return prompt
|
| 146 |
|
|
@@ -159,7 +207,7 @@ class BracketGame(BaseGame):
|
|
| 159 |
else:
|
| 160 |
return 0
|
| 161 |
|
| 162 |
-
content = state_string.split("the text such that:")[1].split("\nThe
|
| 163 |
|
| 164 |
self.words = []
|
| 165 |
self.rules = []
|
|
@@ -188,3 +236,14 @@ class BracketGame(BaseGame):
|
|
| 188 |
self.create_multiple_words()
|
| 189 |
|
| 190 |
sort_game_states(self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import random
|
| 2 |
import re
|
| 3 |
+
from bisect import bisect_left
|
| 4 |
from pathlib import Path
|
| 5 |
from textgames.base_game import BaseGame
|
| 6 |
#%%
|
|
|
|
| 58 |
self.MULTI_WORD_LIST.append(self.WORD_LIST[num1] + self.WORD_LIST[num2])
|
| 59 |
|
| 60 |
def _validate(self, answer: str) -> (bool, str):
|
| 61 |
+
answer = "".join(answer.split()).lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
if ("".join(filter(lambda a: a.isalpha(), answer)) !=
|
| 64 |
+
"".join(filter(lambda a: a.isalpha(), self.string.lower()))):
|
| 65 |
+
val_msg = f"You are not allowed to change the character sequence of base text '{self.string}'."
|
| 66 |
+
return False, val_msg
|
| 67 |
+
|
| 68 |
+
char2type_op = {b[1]: b[0] for b in self.BRACKETS}
|
| 69 |
+
char2type_ed = {b[2]: b[0] for b in self.BRACKETS}
|
| 70 |
+
|
| 71 |
+
depth_count = {b[0]: [(-1, 0)] for b in self.BRACKETS}
|
| 72 |
+
|
| 73 |
+
def push(dc, v):
|
| 74 |
+
cur_depth = dc[-1][-1]
|
| 75 |
+
if cur_depth < 0:
|
| 76 |
+
return False
|
| 77 |
+
dc.append((i, cur_depth + v))
|
| 78 |
+
return True
|
| 79 |
+
|
| 80 |
+
mak, cur_mak = 0, 0
|
| 81 |
+
for i, c in enumerate(answer):
|
| 82 |
+
if c in char2type_op:
|
| 83 |
+
push(depth_count[char2type_op[c]], 1)
|
| 84 |
+
cur_mak += 1
|
| 85 |
+
elif c in char2type_ed:
|
| 86 |
+
if not push(depth_count[char2type_ed[c]], -1):
|
| 87 |
val_msg = "There is a closing bracket without an open bracket"
|
| 88 |
return False, val_msg
|
| 89 |
+
cur_mak -= 1
|
| 90 |
+
mak = max(mak, cur_mak)
|
| 91 |
+
|
| 92 |
+
if mak != self.depth:
|
| 93 |
+
val_msg = f"The depth of the bracket is {mak}. The expected depth is {self.depth}"
|
| 94 |
return False, val_msg
|
| 95 |
|
| 96 |
+
for rule in self.rules:
|
| 97 |
+
i = answer.find(rule[0])
|
| 98 |
+
if i < 0:
|
| 99 |
+
val_msg = f"The text '{rule[0]}' is not found in your answer."
|
| 100 |
+
return False, val_msg
|
| 101 |
+
|
| 102 |
+
i_depth = bisect_left(depth_count[rule[1][0]], (i, -1)) - 1
|
| 103 |
+
if depth_count[rule[1][0]][i_depth][-1] < 1:
|
| 104 |
+
val_msg = f"The text '{rule[0]}' is not inside any {rule[1][0]} bracket {rule[1][1]} {rule[1][2]}"
|
| 105 |
+
return False, val_msg
|
| 106 |
+
|
| 107 |
+
# arr = answer.split(rule[0])
|
| 108 |
+
# if rule[1][1] not in arr[0] or rule[1][2] not in arr[1]:
|
| 109 |
+
# val_msg = f"The text '{rule[0]}' is not between the correct bracket, {rule[1][1]} not in {arr[0]} and {rule[1][2]} not in {arr[1]}"
|
| 110 |
+
# return False, val_msg
|
| 111 |
+
|
| 112 |
+
return True, ""
|
| 113 |
+
|
| 114 |
+
# filter_answer = answer
|
| 115 |
+
# for i in range(0, 26):
|
| 116 |
+
# cc = chr(ord("a") + i)
|
| 117 |
+
# filter_answer = filter_answer.replace(cc,"")
|
| 118 |
+
#
|
| 119 |
+
# cc = chr(ord("A") + i)
|
| 120 |
+
# filter_answer = filter_answer.replace(cc,"")
|
| 121 |
+
#
|
| 122 |
+
# open_bracket_list = ["[", "{", "(", "<"]
|
| 123 |
+
# close_bracket_map = {
|
| 124 |
+
# "[":"]", "{":"}", "(":")", "<":">"
|
| 125 |
+
# }
|
| 126 |
+
#
|
| 127 |
+
# # check max depth
|
| 128 |
+
# count = 0
|
| 129 |
+
# st = []
|
| 130 |
+
#
|
| 131 |
+
# for i in range(len(filter_answer)):
|
| 132 |
+
# if (filter_answer[i] in open_bracket_list):
|
| 133 |
+
# st.append(filter_answer[i]) # pushing the bracket in the stack
|
| 134 |
+
# else:
|
| 135 |
+
# if len(st) > 0 and (filter_answer[i] == close_bracket_map[st[-1]]):
|
| 136 |
+
# if (count < len(st)):
|
| 137 |
+
# count = len(st)
|
| 138 |
+
# st.pop()
|
| 139 |
+
# else:
|
| 140 |
+
# val_msg = "There is a closing bracket without an open bracket"
|
| 141 |
+
# return False, val_msg
|
| 142 |
+
#
|
| 143 |
+
# if count == self.depth:
|
| 144 |
+
# return True, ""
|
| 145 |
+
# else:
|
| 146 |
+
# val_msg = f"The depth of the bracket is {count}. The expected depth is {self.depth}"
|
| 147 |
+
# return False, val_msg
|
| 148 |
+
|
| 149 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
| 150 |
num_words = kwargs["num_words"]
|
| 151 |
num_rules = kwargs["num_rules"]
|
|
|
|
| 188 |
prompt = f"You are given a text {self.string} Your job is to put some valid parenthesis brackets in the text such that:\n"
|
| 189 |
for rule in self.rules:
|
| 190 |
prompt += f"- \"{rule[0]}\" is inside a {rule[1][0]} bracket\n"
|
| 191 |
+
prompt += "The open and close parenthesis for block is [ ], curly is { }, round is ( ), and angle is < >\n"
|
| 192 |
prompt += f"The bracket depth must be {self.depth} and print only the answer\n"
|
| 193 |
return prompt
|
| 194 |
|
|
|
|
| 207 |
else:
|
| 208 |
return 0
|
| 209 |
|
| 210 |
+
content = state_string.split("the text such that:")[1].split("\nThe open and close parenthesis ")[0].split("\n")
|
| 211 |
|
| 212 |
self.words = []
|
| 213 |
self.rules = []
|
|
|
|
| 236 |
self.create_multiple_words()
|
| 237 |
|
| 238 |
sort_game_states(self)
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def example() -> (str, str):
|
| 242 |
+
prompt = ("You are given a text fabuloustextgames Your job is to put some valid parenthesis brackets in the text such that:\n"
|
| 243 |
+
"- \"games\" is inside a round bracket\n"
|
| 244 |
+
"- \"text\" is inside a angle bracket\n"
|
| 245 |
+
"- \"fabulous\" is inside a block bracket\n"
|
| 246 |
+
"The open and close parenthesis for block is [ ], curly is { }, round is ( ), and angle is < >\n"
|
| 247 |
+
"The bracket depth must be 2 and print only the answer\n")
|
| 248 |
+
answer = "[[fabulous]<text>(games)]"
|
| 249 |
+
return prompt, answer
|
textgames/crossword_arranger/crossword_arranger.py
CHANGED
|
@@ -125,19 +125,47 @@ class CrosswordArrangerGame(BaseGame):
|
|
| 125 |
return prompt
|
| 126 |
|
| 127 |
def _validate(self, answer: str) -> (bool, str):
|
| 128 |
-
|
|
|
|
|
|
|
| 129 |
val_msg = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
if len(ans_hor) != self.board_size:
|
| 131 |
val_msg = f"Mismatch answer length found!! Expected size of {self.board_size}, got {len(ans_hor)}."
|
| 132 |
return False, val_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
ans_ver = [''.join(ans_hor[r][c] for r in range(self.board_size)) for c in range(self.board_size)]
|
| 134 |
word_set = set(self.word_list)
|
| 135 |
-
for w in chain(ans_hor, ans_ver):
|
| 136 |
if w not in word_set:
|
|
|
|
|
|
|
| 137 |
return False, val_msg
|
| 138 |
word_set.remove(w)
|
| 139 |
return True, val_msg
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
#%%
|
| 143 |
|
|
|
|
| 125 |
return prompt
|
| 126 |
|
| 127 |
def _validate(self, answer: str) -> (bool, str):
|
| 128 |
+
answer = answer if answer else ""
|
| 129 |
+
# ans_hor = list(filter(None, answer.lower().replace(' ', '\n').split("\n")))
|
| 130 |
+
ans_hor = answer.lower().split()
|
| 131 |
val_msg = ""
|
| 132 |
+
if len(ans_hor) != self.board_size:
|
| 133 |
+
arr = answer.lower().split()
|
| 134 |
+
if all(len(l) == 1 for l in arr) and (len(arr) == self.board_size * self.board_size):
|
| 135 |
+
ans_hor = ["".join(arr[i:i+self.board_size]) for i in range(0, len(arr), self.board_size)]
|
| 136 |
if len(ans_hor) != self.board_size:
|
| 137 |
val_msg = f"Mismatch answer length found!! Expected size of {self.board_size}, got {len(ans_hor)}."
|
| 138 |
return False, val_msg
|
| 139 |
+
for w in ans_hor:
|
| 140 |
+
if len(w) != self.board_size:
|
| 141 |
+
val_msg = f"Mismatch answer length found!! Expected size of {self.board_size}, got {len(w)}."
|
| 142 |
+
return False, val_msg
|
| 143 |
ans_ver = [''.join(ans_hor[r][c] for r in range(self.board_size)) for c in range(self.board_size)]
|
| 144 |
word_set = set(self.word_list)
|
| 145 |
+
for i, w in enumerate(chain(ans_hor, ans_ver)):
|
| 146 |
if w not in word_set:
|
| 147 |
+
val_msg = (f"Mismatch answer word found!! {'Horizontal' if i < self.board_size else 'Vertical'} word"
|
| 148 |
+
f" '{w}' is not in the word set.")
|
| 149 |
return False, val_msg
|
| 150 |
word_set.remove(w)
|
| 151 |
return True, val_msg
|
| 152 |
|
| 153 |
+
@staticmethod
|
| 154 |
+
def example() -> (str, str):
|
| 155 |
+
prompt = (f"Given a board size of 3x3, arrange a possible crossword puzzle answer from a list of words.\n"
|
| 156 |
+
f"Item in the list can only be used once.\n\n"
|
| 157 |
+
f"List of words:\n"
|
| 158 |
+
f"- app\n"
|
| 159 |
+
f"- all\n"
|
| 160 |
+
f"- and\n"
|
| 161 |
+
f"- lee\n"
|
| 162 |
+
f"- let\n"
|
| 163 |
+
f"- pat\n"
|
| 164 |
+
f"- pee\n"
|
| 165 |
+
f"- pet\n\n"
|
| 166 |
+
f"Print only the answer.")
|
| 167 |
+
answer = "app\nlee\nlet"
|
| 168 |
+
return prompt, answer
|
| 169 |
|
| 170 |
#%%
|
| 171 |
|
textgames/islands/islands.py
CHANGED
|
@@ -99,8 +99,8 @@ class Islands(BaseGame):
|
|
| 99 |
answer = [a.replace(" ", "").lower().strip() for a in answer]
|
| 100 |
|
| 101 |
# check the size
|
| 102 |
-
if len(answer) != self.N or len(
|
| 103 |
-
val_msg = f"2D grid is not {self.N} x {self.N}. ({len(answer)} x {len(answer
|
| 104 |
return False, val_msg
|
| 105 |
|
| 106 |
# check the tiles, ensure they are valid
|
|
@@ -194,4 +194,16 @@ Your 2D grid must follow the following rules:
|
|
| 194 |
|
| 195 |
Print only the answer.
|
| 196 |
"""
|
| 197 |
-
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
answer = [a.replace(" ", "").lower().strip() for a in answer]
|
| 100 |
|
| 101 |
# check the size
|
| 102 |
+
if len(answer) != self.N or any((len(a) < self.N) for a in answer):
|
| 103 |
+
val_msg = f"2D grid is not {self.N} x {self.N}. ({len(answer)} x {set(len(a) for a in answer)})"
|
| 104 |
return False, val_msg
|
| 105 |
|
| 106 |
# check the tiles, ensure they are valid
|
|
|
|
| 194 |
|
| 195 |
Print only the answer.
|
| 196 |
"""
|
| 197 |
+
return prompt
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def example() -> (str, str):
|
| 201 |
+
prompt = ("You are asked to construct a 2D 5 x 5 grid, consisting of water tiles (denoted by \u2019.\u2019), \n"
|
| 202 |
+
"land tiles (denoted by \u2019#\u2019). \n\n"
|
| 203 |
+
"A group of connected land tiles in 4 cardinal directions forms an island.\n\n"
|
| 204 |
+
"Your 2D grid must follow the following rules:\n"
|
| 205 |
+
"- There must be exactly 1 islands.\n"
|
| 206 |
+
"- The size of each island must be from 1 to 2 tiles.\n\n"
|
| 207 |
+
"Print only the answer.\n")
|
| 208 |
+
answer = "...##\n.....\n.....\n.....\n....."
|
| 209 |
+
return prompt, answer
|
textgames/ordering_text/ordering_text.py
CHANGED
|
@@ -5,10 +5,10 @@ Rules Description
|
|
| 5 |
|
| 6 |
word length:
|
| 7 |
- example: word less than 5 characters gets 10 points
|
| 8 |
-
- possible operands: {
|
| 9 |
-
-
|
| 10 |
-
- possible combinations: {
|
| 11 |
-
- only 1
|
| 12 |
|
| 13 |
neighboring / consecutive chars
|
| 14 |
- example: every pair of consecutive consonant gets 5 points
|
|
@@ -66,6 +66,15 @@ from textgames.base_game import BaseGame
|
|
| 66 |
from textgames.assets.word_list import WORDS_LIST, WORDS_BY_LEN
|
| 67 |
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
#%%
|
| 70 |
class Scoring:
|
| 71 |
def __init__(self, point: int):
|
|
@@ -505,14 +514,15 @@ class OrderingTextGame(BaseGame):
|
|
| 505 |
return self.answer # sorted(self.words, key=lambda word: (self.get_point(word), word))
|
| 506 |
|
| 507 |
def _validate(self, answer: str) -> (bool, str):
|
| 508 |
-
answer = answer.lower().replace('
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
|
|
|
| 516 |
|
| 517 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
| 518 |
if "preset_config" in kwargs:
|
|
@@ -588,6 +598,26 @@ class OrderingTextGame(BaseGame):
|
|
| 588 |
prompt += "\nPrint only the answer."
|
| 589 |
return prompt
|
| 590 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
|
| 592 |
#%%
|
| 593 |
|
|
|
|
| 5 |
|
| 6 |
word length:
|
| 7 |
- example: word less than 5 characters gets 10 points
|
| 8 |
+
- possible operands: {\\eq, \\lt, \\gt, \\ne}
|
| 9 |
+
- \\le and \\ge will be randomized for prompt generation
|
| 10 |
+
- possible combinations: {\\gt\\lt, \\gt\\lt\\ne}
|
| 11 |
+
- only 1 \\ne is considered
|
| 12 |
|
| 13 |
neighboring / consecutive chars
|
| 14 |
- example: every pair of consecutive consonant gets 5 points
|
|
|
|
| 66 |
from textgames.assets.word_list import WORDS_LIST, WORDS_BY_LEN
|
| 67 |
|
| 68 |
|
| 69 |
+
#%%
|
| 70 |
+
index_to_word = {
|
| 71 |
+
1: "first", 2: "second", 3: "third", 4: "fourth", 5: "fifth",
|
| 72 |
+
6: "sixth", 7: "seventh", 8: "eighth", 9: "ninth", 10: "tenth",
|
| 73 |
+
11: "eleventh", 12: "twelfth", 13: "thirteenth", 14: "fourteenth",
|
| 74 |
+
15: "fifteenth", 16: "sixteenth", 17: "seventeenth", 18: "eighteenth",
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
#%%
|
| 79 |
class Scoring:
|
| 80 |
def __init__(self, point: int):
|
|
|
|
| 514 |
return self.answer # sorted(self.words, key=lambda word: (self.get_point(word), word))
|
| 515 |
|
| 516 |
def _validate(self, answer: str) -> (bool, str):
|
| 517 |
+
answer = answer.lower().replace(',', ' ').split()
|
| 518 |
+
gold = self.get_answer()
|
| 519 |
+
if len(answer) < len(gold):
|
| 520 |
+
return False, f"Your answer is too short. There should be {len(gold)} items."
|
| 521 |
+
for i, (a, b) in enumerate(zip(answer, self.get_answer()), 1):
|
| 522 |
+
if a != b:
|
| 523 |
+
val_msg = f"'{a}' is not supposed to be the {index_to_word[i]} word in the order."
|
| 524 |
+
return False, val_msg
|
| 525 |
+
return True, ""
|
| 526 |
|
| 527 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
| 528 |
if "preset_config" in kwargs:
|
|
|
|
| 598 |
prompt += "\nPrint only the answer."
|
| 599 |
return prompt
|
| 600 |
|
| 601 |
+
@staticmethod
|
| 602 |
+
def example() -> (str, str):
|
| 603 |
+
prompt = ("Given a set of rules to calculate point, sort the set of words in decreasing order.\n"
|
| 604 |
+
"When there 2 or more words with same point, sort lexicographically.\n\n"
|
| 605 |
+
"Rules:\n"
|
| 606 |
+
"- add 10 points if there exists 'u' in the word\n\n"
|
| 607 |
+
"Words:\n"
|
| 608 |
+
"- hudi\n"
|
| 609 |
+
"- genta\n"
|
| 610 |
+
"- aji\n"
|
| 611 |
+
"- ruochen\n\n"
|
| 612 |
+
"Print only the answer.")
|
| 613 |
+
answer = (
|
| 614 |
+
"hudi\n"
|
| 615 |
+
"ruochen\n"
|
| 616 |
+
"aji\n"
|
| 617 |
+
"genta"
|
| 618 |
+
)
|
| 619 |
+
return prompt, answer
|
| 620 |
+
|
| 621 |
|
| 622 |
#%%
|
| 623 |
|
textgames/password_game/password_game.py
CHANGED
|
@@ -274,3 +274,13 @@ class PasswordGame(BaseGame):
|
|
| 274 |
self.rules = [rule for rule in new_rules]
|
| 275 |
|
| 276 |
sort_game_states(self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
self.rules = [rule for rule in new_rules]
|
| 275 |
|
| 276 |
sort_game_states(self)
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def example() -> (str, str):
|
| 280 |
+
prompt = ("Please write a text string without any space by following a set of given rules."
|
| 281 |
+
" Please write only the answer and follow the following criteria:\n"
|
| 282 |
+
"- the text has 6 english character\n"
|
| 283 |
+
"- the text has 0 uppercase characters\n")
|
| 284 |
+
answer = "hoodie"
|
| 285 |
+
return prompt, answer
|
| 286 |
+
|
textgames/string_search/string_search.py
CHANGED
|
@@ -309,4 +309,16 @@ Find a substring of exactly {self.answer_len} characters long that:
|
|
| 309 |
{extra_constraints}
|
| 310 |
Print only the answer.
|
| 311 |
"""
|
| 312 |
-
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
{extra_constraints}
|
| 310 |
Print only the answer.
|
| 311 |
"""
|
| 312 |
+
return prompt
|
| 313 |
+
|
| 314 |
+
@staticmethod
|
| 315 |
+
def example() -> (str, str):
|
| 316 |
+
prompt = ("You are given the following string:\n"
|
| 317 |
+
"hudigentaajiruochen\n\n"
|
| 318 |
+
"Find a substring of exactly 3 characters long that:\n"
|
| 319 |
+
" - Contains t\n"
|
| 320 |
+
" - Does not contain i and a\n\n"
|
| 321 |
+
"Print only the answer.\n")
|
| 322 |
+
answer = "ent"
|
| 323 |
+
return prompt, answer
|
| 324 |
+
|
textgames/sudoku/sudoku.py
CHANGED
|
@@ -9,6 +9,7 @@ Please solve the 9x9 sudoku with 1,2,3,4,5,6,7,8,9 as the values and fill _ with
|
|
| 9 |
Print only the answer.
|
| 10 |
"""
|
| 11 |
|
|
|
|
| 12 |
#%%
|
| 13 |
class Sudoku(BaseGame):
|
| 14 |
@staticmethod
|
|
@@ -28,34 +29,47 @@ class Sudoku(BaseGame):
|
|
| 28 |
for j in range(self.size):
|
| 29 |
num = mat[i][j]
|
| 30 |
if num == self.empty_character:
|
| 31 |
-
return False
|
| 32 |
|
| 33 |
subgrid_index = (i // self.srn) * self.srn + (j // self.srn)
|
| 34 |
|
| 35 |
-
if num in rows[i]
|
| 36 |
-
return False
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
rows[i].add(num)
|
| 39 |
cols[j].add(num)
|
| 40 |
subgrids[subgrid_index].add(num)
|
| 41 |
|
| 42 |
-
return True
|
| 43 |
|
| 44 |
def _validate(self, input) -> (bool, str):
|
| 45 |
mat = [[self.empty_character for i in range(self.size)] for j in range(self.size)]
|
| 46 |
|
|
|
|
| 47 |
arr = input.split()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
for i in range(len(arr)):
|
| 49 |
for j in range(len(arr[i])):
|
| 50 |
if arr[i][j] not in self.char_to_id:
|
| 51 |
-
val_msg = "
|
| 52 |
return False, val_msg
|
| 53 |
|
| 54 |
mat[i][j] = self.char_to_id[arr[i][j]]
|
| 55 |
if arr[i][j] != self.mat[i][j] and self.mat[i][j] != self.empty_character:
|
| 56 |
val_msg = "One or more characters are replaced"
|
| 57 |
return False, val_msg
|
| 58 |
-
|
|
|
|
| 59 |
|
| 60 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
| 61 |
size=kwargs["size"]
|
|
@@ -228,4 +242,14 @@ class Sudoku(BaseGame):
|
|
| 228 |
self.char_to_id = {}
|
| 229 |
for c_id in range(len(self.characters)):
|
| 230 |
self.char_to_id[self.characters[c_id]] = c_id
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
Print only the answer.
|
| 10 |
"""
|
| 11 |
|
| 12 |
+
|
| 13 |
#%%
|
| 14 |
class Sudoku(BaseGame):
|
| 15 |
@staticmethod
|
|
|
|
| 29 |
for j in range(self.size):
|
| 30 |
num = mat[i][j]
|
| 31 |
if num == self.empty_character:
|
| 32 |
+
return False, "There are unfilled cells"
|
| 33 |
|
| 34 |
subgrid_index = (i // self.srn) * self.srn + (j // self.srn)
|
| 35 |
|
| 36 |
+
if num in rows[i]:
|
| 37 |
+
return False, f"Duplicated row value ({num}) for cell in row {i+1} column {j+1}."
|
| 38 |
+
elif num in cols[j]:
|
| 39 |
+
return False, f"Duplicated column value ({num}) for cell in row {i+1} column {j+1}."
|
| 40 |
+
elif num in subgrids[subgrid_index]:
|
| 41 |
+
return False, f"Duplicated subgrid value ({num}) for cell in row {i+1} column {j+1}."
|
| 42 |
+
|
| 43 |
rows[i].add(num)
|
| 44 |
cols[j].add(num)
|
| 45 |
subgrids[subgrid_index].add(num)
|
| 46 |
|
| 47 |
+
return True, ""
|
| 48 |
|
| 49 |
def _validate(self, input) -> (bool, str):
|
| 50 |
mat = [[self.empty_character for i in range(self.size)] for j in range(self.size)]
|
| 51 |
|
| 52 |
+
input = input if input else ""
|
| 53 |
arr = input.split()
|
| 54 |
+
if all(len(l) == 1 for l in arr) and (len(arr) == self.size * self.size):
|
| 55 |
+
arr = ["".join(arr[i:i+self.size]) for i in range(0, len(arr), self.size)]
|
| 56 |
+
if (len(arr) != self.size) or any(len(arr[i]) != self.size for i in range(len(arr))):
|
| 57 |
+
arr = input.split("\n")
|
| 58 |
+
val_msg = f"Your answer is wrong in shape, it should be {self.size}x{self.size} sudoku."
|
| 59 |
+
return False, val_msg
|
| 60 |
+
|
| 61 |
for i in range(len(arr)):
|
| 62 |
for j in range(len(arr[i])):
|
| 63 |
if arr[i][j] not in self.char_to_id:
|
| 64 |
+
val_msg = "There are unrecognized characters, or possibly unfilled cells."
|
| 65 |
return False, val_msg
|
| 66 |
|
| 67 |
mat[i][j] = self.char_to_id[arr[i][j]]
|
| 68 |
if arr[i][j] != self.mat[i][j] and self.mat[i][j] != self.empty_character:
|
| 69 |
val_msg = "One or more characters are replaced"
|
| 70 |
return False, val_msg
|
| 71 |
+
|
| 72 |
+
return self.is_valid_sudoku(mat)
|
| 73 |
|
| 74 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
| 75 |
size=kwargs["size"]
|
|
|
|
| 242 |
self.char_to_id = {}
|
| 243 |
for c_id in range(len(self.characters)):
|
| 244 |
self.char_to_id[self.characters[c_id]] = c_id
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def example() -> (str, str):
|
| 248 |
+
prompt = ("Please solve the 4x4 sudoku with A,B,C,D as the values and fill _ with the possible value and"
|
| 249 |
+
" only print the answer. Follow the sudoku rule.\nA_CD CD_B _AD_ DCBA")
|
| 250 |
+
answer = ("ABCD\n"
|
| 251 |
+
"CDAB\n"
|
| 252 |
+
"BADC\n"
|
| 253 |
+
"DCBA")
|
| 254 |
+
return prompt, answer
|
| 255 |
+
|