File size: 3,340 Bytes
354fa18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import json
import os
import time
import instructor
import openai
import pydantic_core
import tqdm
from exp_model_class import ExtendedModelType
from openai import OpenAI
from pydantic import BaseModel
client = instructor.patch(OpenAI(api_key=os.getenv("OPENAI_API_KEY")))
game_list = ["lottery", "trustee"]
class money_extract(BaseModel):
name: str
Belief: str
Desire: str
Intention: str
give_money_number: float
class option_extract(BaseModel):
name: str
option_trust_or_not_trust: str
Belief: str
Desire: str
Intention: str
def check_substring(main_string, string_list=["lottery", "trustee"]):
for s in string_list:
if s in main_string:
return True
return False
def get_struct_output(input, whether_money=False, test=False):
if test:
return (1, {})
if whether_money:
response_mod = money_extract
else:
response_mod = option_extract
ori_path = openai.api_base
openai.api_base = "https://api.openai.com/v1"
resp = openai.ChatCompletion.create(
model=ExtendedModelType.GPT_3_5_TURBO, # TODO change if you need
response_model=response_mod,
messages=[
{"role": "user", "content": input},
],
)
openai.api_base = ori_path
# print("mode:", response_mod.__name__)
if response_mod.__name__ == "money_extract":
given_money = resp.give_money_number
return (
given_money,
dict(resp),
)
else:
option_trust_or_not_trust = resp.option_trust_or_not_trust
return (
option_trust_or_not_trust,
dict(resp),
)
def extrat_json(folder_path):
dirs_path = os.listdir(folder_path)
for file in dirs_path:
if (
file.endswith(".json")
and "map" not in file
and "extract" not in file
and file[:-5] + "_extract.json" not in dirs_path
):
print(file)
with open(os.path.join(folder_path, file), "r", encoding="utf-8") as f:
data = json.load(f)
res = data["dialog"]
new_res = []
for items in tqdm.trange(len(res)):
item = res[items][-1]
try:
if check_substring(file, game_list):
extract_res, structure_output = get_struct_output(item)
else:
extract_res, structure_output = get_struct_output(
item, whether_money=True
)
new_res.append(extract_res)
except openai.error.APIError:
print("openai.error.APIError")
items -= 1
except (
openai.error.Timeout or pydantic_core._pydantic_core.ValidationError
):
print("Time out error")
time.sleep(30)
except json.decoder.JSONDecodeError:
extract_res = data["res"][items]
data["res"] = new_res
with open(
os.path.join(folder_path, file[:-5] + "_extract.json"),
"w",
encoding="utf-8",
) as f:
json.dump(data, f, indent=4)
|