hsaest commited on
Commit
342f407
1 Parent(s): 31e4cab

Delete evaluation

Browse files
evaluation/.DS_Store DELETED
Binary file (6.15 kB)
 
evaluation/__pycache__/commonsenseConstraint.cpython-39.pyc DELETED
Binary file (14 kB)
 
evaluation/__pycache__/eval.cpython-39.pyc DELETED
Binary file (7.05 kB)
 
evaluation/__pycache__/hardConstraint.cpython-39.pyc DELETED
Binary file (8.13 kB)
 
evaluation/commonsenseConstraint.py DELETED
@@ -1,735 +0,0 @@
1
- from annotation.src.utils import get_valid_name_city,extract_before_parenthesis,extract_numbers_from_filenames
2
- from tools.flights.apis import Flights
3
- from tools.accommodations.apis import Accommodations
4
- from tools.restaurants.apis import Restaurants
5
- from tools.googleDistanceMatrix.apis import GoogleDistanceMatrix
6
- from tools.attractions.apis import Attractions
7
- import math
8
- import json
9
- import re
10
- import os
11
- import sys
12
- from tqdm import tqdm
13
- import argparse
14
-
15
- sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
16
- os.chdir(os.path.dirname(os.path.abspath(__file__)))
17
-
18
- flight = Flights()
19
- accommodation = Accommodations()
20
- restaurants = Restaurants()
21
- googleDistanceMatrix = GoogleDistanceMatrix()
22
- attractions = Attractions()
23
-
24
- city_state_set = open('../database/background/citySet_with_states.txt','r').read().split('\n')
25
- city_state_map = {x:y for x,y in [unit.split('\t') for unit in city_state_set]}
26
-
27
-
28
- def load_line_json_data(filename):
29
- data = []
30
- with open(filename, 'r', encoding='utf-8') as f:
31
- for line in f.read().strip().split('\n'):
32
- unit = json.loads(line)
33
- data.append(unit)
34
- return data
35
-
36
-
37
- def count_consecutive_values(lst):
38
- if not lst:
39
- return []
40
-
41
- result = []
42
- current_string = lst[0]
43
- count = 1
44
-
45
- for i in range(1, len(lst)):
46
- if lst[i] == current_string:
47
- count += 1
48
- else:
49
- result.append((current_string, count))
50
- current_string = lst[i]
51
- count = 1
52
-
53
- result.append((current_string, count)) # Add the last group of values
54
- return result
55
-
56
-
57
- def transportation_match(text: str):
58
-
59
- if 'taxi' in text.lower():
60
- return 'Taxi'
61
-
62
- elif 'self-driving' in text.lower():
63
- return 'Self-driving'
64
-
65
- elif 'flight' in text.lower():
66
- return 'Flight'
67
-
68
-
69
- def extract_from_to(text: str):
70
- """
71
- Extracts 'A' and 'B' from the format "from A to B" in the given text, with B ending at a comma or the end of the string.
72
-
73
- Args:
74
- - text (str): The input string.
75
-
76
- Returns:
77
- - tuple: A tuple containing 'A' and 'B'. If no match is found, returns (None, None).
78
- """
79
- pattern = r"from\s+(.+?)\s+to\s+([^,]+)(?=[,\s]|$)"
80
- matches = re.search(pattern, text)
81
- return matches.groups() if matches else (None, None)
82
-
83
-
84
-
85
- def is_valid_city_sequence(city_list):
86
- """
87
- Checks if the city sequence is valid. A valid sequence has every city (except the first and last)
88
- appearing consecutively, and no city should appear again once its sequence is over.
89
-
90
- Args:
91
- - city_list (list): List of cities.
92
-
93
- Returns:
94
- - bool: True if the sequence is valid, False otherwise.
95
- """
96
-
97
- # If the list has less than 3 cities, it's invalid.
98
- if len(city_list) < 3:
99
- return False
100
-
101
- # Set to keep track of visited cities
102
- visited_cities = set()
103
-
104
- i = 0
105
- while i < len(city_list):
106
- city = city_list[i]
107
-
108
- # If the city was already visited, it's invalid.
109
- if city in visited_cities and (i != 0 and i != len(city_list) - 1):
110
- return False
111
-
112
- # Count the consecutive occurrences of the city
113
- count = 0
114
- while i < len(city_list) and city_list[i] == city:
115
- count += 1
116
- i += 1
117
-
118
- # If the city appeared only once in the medium, it's invalid.
119
- if count == 1 and 0 < i - 1 < len(city_list) - 1:
120
- return False
121
-
122
- visited_cities.add(city)
123
-
124
- return True
125
-
126
-
127
-
128
- def is_reasonalbe_visiting_city(question, tested_data):
129
-
130
- city_list = []
131
-
132
- # print(tested_data)
133
- for i in range(min(question['days'],len(tested_data))):
134
- city_value = tested_data[i]['current_city']
135
-
136
- if 'from' in city_value:
137
- city1, city2 = extract_from_to(city_value)
138
- city1 = extract_before_parenthesis(city1)
139
- city2 = extract_before_parenthesis(city2)
140
- if i==0 and city1 != question['org']:
141
- return False, f"The first day's city should be {question['org']}."
142
-
143
- city_list += [city1, city2]
144
-
145
- else:
146
- city_list.append(extract_before_parenthesis(city_value))
147
-
148
- if city_list[0] != city_list[-1]:
149
- return False, "The trip should be a closed circle."
150
-
151
- if not is_valid_city_sequence(city_list):
152
- return False, "The city sequence is invalid."
153
-
154
- for idx, city in enumerate(city_list):
155
- if city not in city_state_map:
156
- return False, f"{city} is not a valid city."
157
- if idx not in [0,len(city_list)-1] and question['days'] >3 and city_state_map[city] != question['dest']:
158
- return False, f"{city} is not in {question['dest']}."
159
-
160
- return True, None
161
-
162
-
163
- def is_valid_restaurants(question, tested_data):
164
-
165
- restaurants_list = []
166
-
167
- for i in range(min(question['days'],len(tested_data))):
168
- unit = tested_data[i]
169
-
170
- if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
171
- if unit['breakfast'] not in restaurants_list:
172
- restaurants_list.append(unit['breakfast'])
173
- else:
174
- return False, f"The restaurant in day {i+1} breakfast is repeated."
175
- # elif 'breakfast' not in unit :
176
- # return False, f"No Breakfast Info."
177
-
178
- if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
179
- if unit['lunch'] not in restaurants_list:
180
- restaurants_list.append(unit['lunch'])
181
- else:
182
- return False, f"The restaurant in day {i+1} lunch {unit['lunch']} is repeated."
183
- # elif 'lunch' not in unit:
184
- # return False, f"No Lunch Info."
185
-
186
- if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
187
- if unit['dinner'] not in restaurants_list:
188
- restaurants_list.append(unit['dinner'])
189
- else:
190
- return False, f"The restaurant in day {i+1} dinner is repeated."
191
- # elif 'dinner' not in unit:
192
- # return False, f"No Dinner Info."
193
-
194
- return True, None
195
-
196
- def is_valid_attractions(question, tested_data):
197
-
198
- attractions_list = []
199
-
200
- for i in range(min(question['days'],len(tested_data))):
201
- unit = tested_data[i]
202
-
203
- if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
204
- for attraction in unit['attraction'].split(';')[:-1]:
205
- if attraction not in attractions_list:
206
- attractions_list.append(attraction)
207
- else:
208
- return False, f"The attraction '{attraction}' in day {i+1} is repeated."
209
-
210
- # elif 'attraction' not in unit:
211
- # return False, f"No Attraction Info."
212
-
213
- return True, None
214
-
215
- def is_valid_transportation(question, tested_data):
216
-
217
- if tested_data[0]['transportation'] and tested_data[0]['transportation'] != '-':
218
- transportation_list = [transportation_match(tested_data[0]['transportation'])]
219
-
220
- else:
221
- return False, "The transportation in day 1 should not be empty."
222
-
223
- for i in range(min(question['days'],len(tested_data))):
224
- unit = tested_data[i]
225
-
226
- if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-':
227
- transportation_list.append(transportation_match(unit['transportation']))
228
- # elif 'transportation' not in unit:
229
- # return False, f"No Transportation Info."
230
-
231
- if (('Self-driving' in transportation_list) and ('Flight' in transportation_list)) or (('Taxi' in transportation_list) and ('Self-driving' in transportation_list)):
232
- return False, "The transportation is conflicting."
233
-
234
- return True, None
235
-
236
- def is_valid_information_in_current_city(question, tested_data):
237
-
238
- for i in range(min(question['days'],len(tested_data))):
239
- unit = tested_data[i]
240
- current_city = unit['current_city']
241
- final_city_list = []
242
-
243
- if 'from' in current_city:
244
- city1, city2 = extract_from_to(current_city)
245
- city1 = extract_before_parenthesis(city1)
246
- city2 = extract_before_parenthesis(city2)
247
- final_city_list = [city1, city2]
248
- else:
249
- final_city_list = extract_before_parenthesis(current_city)
250
-
251
- if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-':
252
- for city in final_city_list:
253
- if city not in unit['transportation']:
254
- # print(city)
255
- return False, f"The transportation in day {i+1} is invalid city choice."
256
- # elif 'transportation' not in unit:
257
- # return False, f"No Transportation Info."
258
-
259
- if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
260
-
261
- flag = False
262
-
263
- for city in final_city_list:
264
- if city in unit['breakfast']:
265
- flag = True
266
-
267
- if not flag:
268
- return False, f"The breakfast in day {i+1} is invalid city choice."
269
- # elif 'breakfast' not in unit:
270
- # return False, f"No Breakfast Info."
271
-
272
- if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
273
- flag = False
274
-
275
- for city in final_city_list:
276
- if city in unit['lunch']:
277
- flag = True
278
-
279
- if not flag:
280
- return False, f"The lunch in day {i+1} is invalid city choice."
281
- # elif 'lunch' not in unit:
282
- # return False, f"No Lunch Info."
283
-
284
- if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
285
- flag = False
286
-
287
- for city in final_city_list:
288
- if city in unit['dinner']:
289
- flag = True
290
-
291
- if not flag:
292
- return False, f"The dinner in day {i+1} is invalid city choice."
293
- # elif 'dinner' not in unit:
294
- # return False, f"No Dinner Info."
295
-
296
- if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
297
-
298
- attraction_list = unit['attraction'].split(';')[:-1]
299
-
300
- for attraction in attraction_list:
301
- flag = False
302
- for city in final_city_list:
303
- if city in attraction:
304
- flag = True
305
- if not flag:
306
- return False, f"The attraction in day {i+1} is invalid city choice."
307
-
308
- # elif 'attraction' not in unit:
309
- # return False, f"No Attraction Info."
310
-
311
-
312
- if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-':
313
-
314
- if final_city_list[-1] not in unit['accommodation']:
315
- return False, f"The accommodation in day {i+1} is invalid city choice."
316
-
317
- # elif 'accommodation' not in unit:
318
- # return False, f"No Accommodation Info."
319
-
320
- return True, None
321
-
322
- # hallucination
323
- def is_valid_information_in_sandbox(question, tested_data):
324
-
325
- for i in range(min(question['days'],len(tested_data))):
326
- unit = tested_data[i]
327
-
328
- if unit['transportation'] and unit['transportation'] != '-':
329
- value = unit['transportation']
330
- org_city, dest_city = extract_from_to(value)
331
- if org_city == None or dest_city == None:
332
- org_city, dest_city = extract_from_to(unit['current_city'])
333
- if 'flight number' in value.lower():
334
- try:
335
- org_city = extract_before_parenthesis(org_city)
336
- dest_city = extract_before_parenthesis(dest_city)
337
- except TypeError:
338
- raise ValueError("The transportation {} in day {} can not be parsed.".format(value,i+1))
339
- # print(value)
340
- if len(flight.data[(flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]) & (flight.data['OriginCityName']==org_city) & (flight.data['DestCityName']==dest_city)]) < 1:
341
- return False, f"The flight number in day {i+1} is invalid in the sandbox."
342
-
343
- elif 'self-driving' in value.lower() or 'taxi' in value.lower():
344
- try:
345
- org_city = extract_before_parenthesis(org_city)
346
- dest_city = extract_before_parenthesis(dest_city)
347
- except TypeError:
348
- org_city = '-'
349
- dest_city = '-'
350
- print("The transportation {} in day {} can not be parsed and '-' will be used instead.".format(value,i+1))
351
-
352
- if 'self-driving' in value.lower():
353
- if googleDistanceMatrix.run_for_evaluation(org_city, dest_city, mode='self-driving')['cost'] == None:
354
- return False, f"The self-driving in day {i+1} is invalid in the sandbox."
355
- else:
356
- if googleDistanceMatrix.run_for_evaluation(org_city, dest_city, mode='taxi')['cost'] == None:
357
- return False, f"The taxi in day {i+1} is invalid in the sandbox."
358
-
359
- if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
360
- name, city = get_valid_name_city(unit['breakfast'])
361
- if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
362
- return False, f"The breakfast in day {i+1} is invalid in the sandbox."
363
- # elif 'breakfast' not in unit:
364
- # return False, f"No Breakfast Info."
365
-
366
- if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
367
- name, city = get_valid_name_city(unit['lunch'])
368
- if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
369
- return False, f"The lunch in day {i+1} is invalid in the sandbox."
370
- # elif 'lunch' not in unit:
371
- # return False, f"No Lunch Info."
372
-
373
- if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
374
- name, city = get_valid_name_city(unit['dinner'])
375
- if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
376
- return False, f"The dinner in day {i+1} is invalid in the sandbox."
377
- # elif 'dinner' not in unit:
378
- # return False, f"No Dinner Info."
379
-
380
- if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
381
- attractions_list = unit['attraction'].split(';')[:-1]
382
- for attraction in attractions_list:
383
- name, city = get_valid_name_city(attraction)
384
- if len(attractions.data[(attractions.data['Name'].astype(str).str.contains(re.escape(name))) & (attractions.data['City'] == city)]) < 1:
385
- return False, f"The attraction {attraction} in day {i+1} is invalid in the sandbox."
386
- # elif 'attraction' not in unit:
387
- # return False, f"No Attraction Info."
388
-
389
- if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-':
390
- name, city = get_valid_name_city(unit['accommodation'])
391
- # print(name,city)
392
- # print(accommodation.data[accommodation.data['NAME'].astype(str).str.contains(re.escape(name))])
393
- if len(accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]) < 1:
394
- return False, f"The accommodation in day {i+1} is invalid in the sandbox."
395
- # elif 'accommodation' not in unit:
396
- # return False, f"No Accommodation Info."
397
-
398
- return True, None
399
-
400
-
401
- def is_valid_accommodaton(question, tested_data):
402
- data = []
403
- for i in range(min(question['days'],len(tested_data))):
404
- unit = tested_data[i]
405
-
406
- if 'accommodation' not in unit:
407
- return False, f"No Accommodation Info."
408
-
409
- data.append(unit['accommodation'])
410
- # data = [unit['accommodation'] for unit in tested_data]
411
- consectutive_accommodation = count_consecutive_values(data)
412
- for unit in consectutive_accommodation:
413
- # print(unit)
414
- if unit and unit[0] not in ['-',''] :
415
- name, city = get_valid_name_city(unit[0])
416
- # print(unit[0],name,city)
417
- # try:
418
- if len(accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]) == 1 and unit[1] < accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)].iloc[0]['minimum nights']:
419
- return False, f"The accommodation {unit[0]} do not obey the minumum nights rule."
420
- # can not parse data
421
- # except re.error:
422
- # continue
423
-
424
- return True, None
425
-
426
- def is_valid_visiting_city_number(question, tested_data):
427
-
428
- city_set = set()
429
-
430
-
431
- for i in range(min(question['days'],len(tested_data))):
432
- city_value = tested_data[i]['current_city']
433
-
434
- if 'from' in city_value:
435
- city1, city2 = extract_from_to(city_value)
436
- city1 = extract_before_parenthesis(city1)
437
- city2 = extract_before_parenthesis(city2)
438
- if i==0 and city1 != question['org']:
439
- return False, f"The first day's city should be {question['org']}."
440
-
441
- city_set.add(city1)
442
- city_set.add(city2)
443
-
444
- else:
445
- city_set.add(extract_before_parenthesis(city_value))
446
-
447
- city_set.discard(question['org'])
448
-
449
- if len(city_set) != question['visiting_city_number']:
450
- return False, f"The number of visiting cities should be {question['visiting_city_number']}."
451
-
452
- return True, None
453
-
454
- def is_valid_days(question, tested_data):
455
- lens = 0
456
- for i in range(min(question['days'],len(tested_data))):
457
- if tested_data[i] != {} and tested_data[i]['current_city'] != "You don't need to fill in the information for this or later days.":
458
- lens += 1
459
-
460
- if lens != question['days']:
461
- # print(lens)
462
- return False, f"The number of days should be {question['days']}."
463
- else:
464
- return True, None
465
-
466
- def is_not_absent(question, tested_data):
467
- needed_info = 6 * question['days']
468
- total_valid_info = 0
469
-
470
- if not is_valid_days(question, tested_data)[0]:
471
- return False, "Invalid Days"
472
-
473
- if not is_valid_visiting_city_number(question, tested_data)[0]:
474
- return False, "Invalid City Number"
475
-
476
- for i in range(min(question['days'],len(tested_data))):
477
- unit = tested_data[i]
478
-
479
- if 'transportation' not in unit:
480
- return False, f"No Transportation Info."
481
-
482
- if 'breakfast' not in unit:
483
- return False, f"No Breakfast Info."
484
-
485
- if 'lunch' not in unit:
486
- return False, f"No Lunch Info."
487
-
488
- if 'dinner' not in unit:
489
- return False, f"No Dinner Info."
490
-
491
- if 'attraction' not in unit:
492
- return False, f"No Attraction Info."
493
-
494
- if 'accommodation' not in unit:
495
- return False, f"No Accommodation Info."
496
-
497
- if ('from ' in unit['current_city'] or 'to ' in unit['current_city']) and unit['transportation'] in ['','-']:
498
- return False, f"No transportation in day {i+1} is not allowed."
499
-
500
- if ('from ' not in unit['current_city'] and ' to ' not in unit['current_city']) and unit['attraction'] in ['','-']:
501
- return False, f"No attaction in day {i+1} is not allowed."
502
-
503
- if i != question['days'] - 1 and unit['accommodation'] in ['','-']:
504
- return False, f"No accommodation in day {i+1} is not allowed."
505
-
506
- if (unit['breakfast'] in ['','-'] or unit['lunch'] in ['','-'] or unit['dinner'] in ['','-']) and 'from ' not in unit['current_city']:
507
- return False, f"No meal in day {i+1} is not allowed."
508
-
509
-
510
- for key in unit:
511
- if unit[key] and unit[key] != '-':
512
- total_valid_info += 1
513
-
514
-
515
- if total_valid_info * 1.0 / needed_info < 0.5:
516
- return False, f"The absent information is more than 50%."
517
-
518
- return True, None
519
-
520
-
521
- def evaluation(query_data, tested_data):
522
- return_info = {}
523
- return_info['is_reasonalbe_visiting_city'] = is_reasonalbe_visiting_city(query_data, tested_data)
524
- return_info['is_valid_restaurants'] = is_valid_restaurants(query_data, tested_data)
525
- return_info['is_valid_attractions'] = is_valid_attractions(query_data, tested_data)
526
- return_info['is_valid_accommodation'] = is_valid_accommodaton(query_data, tested_data)
527
- return_info['is_valid_transportation'] = is_valid_transportation(query_data, tested_data)
528
- return_info['is_valid_information_in_current_city'] = is_valid_information_in_current_city(query_data, tested_data)
529
- return_info['is_valid_information_in_sandbox'] = is_valid_information_in_sandbox(query_data, tested_data)
530
- return_info['is_not_absent'] = is_not_absent(query_data, tested_data)
531
- return return_info
532
-
533
- def boolean_evaluation(query_data, tested_data):
534
- return_info = {}
535
- return_info['is_reasonalbe_visiting_city'] = is_reasonalbe_visiting_city(query_data, tested_data)
536
- return_info['is_valid_restaurants'] = is_valid_restaurants(query_data, tested_data)
537
- return_info['is_valid_accommodation'] = is_valid_accommodaton(query_data, tested_data)
538
- return_info['is_valid_attractions'] = is_valid_attractions(query_data, tested_data)
539
- return_info['is_valid_transportation'] = is_valid_transportation(query_data, tested_data)
540
- return_info['is_valid_information_in_current_city'] = is_valid_information_in_current_city(query_data, tested_data)
541
- return_info['is_valid_information_in_sandbox'] = is_valid_information_in_sandbox(query_data, tested_data)
542
- return_info['is_not_absent'] = is_not_absent(query_data, tested_data)
543
- for key in return_info:
544
- if return_info[key][0] == False:
545
- print(return_info[key][1])
546
- return False
547
- return True
548
-
549
- # if __name__ == '__main__':
550
- # number_list = extract_numbers_from_filenames('/home/xj/toolAugEnv/code/toolConstraint/data/annotation/lrz')
551
- # # json_data = json.load(open('/home/xj/toolAugEnv/code/toolConstraint/data/annotation/x/annotation_4.json'))
552
- # query_data = load_line_json_data('/home/xj/toolAugEnv/code/toolConstraint/data/query/lrz.jsonl')
553
- # for idx in number_list:
554
- # json_data = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/lrz/annotation_{idx}.json'))
555
- # print(str(idx), evaluation(query_data[idx-1], json_data))
556
- # # json_data = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/plan_{idx}.json'))
557
- # # query_data = load_line_json_data('/home/xj/toolAugEnv/code/toolConstraint/data/query/test.jsonl')[idx-1]
558
- # # help me write all function name in this file, just the name
559
- # #
560
- # # list all function name in this file
561
- # # ['is_reasonalbe_visiting_city', 'is_valiable_restaurants', 'is_valiable_attractions', 'is_valiable_transportation', 'is_valid_information_in_current_city', 'is_valid_information_in_sandbox']
562
- # # print(is_valiable_restaurants(query_data, json_data))
563
-
564
- # if __name__ == "__main__":
565
- # user = 'zk'
566
- # query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
567
- # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
568
- # commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
569
- # for idx in idx_number_list:
570
- # print(idx)
571
- # query_data = query_data_list[idx-1]
572
- # generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/plan_{idx}.json'))
573
- # # generated_plan = generated_plan[:-1]
574
- # if generated_plan[-1]['gpt-3.5-turbo-16k-result'] != 'Plan Fail':
575
- # info_box = evaluation(query_data, generated_plan[-1]['gpt-3.5-turbo-16k-result'])
576
- # generated_plan[-1]['toolAug-commonsense'] = info_box
577
- # else:
578
- # generated_plan[-1]['toolAug-commonsense'] = None
579
- # info_box = None
580
- # commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
581
- # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/plan_{idx}.json','w') as f:
582
- # json.dump(generated_plan,f)
583
-
584
- # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/commonsense_statistic.json','w') as f:
585
- # json.dump(commonsense_statistic,f)
586
-
587
- # if __name__ == "__main__":
588
- # user = 'all'
589
- # model_type = ['chatgpt','gpt4','greedy_search'][2]
590
- # query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
591
- # # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
592
- # idx_number_list = [i for i in range(1,501)]
593
- # commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
594
-
595
- # for idx in idx_number_list:
596
- # print(idx)
597
- # query_data = query_data_list[idx-1]
598
- # generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/plan_{idx}.json'))
599
- # # generated_plan = generated_plan[:-1]
600
- # if model_type == 'greedy_search':
601
- # info_box = evaluation(query_data, generated_plan[-1][f'greedy_search_plan'])
602
- # else:
603
- # info_box = evaluation(query_data, generated_plan[-1][f'{model_type}_human_collected_info_results_parsed'])
604
- # generated_plan[-1][f'{model_type}_with_human_collected_commonsense'] = info_box
605
- # commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
606
-
607
- # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/plan_{idx}.json','w') as f:
608
- # json.dump(generated_plan,f)
609
-
610
- # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/{model_type}_with_human_collected_commonsense_statistic.json','w') as f:
611
- # json.dump(commonsense_statistic,f)
612
-
613
-
614
- # if __name__ == "__main__":
615
- # user = 'all'
616
- # query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
617
- # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
618
- # hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
619
- # not_satified = []
620
- # for idx in tqdm(idx_number_list):
621
- # # print(idx)
622
- # query_data = query_data_list[idx-1]
623
- # generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}/annotation_{idx}.json'))
624
-
625
- # if not boolean_evaluation(query_data, generated_plan):
626
- # not_satified.append(idx)
627
- # print(idx)
628
- # generated_plan = generated_plan[:-1]
629
- # print(not_satified)
630
-
631
- if __name__ == "__main__":
632
- set_type = ["train",'dev','test'][0]
633
- query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/query/query.jsonl')
634
- # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/plan')
635
- commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
636
- not_satified = []
637
- # print( idx_number_list)
638
- for idx in tqdm(range(1,len(query_data_list)+1)):
639
- # print(idx)
640
- query_data = query_data_list[idx-1]
641
- generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/plan/plan_{idx}.json'))
642
- try:
643
- store_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{idx}.json'))
644
- except FileNotFoundError:
645
- store_plan = [{}]
646
- info_box = evaluation(query_data,generated_plan[1])
647
- # if not boolean_evaluation(query_data, generated_plan[1]):
648
- # not_satified.append(idx)
649
- # print(idx)
650
- # print(store_plan[-1])
651
- store_plan[-1][f'human_anno_commonsense_constraint'] = info_box
652
- with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{idx}.json','w') as f:
653
- json.dump(store_plan,f)
654
- commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
655
- print(not_satified)
656
- with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/human_anno_commonsense_constraint.json','w') as f:
657
- json.dump(commonsense_statistic,f)
658
-
659
- # if __name__ == "__main__":
660
- # user = 'all'
661
- # model_type = ['chatgpt','gpt4'][1]
662
- # query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
663
- # # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
664
- # idx_number_list = [i for i in range(1,501)]
665
- # commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
666
- # cnt = 0
667
- # for idx in idx_number_list:
668
- # # print(idx)
669
- # query_data = query_data_list[idx-1]
670
- # generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre/{user}/plan_{idx}.json'))[-1]['gpt4_human_collected_info_results_parsed']
671
- # # generated_plan = generated_plan[:-1]
672
-
673
- # if not boolean_evaluation(query_data, generated_plan):
674
- # cnt += 1
675
- # print(idx)
676
- # print(cnt)
677
-
678
- # if __name__ == "__main__":
679
- # parser = argparse.ArgumentParser(description="")
680
- # # model_type = ['gpt-3.5-turbo-1106','gpt-4-1106-preview','greedy_search','mistral-7B-32K','gemini2','mixtral','gpt-3.5-turbo-11062'][-1]
681
- # # method = ['direct','cot','react','reflexion','tool-use'][-1]
682
- # # set_type = ['dev','test'][0]
683
- # parser.add_argument("--model_type", type=str, default="gpt-3.5-turbo-1106")
684
- # parser.add_argument("--method", type=str, default="direct")
685
- # parser.add_argument("--set_type", type=str, default="dev")
686
- # args = parser.parse_args()
687
- # directory = f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{args.set_type}'
688
- # query_data_list = load_line_json_data(os.path.join(directory, 'query/query.jsonl'))
689
- # # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
690
- # idx_number_list = [i for i in range(1,len(query_data_list)+1)]
691
- # commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
692
- # deliver_cnt = 0
693
- # if args.method == 'tool-use':
694
- # suffix = ''
695
- # else:
696
- # suffix = '_with_human_info'
697
- # for idx in tqdm(idx_number_list):
698
- # # print(idx)
699
- # query_data = query_data_list[idx-1]
700
- # generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/plan_{idx}.json'))
701
- # # generated_plan = generated_plan[:-1]
702
- # if args.model_type == 'greedy_search':
703
- # info_box = evaluation(query_data, generated_plan[-1][f'greedy_search_plan'])
704
- # else:
705
- # if args.method == 'tool-use':
706
- # suffix2 = ''
707
- # else:
708
- # suffix2 = '_collected'
709
- # if generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] and generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results']!='Max Token Length Exceeded.':
710
- # try:
711
- # info_box = evaluation(query_data, generated_plan[-1][f'{args.model_type}_{args.method}{suffix}_results_parsed'])
712
- # except KeyError:
713
- # info_box = None
714
- # generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] = ""
715
- # except IndexError:
716
- # info_box = None
717
- # generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] = ""
718
- # else:
719
- # info_box = None
720
- # if info_box:
721
- # deliver_cnt += 1
722
- # generated_plan[-1][f'{args.model_type}_{args.method}{suffix}_commonsense_constraint'] = info_box
723
- # commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
724
-
725
- # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/plan_{idx}.json','w') as f:
726
- # json.dump(generated_plan,f)
727
-
728
- # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/{args.model_type}_{args.method}{suffix}_commonsense_constraint.json','w') as f:
729
- # json.dump(commonsense_statistic,f)
730
-
731
- # if args.set_type == 'dev':
732
- # print(f"Model:{args.model_type} Method:{args.method} Set: {args.set_type} \nDeliver Rate: {deliver_cnt/180}" )
733
- # elif args.set_type == 'test':
734
- # print(f"Model:{args.model_type} Method:{args.method} Set: {args.set_type} \nDeliver Rate: {deliver_cnt/1000}" )
735
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/eval.py DELETED
@@ -1,181 +0,0 @@
1
- from commonsenseConstraint import evaluation as commonsense_eval
2
- from hardConstraint import evaluation as hard_eval
3
- import json
4
- from tqdm import tqdm
5
- from datasets import load_dataset
6
-
7
-
8
- def load_line_json_data(filename):
9
- data = []
10
- with open(filename, 'r', encoding='utf-8') as f:
11
- for line in f.read().strip().split('\n'):
12
- unit = json.loads(line)
13
- data.append(unit)
14
- return data
15
-
16
- def count_true_false(data):
17
- """Count the number of true and false values in a list."""
18
- true_count = data.count(True)
19
- false_count = data.count(False)
20
- return true_count, false_count
21
-
22
- def statistics(commonsense_statistic):
23
- """Generate statistics for each level and day in the given data with a different structure."""
24
- result = {level: {day: {} for day in commonsense_statistic[level]} for level in commonsense_statistic}
25
-
26
- for level, days in commonsense_statistic.items():
27
- for day, dicts in days.items():
28
- for dct in dicts:
29
- if dct:
30
- for key, data in dct.items():
31
- true_count, false_count = count_true_false(data)
32
- if key not in result[level][day]:
33
- result[level][day][key] = {"true": 0, "false": 0}
34
- result[level][day][key]["true"] += true_count
35
- result[level][day][key]["false"] += false_count
36
-
37
- return result
38
-
39
-
40
- def eval_score(validation_or_test: str, file_path: str, TOKEN):
41
-
42
- if validation_or_test == 'validation':
43
- query_data_list = load_dataset('osunlp/TravelBenchEval','validation',token=TOKEN)['validation']
44
- elif validation_or_test == 'test':
45
- query_data_list = load_dataset('osunlp/TravelBenchEval','test',token=TOKEN)['test']
46
-
47
- query_data_list = [x for x in query_data_list]
48
- hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
49
- commonsenseConstraint_statistic = {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
50
- tested_plans = load_line_json_data(file_path)
51
- delivery_cnt = 0
52
- plan_constraint_store = []
53
- for idx in tqdm(range(0,len(query_data_list))):
54
- query_data = query_data_list[idx]
55
- tested_plan = tested_plans[idx]
56
- if type(query_data) == str:
57
- query_data = eval(query_data)
58
- if type(tested_plan) == str:
59
- tested_plan = eval(tested_plan)
60
- if type(query_data['local_constraint']) == str:
61
- query_data['local_constraint'] = eval(query_data['local_constraint'])
62
-
63
- if tested_plan['plan']:
64
- delivery_cnt += 1
65
- commonsense_info_box = commonsense_eval(query_data,tested_plan['plan'])
66
- else:
67
- commonsense_info_box = None
68
-
69
- if commonsense_info_box and commonsense_info_box['is_not_absent'][0] and commonsense_info_box['is_valid_information_in_sandbox'][0]:
70
- hard_info_box = hard_eval(query_data,tested_plan['plan'])
71
- else:
72
- hard_info_box = None
73
-
74
- plan_constraint_store.append({'commonsense_constraint':commonsense_info_box,'hard_constraint':hard_info_box})
75
-
76
- commonsenseConstraint_statistic[query_data['level']][query_data['days']].append(commonsense_info_box)
77
- hardConstraint_statistic[query_data['level']][query_data['days']].append(hard_info_box)
78
-
79
- commonsenseConstraint_statistic_processed = statistics(commonsenseConstraint_statistic)
80
- hardConstraint_statistic_processed = statistics(hardConstraint_statistic)
81
- # print(commonsenseConstraint_statistic_processed)
82
- # print(hardConstraint_statistic_processed)
83
- constraint_record = {key: {day: {'house rule':0, 'cuisine':0, 'room type':0, 'transportation':0} for day in [3,5,7]} for key in ['medium','hard']}
84
- constraint_mapping = {'house rule':'valid_room_rule','cuisine':'valid_cuisine','room type':'valid_room_type','transportation':'valid_transportation'}
85
- mapping_constraint_record = {key: {day: {'valid_room_rule':0, 'valid_cuisine':0, 'valid_room_type':0, 'valid_transportation':0} for day in [3,5,7]} for key in ['medium','hard']}
86
- count_record = {key:{day:0 for day in [3,5,7]} for key in ['easy','medium','hard']}
87
-
88
- for unit in query_data_list:
89
- count_record[unit['level']][unit['days']] += 1
90
- for key in constraint_record['medium'][3]:
91
- if unit['local_constraint'][key] != None:
92
- constraint_record[unit['level']][unit['days']][key] += 1
93
- mapping_constraint_record[unit['level']][unit['days']][constraint_mapping[key]] += 1
94
-
95
- data_record = {key:{day:[] for day in [3,5,7]} for key in ['easy','medium','hard']}
96
-
97
- constraint_dis_record = {"commonsense":{"pass":0,"total":0},"hard":{"pass":0,"total":0}}
98
-
99
- for constraint in ['commonsense','hard']:
100
- if constraint == 'commonsense':
101
- constraint_statistic = commonsenseConstraint_statistic_processed
102
- elif constraint == 'hard':
103
- constraint_statistic = hardConstraint_statistic_processed
104
-
105
- key_dict = {'commonsense':['is_valid_information_in_current_city','is_valid_information_in_sandbox','is_reasonalbe_visiting_city','is_valid_restaurants','is_valid_transportation','is_valid_attractions','is_valid_accommodation','is_not_absent'],'hard':['valid_cost','valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']}
106
-
107
- for key in constraint_statistic:
108
- # level
109
- for key2 in constraint_statistic[key]:
110
- # day
111
- # print(key2)
112
- # key2 = eval(key2)
113
- if key2 == -1:
114
- print(constraint_statistic[key])
115
- exit(0)
116
- for key3 in key_dict[constraint]:
117
- data_record[key][key2].append('0/0')
118
- if key3 in constraint_statistic[key][key2]:
119
- constraint_dis_record[constraint]['pass'] += constraint_statistic[key][key2][key3]['true']
120
- if constraint == 'hard':
121
- if key == 'hard' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']:
122
- data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}"
123
- constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3]
124
- elif key == 'medium' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type']:
125
- data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}"
126
- constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3]
127
- else:
128
- data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}"
129
- if key3 in ['valid_cost','valid_visitng_city_number','valid_days']:
130
- constraint_dis_record[constraint]['total'] += count_record[key][key2]
131
- else:
132
- data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}"
133
- constraint_dis_record[constraint]['total'] += count_record[key][key2]
134
-
135
- final_all_cnt = 0
136
- final_commonsense_cnt = 0
137
- final_hardConstraint_cnt = 0
138
- final_all_cnt_map = {level:0 for level in ['easy','medium','hard']}
139
- for idx in (range(0,len(query_data_list))):
140
- if plan_constraint_store[idx]['commonsense_constraint']:
141
- final_commonsense_pass = True
142
- final_hardConstraint_pass = True
143
- for item in plan_constraint_store[idx]['commonsense_constraint']:
144
- if plan_constraint_store[idx]['commonsense_constraint'][item][0] is not None and not plan_constraint_store[idx]['commonsense_constraint'][item][0]:
145
- final_commonsense_pass = False
146
- break
147
- if plan_constraint_store[idx]['hard_constraint'] is None:
148
- continue
149
- for item in plan_constraint_store[idx]['hard_constraint']:
150
- if plan_constraint_store[idx]['hard_constraint'][item][0] is not None and plan_constraint_store[idx]['hard_constraint'][item][0] == False:
151
- final_hardConstraint_pass = False
152
- break
153
-
154
- if final_commonsense_pass:
155
- final_commonsense_cnt += 1
156
- if final_hardConstraint_pass:
157
- final_hardConstraint_cnt += 1
158
- if final_commonsense_pass and final_hardConstraint_pass:
159
- final_all_cnt += 1
160
- final_all_cnt_map[query_data_list[idx]['level']] += 1
161
-
162
- result = {}
163
-
164
- if validation_or_test == 'validation':
165
- result['Delivery Rate'] = delivery_cnt / 180
166
- result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 1440
167
- result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 180
168
- result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 420
169
- result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 180
170
- result['Final Pass Rate'] = final_all_cnt / 180
171
-
172
- elif validation_or_test == 'test':
173
- result['Delivery Rate'] = delivery_cnt / 1000
174
- result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 8000
175
- result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 1000
176
- result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 2290
177
- result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 1000
178
- result['Final Pass Rate'] = final_all_cnt / 1000
179
-
180
- return result
181
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/hardConstraint.py DELETED
@@ -1,266 +0,0 @@
1
- from annotation.src.utils import get_valid_name_city,extract_before_parenthesis,extract_numbers_from_filenames
2
- from tools.flights.apis import Flights
3
- from tools.accommodations.apis import Accommodations
4
- from tools.restaurants.apis import Restaurants
5
- from tools.googleDistanceMatrix.apis import GoogleDistanceMatrix
6
- from tools.attractions.apis import Attractions
7
- import math
8
- import json
9
- import re
10
- import numpy as np
11
- import os
12
- import sys
13
- from tqdm import tqdm
14
- import argparse
15
-
16
- sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
17
- os.chdir(os.path.dirname(os.path.abspath(__file__)))
18
-
19
-
20
- flight = Flights()
21
- accommodation = Accommodations()
22
- restaurants = Restaurants()
23
- googleDistanceMatrix = GoogleDistanceMatrix()
24
- attractions = Attractions()
25
-
26
-
27
- def load_line_json_data(filename):
28
- data = []
29
- with open(filename, 'r', encoding='utf-8') as f:
30
- for line in f.read().strip().split('\n'):
31
- unit = json.loads(line)
32
- data.append(unit)
33
- return data
34
-
35
-
36
- def convert_bool_values(item):
37
- if isinstance(item, dict):
38
- # If the item is a dictionary, recurse on each value
39
- return {key: convert_bool_values(value) for key, value in item.items()}
40
- elif isinstance(item, list):
41
- # If the item is a list, recurse on each item in the list
42
- return [convert_bool_values(value) for value in item]
43
- elif isinstance(item, tuple):
44
- # If the item is a tuple, recurse on each item in the tuple and repackage as a tuple
45
- return tuple(convert_bool_values(value) for value in item)
46
- elif isinstance(item, np.bool_): # Here we check for numpy's bool_ type
47
- # If the item is a numpy bool_, convert it to a standard Python bool
48
- return bool(item)
49
- else:
50
- # If the item is any other type, return it unchanged
51
- return item
52
-
53
-
54
-
55
-
56
- def extract_from_to(text: str):
57
- """
58
- Extracts 'A' and 'B' from the format "from A to B" in the given text, with B ending at a comma or the end of the string.
59
-
60
- Args:
61
- - text (str): The input string.
62
-
63
- Returns:
64
- - tuple: A tuple containing 'A' and 'B'. If no match is found, returns (None, None).
65
- """
66
- pattern = r"from\s+(.+?)\s+to\s+([^,]+)(?=[,\s]|$)"
67
- matches = re.search(pattern, text)
68
- return matches.groups() if matches else (None, None)
69
-
70
-
71
- def get_total_cost(question, tested_data):
72
- total_cost = 0
73
- for i in range(min(question['days'],len(tested_data))):
74
- unit = tested_data[i]
75
- # transporation
76
- if unit['transportation'] and unit['transportation'] != '-':
77
- value = unit['transportation']
78
- org_city, dest_city = extract_from_to(value)
79
- if org_city == None or dest_city == None:
80
- org_city, dest_city = extract_from_to(unit['current_city'])
81
-
82
- if org_city == None or dest_city == None:
83
- pass
84
- else:
85
- if 'flight number' in value.lower():
86
- res = flight.data[flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]]
87
- if len(res) > 0:
88
- total_cost += res['Price'].values[0] * question['people_number']
89
-
90
- elif 'self-driving' in value.lower() or 'taxi' in value.lower():
91
- if 'self-driving' in value.lower():
92
- # print(org_city,dest_city)
93
- cost = googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'self-driving')['cost']
94
- total_cost += cost * math.ceil(question['people_number'] * 1.0 / 5)
95
- else:
96
- cost = googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'taxi')['cost']
97
- total_cost += cost * math.ceil(question['people_number'] * 1.0 / 4)
98
-
99
- # breakfast
100
- if unit['breakfast'] and unit['breakfast'] != '-':
101
- name, city = get_valid_name_city(unit['breakfast'])
102
- res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
103
- if len(res) > 0:
104
- total_cost += res['Average Cost'].values[0] * question['people_number']
105
-
106
-
107
- # lunch
108
- if unit['lunch'] and unit['lunch'] != '-':
109
- name, city = get_valid_name_city(unit['lunch'])
110
- res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
111
- if len(res) > 0:
112
- total_cost += res['Average Cost'].values[0] * question['people_number']
113
-
114
- # dinner
115
- if unit['dinner'] and unit['dinner'] != '-':
116
- name, city = get_valid_name_city(unit['dinner'])
117
- res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
118
- if len(res) > 0:
119
- total_cost += res['Average Cost'].values[0] * question['people_number']
120
-
121
- # accommodation
122
- if unit['accommodation'] and unit['accommodation'] != '-':
123
- name, city = get_valid_name_city(unit['accommodation'])
124
- res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
125
- if len(res) > 0:
126
- total_cost += res['price'].values[0] * math.ceil(question['people_number'] * 1.0 / res['maximum occupancy'].values[0])
127
- # print(total_cost)
128
- return total_cost
129
-
130
-
131
- def is_valid_room_rule(question, tested_data):
132
-
133
- if question['local_constraint']['house rule'] is None:
134
- return None,None
135
-
136
- for i in range(min(question['days'],len(tested_data))):
137
- unit = tested_data[i]
138
- if unit['accommodation'] and unit['accommodation'] != '-':
139
- name, city = get_valid_name_city(unit['accommodation'])
140
- res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
141
- if len(res) > 0:
142
- if question['local_constraint']['house rule'] == 'smoking' and 'No smoking' in str(res['house_rules'].values[0]):
143
- return False, f"The house rule should be {question['local_constraint']['house rule']}."
144
- if question['local_constraint']['house rule'] == 'parities' and 'No parties' in str(res['house_rules'].values[0]):
145
- return False, f"The house rule should be {question['local_constraint']['house rule']}."
146
- if question['local_constraint']['house rule'] == 'children under 10' and 'No children under 10' in str(res['house_rules'].values[0]):
147
- return False, f"The house rule should be {question['local_constraint']['house rule']}."
148
- if question['local_constraint']['house rule'] == 'visitors' and 'No visitors' in str(res['house_rules'].values[0]):
149
- return False, f"The house rule should be {question['local_constraint']['house rule']}."
150
- if question['local_constraint']['house rule'] == 'pets' and 'No pets' in str(res['house_rules'].values[0]):
151
- return False, f"The house rule should be {question['local_constraint']['house rule']}."
152
-
153
-
154
- return True, None
155
-
156
-
157
-
158
- def is_valid_cuisine(question, tested_data):
159
- cuisine_set = set()
160
- if question['local_constraint']['cuisine']:
161
- for i in range(min(question['days'],len(tested_data))):
162
- unit = tested_data[i]
163
-
164
- if unit['breakfast'] and unit['breakfast'] != '-':
165
- name, city = get_valid_name_city(unit['breakfast'])
166
- if city == question['org']:
167
- continue
168
- res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
169
- if len(res) > 0:
170
- for cuisine in question['local_constraint']['cuisine']:
171
- if cuisine in res.iloc[0]['Cuisines']:
172
- cuisine_set.add(cuisine)
173
-
174
- if unit['lunch'] and unit['lunch'] != '-':
175
- name, city = get_valid_name_city(unit['lunch'])
176
- if city == question['org']:
177
- continue
178
- res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
179
- if len(res) > 0:
180
- for cuisine in question['local_constraint']['cuisine']:
181
- if cuisine in res.iloc[0]['Cuisines']:
182
- cuisine_set.add(cuisine)
183
-
184
- if unit['dinner'] and unit['dinner'] != '-':
185
- name, city = get_valid_name_city(unit['dinner'])
186
- if city == question['org']:
187
- continue
188
- res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
189
- if len(res) > 0:
190
- for cuisine in question['local_constraint']['cuisine']:
191
- if cuisine in res.iloc[0]['Cuisines']:
192
- cuisine_set.add(cuisine)
193
-
194
- if len(cuisine_set) == len(question['local_constraint']['cuisine']):
195
- return True, None
196
- else:
197
- # judge which cuisine is not satisfied
198
- for cuisine in question['local_constraint']['cuisine']:
199
- if cuisine not in cuisine_set:
200
- return False, f"The cuisine {cuisine} is not satisfied."
201
- # return False, f"The cuisine should be {question['local_constraint']['cuisine']}."
202
- else:
203
- return None,None
204
-
205
-
206
- def is_valid_transportation(question, tested_data):
207
- if question['local_constraint']['transportation'] is None:
208
- return None,None
209
- for i in range(min(question['days'],len(tested_data))):
210
- unit = tested_data[i]
211
- if unit['transportation'] and unit['transportation'] != '-':
212
- value = unit['transportation']
213
- if question['local_constraint']['transportation'] == 'no flight' and 'Flight' in value:
214
- return False, f"The transportation should not be {question['local_constraint']['transportation']}."
215
- elif question['local_constraint']['transportation'] == 'no self-driving' and 'Self-driving' in value:
216
- return False, f"The transportation should not be {question['local_constraint']['transportation']}."
217
-
218
- return True, None
219
-
220
-
221
- def is_valid_room_type(question, tested_data):
222
- if question['local_constraint']['room type'] is None:
223
- return None,None
224
- for i in range(min(question['days'],len(tested_data))):
225
- unit = tested_data[i]
226
- if unit['accommodation'] and unit['accommodation'] != '-':
227
- name, city = get_valid_name_city(unit['accommodation'])
228
- res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
229
- if len(res) > 0:
230
- if question['local_constraint']['room type'] == 'not shared room' and res['room type'].values[0] == 'Shared room':
231
- return False, f"The room type should be {question['local_constraint']['room type']}."
232
- # "shared room", "not shared room", "private room", "entire room"
233
- elif question['local_constraint']['room type'] == 'shared room' and res['room type'].values[0] != 'Shared room':
234
- return False, f"The room type should be {question['local_constraint']['room type']}."
235
-
236
- elif question['local_constraint']['room type'] == 'private room' and res['room type'].values[0] != 'Private room':
237
- return False, f"The room type should be {question['local_constraint']['room type']}."
238
-
239
- elif question['local_constraint']['room type'] == 'entire room' and res['room type'].values[0] != 'Entire home/apt':
240
- return False, f"The room type should be {question['local_constraint']['room type']}."
241
-
242
- return True, None
243
-
244
-
245
- def evaluation(query_data, tested_data):
246
- return_info = {}
247
- return_info['valid_cuisine'] = is_valid_cuisine(query_data, tested_data)
248
- return_info['valid_room_rule'] = is_valid_room_rule(query_data, tested_data)
249
- return_info['valid_transportation'] = is_valid_transportation(query_data, tested_data)
250
- return_info['valid_room_type'] = is_valid_room_type(query_data, tested_data)
251
- return_info['valid_cost'] = (bool(get_total_cost(query_data, tested_data) <= query_data['budget']), None)
252
- return return_info
253
-
254
- def boolean_evaluation(query_data, tested_data):
255
- return_info = {}
256
- return_info['valid_cuisine'] = is_valid_cuisine(query_data, tested_data)
257
- return_info['valid_room_rule'] = is_valid_room_rule(query_data, tested_data)
258
- return_info['valid_transportation'] = is_valid_transportation(query_data, tested_data)
259
- return_info['valid_room_type'] = is_valid_room_type(query_data, tested_data)
260
- return_info['valid_cost'] = (bool(get_total_cost(query_data, tested_data) <= query_data['budget']), None)
261
- for key in return_info:
262
- if return_info[key][0] == False:
263
- print(key)
264
- return False
265
- return True
266
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/scored/1_validation_two-stage_1.jsonl DELETED
@@ -1 +0,0 @@
1
- {"Delivery Rate": 0.8944444444444445, "Commonsense Constraint Micro Pass Rate": 0.6111111111111112, "Commonsense Constraint Macro Pass Rate": 0.027777777777777776, "Hard Constraint Micro Pass Rate": 0.1523809523809524, "Hard Constraint Macro Pass Rate": 0.10555555555555556, "Final Pass Rate": 0.005555555555555556}
 
 
evaluation/scored/textbox_validation_two-stage_1.jsonl DELETED
@@ -1 +0,0 @@
1
- {"Delivery Rate": 0.8944444444444445, "Commonsense Constraint Micro Pass Rate": 0.6111111111111112, "Commonsense Constraint Macro Pass Rate": 0.027777777777777776, "Hard Constraint Micro Pass Rate": 0.1523809523809524, "Hard Constraint Macro Pass Rate": 0.10555555555555556, "Final Pass Rate": 0.005555555555555556}