File size: 10,053 Bytes
3a3b852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from commonsenseConstraint import evaluation as commonsense_eval
from hardConstraint import evaluation as hard_eval
import json
from tqdm import tqdm
from datasets import load_dataset
def load_line_json_data(filename):
data = []
with open(filename, 'r', encoding='utf-8') as f:
for line in f.read().strip().split('\n'):
unit = json.loads(line)
data.append(unit)
return data
def count_true_false(data):
"""Count the number of true and false values in a list."""
true_count = data.count(True)
false_count = data.count(False)
return true_count, false_count
def statistics(commonsense_statistic):
"""Generate statistics for each level and day in the given data with a different structure."""
result = {level: {day: {} for day in commonsense_statistic[level]} for level in commonsense_statistic}
for level, days in commonsense_statistic.items():
for day, dicts in days.items():
for dct in dicts:
if dct:
for key, data in dct.items():
true_count, false_count = count_true_false(data)
if key not in result[level][day]:
result[level][day][key] = {"true": 0, "false": 0}
result[level][day][key]["true"] += true_count
result[level][day][key]["false"] += false_count
return result
def eval_score(validation_or_test: str, file_path: str, TOKEN):
if validation_or_test == 'validation':
query_data_list = load_dataset('osunlp/TravelBenchEval','validation',token=TOKEN)['validation']
elif validation_or_test == 'test':
query_data_list = load_dataset('osunlp/TravelBenchEval','test',token=TOKEN)['test']
query_data_list = [x for x in query_data_list]
hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
commonsenseConstraint_statistic = {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
tested_plans = load_line_json_data(file_path)
delivery_cnt = 0
plan_constraint_store = []
for idx in tqdm(range(0,len(query_data_list))):
query_data = query_data_list[idx]
tested_plan = tested_plans[idx]
if type(query_data) == str:
query_data = eval(query_data)
if type(tested_plan) == str:
tested_plan = eval(tested_plan)
if type(query_data['local_constraint']) == str:
query_data['local_constraint'] = eval(query_data['local_constraint'])
if tested_plan['plan']:
delivery_cnt += 1
commonsense_info_box = commonsense_eval(query_data,tested_plan['plan'])
else:
commonsense_info_box = None
if commonsense_info_box and commonsense_info_box['is_not_absent'][0] and commonsense_info_box['is_valid_information_in_sandbox'][0]:
hard_info_box = hard_eval(query_data,tested_plan['plan'])
else:
hard_info_box = None
plan_constraint_store.append({'commonsense_constraint':commonsense_info_box,'hard_constraint':hard_info_box})
commonsenseConstraint_statistic[query_data['level']][query_data['days']].append(commonsense_info_box)
hardConstraint_statistic[query_data['level']][query_data['days']].append(hard_info_box)
commonsenseConstraint_statistic_processed = statistics(commonsenseConstraint_statistic)
hardConstraint_statistic_processed = statistics(hardConstraint_statistic)
# print(commonsenseConstraint_statistic_processed)
# print(hardConstraint_statistic_processed)
constraint_record = {key: {day: {'house rule':0, 'cuisine':0, 'room type':0, 'transportation':0} for day in [3,5,7]} for key in ['medium','hard']}
constraint_mapping = {'house rule':'valid_room_rule','cuisine':'valid_cuisine','room type':'valid_room_type','transportation':'valid_transportation'}
mapping_constraint_record = {key: {day: {'valid_room_rule':0, 'valid_cuisine':0, 'valid_room_type':0, 'valid_transportation':0} for day in [3,5,7]} for key in ['medium','hard']}
count_record = {key:{day:0 for day in [3,5,7]} for key in ['easy','medium','hard']}
for unit in query_data_list:
count_record[unit['level']][unit['days']] += 1
for key in constraint_record['medium'][3]:
if unit['local_constraint'][key] != None:
constraint_record[unit['level']][unit['days']][key] += 1
mapping_constraint_record[unit['level']][unit['days']][constraint_mapping[key]] += 1
data_record = {key:{day:[] for day in [3,5,7]} for key in ['easy','medium','hard']}
constraint_dis_record = {"commonsense":{"pass":0,"total":0},"hard":{"pass":0,"total":0}}
for constraint in ['commonsense','hard']:
if constraint == 'commonsense':
constraint_statistic = commonsenseConstraint_statistic_processed
elif constraint == 'hard':
constraint_statistic = hardConstraint_statistic_processed
key_dict = {'commonsense':['is_valid_information_in_current_city','is_valid_information_in_sandbox','is_reasonalbe_visiting_city','is_valid_restaurants','is_valid_transportation','is_valid_attractions','is_valid_accommodation','is_not_absent'],'hard':['valid_cost','valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']}
for key in constraint_statistic:
# level
for key2 in constraint_statistic[key]:
# day
# print(key2)
# key2 = eval(key2)
if key2 == -1:
print(constraint_statistic[key])
exit(0)
for key3 in key_dict[constraint]:
data_record[key][key2].append('0/0')
if key3 in constraint_statistic[key][key2]:
constraint_dis_record[constraint]['pass'] += constraint_statistic[key][key2][key3]['true']
if constraint == 'hard':
if key == 'hard' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']:
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}"
constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3]
elif key == 'medium' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type']:
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}"
constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3]
else:
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}"
if key3 in ['valid_cost','valid_visitng_city_number','valid_days']:
constraint_dis_record[constraint]['total'] += count_record[key][key2]
else:
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}"
constraint_dis_record[constraint]['total'] += count_record[key][key2]
final_all_cnt = 0
final_commonsense_cnt = 0
final_hardConstraint_cnt = 0
final_all_cnt_map = {level:0 for level in ['easy','medium','hard']}
for idx in (range(0,len(query_data_list))):
if plan_constraint_store[idx]['commonsense_constraint']:
final_commonsense_pass = True
final_hardConstraint_pass = True
for item in plan_constraint_store[idx]['commonsense_constraint']:
if plan_constraint_store[idx]['commonsense_constraint'][item][0] is not None and not plan_constraint_store[idx]['commonsense_constraint'][item][0]:
final_commonsense_pass = False
break
if plan_constraint_store[idx]['hard_constraint'] is None:
continue
for item in plan_constraint_store[idx]['hard_constraint']:
if plan_constraint_store[idx]['hard_constraint'][item][0] is not None and plan_constraint_store[idx]['hard_constraint'][item][0] == False:
final_hardConstraint_pass = False
break
if final_commonsense_pass:
final_commonsense_cnt += 1
if final_hardConstraint_pass:
final_hardConstraint_cnt += 1
if final_commonsense_pass and final_hardConstraint_pass:
final_all_cnt += 1
final_all_cnt_map[query_data_list[idx]['level']] += 1
result = {}
if validation_or_test == 'validation':
result['Delivery Rate'] = delivery_cnt / 180
result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 1440
result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 180
result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 420
result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 180
result['Final Pass Rate'] = final_all_cnt / 180
elif validation_or_test == 'test':
result['Delivery Rate'] = delivery_cnt / 1000
result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 8000
result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 1000
result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 2290
result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 1000
result['Final Pass Rate'] = final_all_cnt / 1000
return result
|