textgames / textgames_check_model_outputs.py
fhudi's picture
Upload folder using huggingface_hub
c9d7b4f verified
# %%
import json
import pickle
import re
from pathlib import Path
# %%
def load_pickle(fp):
with open(fp, "rb") as f:
try:
while True:
yield pickle.load(f)
except EOFError:
pass
# %%
fd = Path("model_outputs")
# %%
# %%
# %%
# # %%
# # concat pickle results (1/22)
# list(fd.glob("results_gemma_*"))[0]
#
# # %%
# fps = sorted(fd.glob("results_gemma_*"))
# all_responses = dict()
# errors = set()
# for fp in fps:
# responses = list(load_pickle(str(fp)))
# print(fp.name, len(responses), responses[0][0], responses[-1][0])
# for r in responses:
# if r[-1]:
# errors.add((r[0], str(r[-1])))
# all_responses.setdefault(r[:2], set())
# all_responses[r[:2]].add(r)
# errors = sorted(errors)
#
# # %%
# assert all(len(v) == 1 for v in all_responses.values()), f"Duplicated response(s) found"
#
# # %%
# duplicated = {k: v for k, v in all_responses.items() if len(v) > 1}
#
# # %%
# concatenated = [list(v)[0] for v in all_responses.values()]
#
# # %%
# with open(fd / "gemma2_9b_results_depre_250122/results_gemma-2-9b-it.single_turn.jsonl", "w", encoding="utf8") as o:
# for i in concatenated:
# json.dump({
# "game": i[0],
# "session": i[1],
# "turn": 1,
# "response": i[2],
# "solved": i[3][0],
# "val_msg": i[3][1],
# "error": repr(i[4]) if i[4] else i[4],
# }, o, ensure_ascii=False)
# o.write("\n")
# %%
# %%
# %%
# %%
# %%
# %%
# Rerun gemma, resolving errors
# %%
import os
import json
import pandas as pd
# %%
os.environ["TG_GAME_ST"] = "7"
os.environ["TG_GAME_ED"] = "8"
# %%
st, ed = os.getenv("TG_GAME_ST", None), os.getenv("TG_GAME_ED", None)
st, ed = ((None if x is None else int(x)) for x in (st, ed))
fp_out = f"model_outputs/results_gemma-2-9b-it{'' if st is None else f'.{st}'}.jsonl"
# %%
from tqdm import tqdm
from itertools import product
from transformers import AutoTokenizer, AutoModelForCausalLM
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name
os.environ.setdefault("TEXTGAMES_OUTPUT_DIR", "user_outputs")
# %%
with open(fd / "gemma2_9b_results_depre_250122/results_gemma-2-9b-it.single_turn.jsonl", "r", encoding="utf-8") as f:
df = pd.read_json(f, lines=True)
# %%
df.columns
# %%
from agents import run_with_agent
from agents.gemma_2_9b_it import gemma_postproc
# %%
def get_buffered_response(texts, game_name, difficulty_level, turn):
if turn > 1:
return None
cur_df = df.loc[(df.game == f"{game_filename(game_name)}_{difficulty_level}")].set_index(["session", "turn"])
with open(f"problemsets/{game_filename(game_name)}_{difficulty_level}.json", "r", encoding="utf8") as f:
_sid_prompt_dict = json.load(f)
prompt_sid_dict = {v: k for k, v in _sid_prompt_dict.items()}
sid = prompt_sid_dict[texts[0]]
try:
return cur_df.loc[(sid, turn)].response
except KeyError:
return None
# %%
run_with_agent(fp_out, get_buffered_response, get_postprocess=gemma_postproc, game_names_list=GAME_NAMES[st:ed], n_turns=1)
# %%
# %%
# type(cur_df.loc[(sid, 1)].response)
# %%
# %%
# %%
# %%
# %%
# %%
# %%