|
from commonsenseConstraint import evaluation as commonsense_eval |
|
from hardConstraint import evaluation as hard_eval |
|
import json |
|
from tqdm import tqdm |
|
from datasets import load_dataset |
|
|
|
|
|
def load_line_json_data(filename): |
|
data = [] |
|
with open(filename, 'r', encoding='utf-8') as f: |
|
for line in f.read().strip().split('\n'): |
|
unit = json.loads(line) |
|
data.append(unit) |
|
return data |
|
|
|
def count_true_false(data): |
|
"""Count the number of true and false values in a list.""" |
|
true_count = data.count(True) |
|
false_count = data.count(False) |
|
return true_count, false_count |
|
|
|
def statistics(commonsense_statistic): |
|
"""Generate statistics for each level and day in the given data with a different structure.""" |
|
result = {level: {day: {} for day in commonsense_statistic[level]} for level in commonsense_statistic} |
|
|
|
for level, days in commonsense_statistic.items(): |
|
for day, dicts in days.items(): |
|
for dct in dicts: |
|
if dct: |
|
for key, data in dct.items(): |
|
true_count, false_count = count_true_false(data) |
|
if key not in result[level][day]: |
|
result[level][day][key] = {"true": 0, "false": 0} |
|
result[level][day][key]["true"] += true_count |
|
result[level][day][key]["false"] += false_count |
|
|
|
return result |
|
|
|
|
|
def eval_score(validation_or_test: str, file_path: str, TOKEN): |
|
|
|
if validation_or_test == 'validation': |
|
query_data_list = load_dataset('osunlp/TravelBenchEval','validation',token=TOKEN)['validation'] |
|
elif validation_or_test == 'test': |
|
query_data_list = load_dataset('osunlp/TravelBenchEval','test',token=TOKEN)['test'] |
|
|
|
query_data_list = [x for x in query_data_list] |
|
hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']} |
|
commonsenseConstraint_statistic = {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']} |
|
tested_plans = load_line_json_data(file_path) |
|
delivery_cnt = 0 |
|
plan_constraint_store = [] |
|
for idx in tqdm(range(0,len(query_data_list))): |
|
query_data = query_data_list[idx] |
|
tested_plan = tested_plans[idx] |
|
if type(query_data) == str: |
|
query_data = eval(query_data) |
|
if type(tested_plan) == str: |
|
tested_plan = eval(tested_plan) |
|
if type(query_data['local_constraint']) == str: |
|
query_data['local_constraint'] = eval(query_data['local_constraint']) |
|
|
|
if tested_plan['plan']: |
|
delivery_cnt += 1 |
|
commonsense_info_box = commonsense_eval(query_data,tested_plan['plan']) |
|
else: |
|
commonsense_info_box = None |
|
|
|
if commonsense_info_box and commonsense_info_box['is_not_absent'][0] and commonsense_info_box['is_valid_information_in_sandbox'][0]: |
|
hard_info_box = hard_eval(query_data,tested_plan['plan']) |
|
else: |
|
hard_info_box = None |
|
|
|
plan_constraint_store.append({'commonsense_constraint':commonsense_info_box,'hard_constraint':hard_info_box}) |
|
|
|
commonsenseConstraint_statistic[query_data['level']][query_data['days']].append(commonsense_info_box) |
|
hardConstraint_statistic[query_data['level']][query_data['days']].append(hard_info_box) |
|
|
|
commonsenseConstraint_statistic_processed = statistics(commonsenseConstraint_statistic) |
|
hardConstraint_statistic_processed = statistics(hardConstraint_statistic) |
|
|
|
|
|
constraint_record = {key: {day: {'house rule':0, 'cuisine':0, 'room type':0, 'transportation':0} for day in [3,5,7]} for key in ['medium','hard']} |
|
constraint_mapping = {'house rule':'valid_room_rule','cuisine':'valid_cuisine','room type':'valid_room_type','transportation':'valid_transportation'} |
|
mapping_constraint_record = {key: {day: {'valid_room_rule':0, 'valid_cuisine':0, 'valid_room_type':0, 'valid_transportation':0} for day in [3,5,7]} for key in ['medium','hard']} |
|
count_record = {key:{day:0 for day in [3,5,7]} for key in ['easy','medium','hard']} |
|
|
|
for unit in query_data_list: |
|
count_record[unit['level']][unit['days']] += 1 |
|
for key in constraint_record['medium'][3]: |
|
if unit['local_constraint'][key] != None: |
|
constraint_record[unit['level']][unit['days']][key] += 1 |
|
mapping_constraint_record[unit['level']][unit['days']][constraint_mapping[key]] += 1 |
|
|
|
data_record = {key:{day:[] for day in [3,5,7]} for key in ['easy','medium','hard']} |
|
|
|
constraint_dis_record = {"commonsense":{"pass":0,"total":0},"hard":{"pass":0,"total":0}} |
|
|
|
for constraint in ['commonsense','hard']: |
|
if constraint == 'commonsense': |
|
constraint_statistic = commonsenseConstraint_statistic_processed |
|
elif constraint == 'hard': |
|
constraint_statistic = hardConstraint_statistic_processed |
|
|
|
key_dict = {'commonsense':['is_valid_information_in_current_city','is_valid_information_in_sandbox','is_reasonalbe_visiting_city','is_valid_restaurants','is_valid_transportation','is_valid_attractions','is_valid_accommodation','is_not_absent'],'hard':['valid_cost','valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']} |
|
|
|
for key in constraint_statistic: |
|
|
|
for key2 in constraint_statistic[key]: |
|
|
|
|
|
|
|
if key2 == -1: |
|
print(constraint_statistic[key]) |
|
exit(0) |
|
for key3 in key_dict[constraint]: |
|
data_record[key][key2].append('0/0') |
|
if key3 in constraint_statistic[key][key2]: |
|
constraint_dis_record[constraint]['pass'] += constraint_statistic[key][key2][key3]['true'] |
|
if constraint == 'hard': |
|
if key == 'hard' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']: |
|
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}" |
|
constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3] |
|
elif key == 'medium' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type']: |
|
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}" |
|
constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3] |
|
else: |
|
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}" |
|
if key3 in ['valid_cost','valid_visitng_city_number','valid_days']: |
|
constraint_dis_record[constraint]['total'] += count_record[key][key2] |
|
else: |
|
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}" |
|
constraint_dis_record[constraint]['total'] += count_record[key][key2] |
|
|
|
final_all_cnt = 0 |
|
final_commonsense_cnt = 0 |
|
final_hardConstraint_cnt = 0 |
|
final_all_cnt_map = {level:0 for level in ['easy','medium','hard']} |
|
for idx in (range(0,len(query_data_list))): |
|
if plan_constraint_store[idx]['commonsense_constraint']: |
|
final_commonsense_pass = True |
|
final_hardConstraint_pass = True |
|
for item in plan_constraint_store[idx]['commonsense_constraint']: |
|
if plan_constraint_store[idx]['commonsense_constraint'][item][0] is not None and not plan_constraint_store[idx]['commonsense_constraint'][item][0]: |
|
final_commonsense_pass = False |
|
break |
|
if plan_constraint_store[idx]['hard_constraint'] is None: |
|
continue |
|
for item in plan_constraint_store[idx]['hard_constraint']: |
|
if plan_constraint_store[idx]['hard_constraint'][item][0] is not None and plan_constraint_store[idx]['hard_constraint'][item][0] == False: |
|
final_hardConstraint_pass = False |
|
break |
|
|
|
if final_commonsense_pass: |
|
final_commonsense_cnt += 1 |
|
if final_hardConstraint_pass: |
|
final_hardConstraint_cnt += 1 |
|
if final_commonsense_pass and final_hardConstraint_pass: |
|
final_all_cnt += 1 |
|
final_all_cnt_map[query_data_list[idx]['level']] += 1 |
|
|
|
result = {} |
|
|
|
if validation_or_test == 'validation': |
|
result['Delivery Rate'] = delivery_cnt / 180 |
|
result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 1440 |
|
result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 180 |
|
result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 420 |
|
result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 180 |
|
result['Final Pass Rate'] = final_all_cnt / 180 |
|
|
|
elif validation_or_test == 'test': |
|
result['Delivery Rate'] = delivery_cnt / 1000 |
|
result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 8000 |
|
result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 1000 |
|
result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 2290 |
|
result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 1000 |
|
result['Final Pass Rate'] = final_all_cnt / 1000 |
|
|
|
return result |
|
|
|
|