# %% import json import pickle import re from pathlib import Path # %% def load_pickle(fp): with open(fp, "rb") as f: try: while True: yield pickle.load(f) except EOFError: pass # %% fd = Path("model_outputs") # %% # %% # %% # # %% # # concat pickle results (1/22) # list(fd.glob("results_gemma_*"))[0] # # # %% # fps = sorted(fd.glob("results_gemma_*")) # all_responses = dict() # errors = set() # for fp in fps: # responses = list(load_pickle(str(fp))) # print(fp.name, len(responses), responses[0][0], responses[-1][0]) # for r in responses: # if r[-1]: # errors.add((r[0], str(r[-1]))) # all_responses.setdefault(r[:2], set()) # all_responses[r[:2]].add(r) # errors = sorted(errors) # # # %% # assert all(len(v) == 1 for v in all_responses.values()), f"Duplicated response(s) found" # # # %% # duplicated = {k: v for k, v in all_responses.items() if len(v) > 1} # # # %% # concatenated = [list(v)[0] for v in all_responses.values()] # # # %% # with open(fd / "gemma2_9b_results_depre_250122/results_gemma-2-9b-it.single_turn.jsonl", "w", encoding="utf8") as o: # for i in concatenated: # json.dump({ # "game": i[0], # "session": i[1], # "turn": 1, # "response": i[2], # "solved": i[3][0], # "val_msg": i[3][1], # "error": repr(i[4]) if i[4] else i[4], # }, o, ensure_ascii=False) # o.write("\n") # %% # %% # %% # %% # %% # %% # Rerun gemma, resolving errors # %% import os import json import pandas as pd # %% os.environ["TG_GAME_ST"] = "7" os.environ["TG_GAME_ED"] = "8" # %% st, ed = os.getenv("TG_GAME_ST", None), os.getenv("TG_GAME_ED", None) st, ed = ((None if x is None else int(x)) for x in (st, ed)) fp_out = f"model_outputs/results_gemma-2-9b-it{'' if st is None else f'.{st}'}.jsonl" # %% from tqdm import tqdm from itertools import product from transformers import AutoTokenizer, AutoModelForCausalLM from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name os.environ.setdefault("TEXTGAMES_OUTPUT_DIR", "user_outputs") # %% with open(fd / "gemma2_9b_results_depre_250122/results_gemma-2-9b-it.single_turn.jsonl", "r", encoding="utf-8") as f: df = pd.read_json(f, lines=True) # %% df.columns # %% from agents import run_with_agent from agents.gemma_2_9b_it import gemma_postproc # %% def get_buffered_response(texts, game_name, difficulty_level, turn): if turn > 1: return None cur_df = df.loc[(df.game == f"{game_filename(game_name)}_{difficulty_level}")].set_index(["session", "turn"]) with open(f"problemsets/{game_filename(game_name)}_{difficulty_level}.json", "r", encoding="utf8") as f: _sid_prompt_dict = json.load(f) prompt_sid_dict = {v: k for k, v in _sid_prompt_dict.items()} sid = prompt_sid_dict[texts[0]] try: return cur_df.loc[(sid, turn)].response except KeyError: return None # %% run_with_agent(fp_out, get_buffered_response, get_postprocess=gemma_postproc, game_names_list=GAME_NAMES[st:ed], n_turns=1) # %% # %% # type(cur_df.loc[(sid, 1)].response) # %% # %% # %% # %% # %% # %% # %%