|
import json |
|
import os |
|
import time |
|
|
|
import instructor |
|
import openai |
|
import pydantic_core |
|
import tqdm |
|
from exp_model_class import ExtendedModelType |
|
from openai import OpenAI |
|
from pydantic import BaseModel |
|
|
|
client = instructor.patch(OpenAI(api_key=os.getenv("OPENAI_API_KEY"))) |
|
|
|
game_list = ["lottery", "trustee"] |
|
|
|
|
|
class money_extract(BaseModel): |
|
name: str |
|
Belief: str |
|
Desire: str |
|
Intention: str |
|
give_money_number: float |
|
|
|
|
|
class option_extract(BaseModel): |
|
name: str |
|
option_trust_or_not_trust: str |
|
Belief: str |
|
Desire: str |
|
Intention: str |
|
|
|
|
|
def check_substring(main_string, string_list=["lottery", "trustee"]): |
|
for s in string_list: |
|
if s in main_string: |
|
return True |
|
return False |
|
|
|
|
|
def get_struct_output(input, whether_money=False, test=False): |
|
if test: |
|
return (1, {}) |
|
if whether_money: |
|
response_mod = money_extract |
|
else: |
|
response_mod = option_extract |
|
ori_path = openai.api_base |
|
openai.api_base = "https://api.openai.com/v1" |
|
resp = openai.ChatCompletion.create( |
|
model=ExtendedModelType.GPT_3_5_TURBO, |
|
response_model=response_mod, |
|
messages=[ |
|
{"role": "user", "content": input}, |
|
], |
|
) |
|
openai.api_base = ori_path |
|
|
|
if response_mod.__name__ == "money_extract": |
|
given_money = resp.give_money_number |
|
return ( |
|
given_money, |
|
dict(resp), |
|
) |
|
else: |
|
option_trust_or_not_trust = resp.option_trust_or_not_trust |
|
return ( |
|
option_trust_or_not_trust, |
|
dict(resp), |
|
) |
|
|
|
|
|
def extrat_json(folder_path): |
|
dirs_path = os.listdir(folder_path) |
|
for file in dirs_path: |
|
if ( |
|
file.endswith(".json") |
|
and "map" not in file |
|
and "extract" not in file |
|
and file[:-5] + "_extract.json" not in dirs_path |
|
): |
|
print(file) |
|
with open(os.path.join(folder_path, file), "r", encoding="utf-8") as f: |
|
data = json.load(f) |
|
res = data["dialog"] |
|
new_res = [] |
|
|
|
for items in tqdm.trange(len(res)): |
|
item = res[items][-1] |
|
try: |
|
if check_substring(file, game_list): |
|
extract_res, structure_output = get_struct_output(item) |
|
else: |
|
extract_res, structure_output = get_struct_output( |
|
item, whether_money=True |
|
) |
|
new_res.append(extract_res) |
|
except openai.error.APIError: |
|
print("openai.error.APIError") |
|
items -= 1 |
|
except ( |
|
openai.error.Timeout or pydantic_core._pydantic_core.ValidationError |
|
): |
|
print("Time out error") |
|
time.sleep(30) |
|
except json.decoder.JSONDecodeError: |
|
extract_res = data["res"][items] |
|
data["res"] = new_res |
|
with open( |
|
os.path.join(folder_path, file[:-5] + "_extract.json"), |
|
"w", |
|
encoding="utf-8", |
|
) as f: |
|
json.dump(data, f, indent=4) |
|
|