hsaest commited on
Commit
3a3b852
1 Parent(s): 02421e8

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. app.py +1 -2
  2. commonsenseConstraint.py +735 -0
  3. eval.py +181 -0
  4. hardConstraint.py +266 -0
  5. requirements.txt +1 -2
app.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import sys
3
  sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "./leaderboard/evaluation")))
4
  sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "./leaderboard")))
5
- print(sys.path)
6
  os.chdir(os.path.dirname(os.path.abspath(__file__)))
7
  import json
8
  import datetime
@@ -19,7 +18,7 @@ from huggingface_hub import HfApi
19
  # InfoStrings
20
  # from scorer import question_scorer
21
  from content import format_error, format_warning, format_log, TITLE, INTRODUCTION_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, model_hyperlink
22
- from evaluation.eval import eval_score
23
 
24
  TOKEN = os.environ.get("TOKEN", None)
25
 
 
2
  import sys
3
  sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "./leaderboard/evaluation")))
4
  sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "./leaderboard")))
 
5
  os.chdir(os.path.dirname(os.path.abspath(__file__)))
6
  import json
7
  import datetime
 
18
  # InfoStrings
19
  # from scorer import question_scorer
20
  from content import format_error, format_warning, format_log, TITLE, INTRODUCTION_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, model_hyperlink
21
+ from eval import eval_score
22
 
23
  TOKEN = os.environ.get("TOKEN", None)
24
 
commonsenseConstraint.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from annotation.src.utils import get_valid_name_city,extract_before_parenthesis,extract_numbers_from_filenames
2
+ from tools.flights.apis import Flights
3
+ from tools.accommodations.apis import Accommodations
4
+ from tools.restaurants.apis import Restaurants
5
+ from tools.googleDistanceMatrix.apis import GoogleDistanceMatrix
6
+ from tools.attractions.apis import Attractions
7
+ import math
8
+ import json
9
+ import re
10
+ import os
11
+ import sys
12
+ from tqdm import tqdm
13
+ import argparse
14
+
15
+ sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
16
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
17
+
18
+ flight = Flights()
19
+ accommodation = Accommodations()
20
+ restaurants = Restaurants()
21
+ googleDistanceMatrix = GoogleDistanceMatrix()
22
+ attractions = Attractions()
23
+
24
+ city_state_set = open('../database/background/citySet_with_states.txt','r').read().split('\n')
25
+ city_state_map = {x:y for x,y in [unit.split('\t') for unit in city_state_set]}
26
+
27
+
28
+ def load_line_json_data(filename):
29
+ data = []
30
+ with open(filename, 'r', encoding='utf-8') as f:
31
+ for line in f.read().strip().split('\n'):
32
+ unit = json.loads(line)
33
+ data.append(unit)
34
+ return data
35
+
36
+
37
+ def count_consecutive_values(lst):
38
+ if not lst:
39
+ return []
40
+
41
+ result = []
42
+ current_string = lst[0]
43
+ count = 1
44
+
45
+ for i in range(1, len(lst)):
46
+ if lst[i] == current_string:
47
+ count += 1
48
+ else:
49
+ result.append((current_string, count))
50
+ current_string = lst[i]
51
+ count = 1
52
+
53
+ result.append((current_string, count)) # Add the last group of values
54
+ return result
55
+
56
+
57
+ def transportation_match(text: str):
58
+
59
+ if 'taxi' in text.lower():
60
+ return 'Taxi'
61
+
62
+ elif 'self-driving' in text.lower():
63
+ return 'Self-driving'
64
+
65
+ elif 'flight' in text.lower():
66
+ return 'Flight'
67
+
68
+
69
+ def extract_from_to(text: str):
70
+ """
71
+ Extracts 'A' and 'B' from the format "from A to B" in the given text, with B ending at a comma or the end of the string.
72
+
73
+ Args:
74
+ - text (str): The input string.
75
+
76
+ Returns:
77
+ - tuple: A tuple containing 'A' and 'B'. If no match is found, returns (None, None).
78
+ """
79
+ pattern = r"from\s+(.+?)\s+to\s+([^,]+)(?=[,\s]|$)"
80
+ matches = re.search(pattern, text)
81
+ return matches.groups() if matches else (None, None)
82
+
83
+
84
+
85
+ def is_valid_city_sequence(city_list):
86
+ """
87
+ Checks if the city sequence is valid. A valid sequence has every city (except the first and last)
88
+ appearing consecutively, and no city should appear again once its sequence is over.
89
+
90
+ Args:
91
+ - city_list (list): List of cities.
92
+
93
+ Returns:
94
+ - bool: True if the sequence is valid, False otherwise.
95
+ """
96
+
97
+ # If the list has less than 3 cities, it's invalid.
98
+ if len(city_list) < 3:
99
+ return False
100
+
101
+ # Set to keep track of visited cities
102
+ visited_cities = set()
103
+
104
+ i = 0
105
+ while i < len(city_list):
106
+ city = city_list[i]
107
+
108
+ # If the city was already visited, it's invalid.
109
+ if city in visited_cities and (i != 0 and i != len(city_list) - 1):
110
+ return False
111
+
112
+ # Count the consecutive occurrences of the city
113
+ count = 0
114
+ while i < len(city_list) and city_list[i] == city:
115
+ count += 1
116
+ i += 1
117
+
118
+ # If the city appeared only once in the medium, it's invalid.
119
+ if count == 1 and 0 < i - 1 < len(city_list) - 1:
120
+ return False
121
+
122
+ visited_cities.add(city)
123
+
124
+ return True
125
+
126
+
127
+
128
+ def is_reasonalbe_visiting_city(question, tested_data):
129
+
130
+ city_list = []
131
+
132
+ # print(tested_data)
133
+ for i in range(min(question['days'],len(tested_data))):
134
+ city_value = tested_data[i]['current_city']
135
+
136
+ if 'from' in city_value:
137
+ city1, city2 = extract_from_to(city_value)
138
+ city1 = extract_before_parenthesis(city1)
139
+ city2 = extract_before_parenthesis(city2)
140
+ if i==0 and city1 != question['org']:
141
+ return False, f"The first day's city should be {question['org']}."
142
+
143
+ city_list += [city1, city2]
144
+
145
+ else:
146
+ city_list.append(extract_before_parenthesis(city_value))
147
+
148
+ if city_list[0] != city_list[-1]:
149
+ return False, "The trip should be a closed circle."
150
+
151
+ if not is_valid_city_sequence(city_list):
152
+ return False, "The city sequence is invalid."
153
+
154
+ for idx, city in enumerate(city_list):
155
+ if city not in city_state_map:
156
+ return False, f"{city} is not a valid city."
157
+ if idx not in [0,len(city_list)-1] and question['days'] >3 and city_state_map[city] != question['dest']:
158
+ return False, f"{city} is not in {question['dest']}."
159
+
160
+ return True, None
161
+
162
+
163
+ def is_valid_restaurants(question, tested_data):
164
+
165
+ restaurants_list = []
166
+
167
+ for i in range(min(question['days'],len(tested_data))):
168
+ unit = tested_data[i]
169
+
170
+ if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
171
+ if unit['breakfast'] not in restaurants_list:
172
+ restaurants_list.append(unit['breakfast'])
173
+ else:
174
+ return False, f"The restaurant in day {i+1} breakfast is repeated."
175
+ # elif 'breakfast' not in unit :
176
+ # return False, f"No Breakfast Info."
177
+
178
+ if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
179
+ if unit['lunch'] not in restaurants_list:
180
+ restaurants_list.append(unit['lunch'])
181
+ else:
182
+ return False, f"The restaurant in day {i+1} lunch {unit['lunch']} is repeated."
183
+ # elif 'lunch' not in unit:
184
+ # return False, f"No Lunch Info."
185
+
186
+ if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
187
+ if unit['dinner'] not in restaurants_list:
188
+ restaurants_list.append(unit['dinner'])
189
+ else:
190
+ return False, f"The restaurant in day {i+1} dinner is repeated."
191
+ # elif 'dinner' not in unit:
192
+ # return False, f"No Dinner Info."
193
+
194
+ return True, None
195
+
196
+ def is_valid_attractions(question, tested_data):
197
+
198
+ attractions_list = []
199
+
200
+ for i in range(min(question['days'],len(tested_data))):
201
+ unit = tested_data[i]
202
+
203
+ if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
204
+ for attraction in unit['attraction'].split(';')[:-1]:
205
+ if attraction not in attractions_list:
206
+ attractions_list.append(attraction)
207
+ else:
208
+ return False, f"The attraction '{attraction}' in day {i+1} is repeated."
209
+
210
+ # elif 'attraction' not in unit:
211
+ # return False, f"No Attraction Info."
212
+
213
+ return True, None
214
+
215
+ def is_valid_transportation(question, tested_data):
216
+
217
+ if tested_data[0]['transportation'] and tested_data[0]['transportation'] != '-':
218
+ transportation_list = [transportation_match(tested_data[0]['transportation'])]
219
+
220
+ else:
221
+ return False, "The transportation in day 1 should not be empty."
222
+
223
+ for i in range(min(question['days'],len(tested_data))):
224
+ unit = tested_data[i]
225
+
226
+ if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-':
227
+ transportation_list.append(transportation_match(unit['transportation']))
228
+ # elif 'transportation' not in unit:
229
+ # return False, f"No Transportation Info."
230
+
231
+ if (('Self-driving' in transportation_list) and ('Flight' in transportation_list)) or (('Taxi' in transportation_list) and ('Self-driving' in transportation_list)):
232
+ return False, "The transportation is conflicting."
233
+
234
+ return True, None
235
+
236
+ def is_valid_information_in_current_city(question, tested_data):
237
+
238
+ for i in range(min(question['days'],len(tested_data))):
239
+ unit = tested_data[i]
240
+ current_city = unit['current_city']
241
+ final_city_list = []
242
+
243
+ if 'from' in current_city:
244
+ city1, city2 = extract_from_to(current_city)
245
+ city1 = extract_before_parenthesis(city1)
246
+ city2 = extract_before_parenthesis(city2)
247
+ final_city_list = [city1, city2]
248
+ else:
249
+ final_city_list = extract_before_parenthesis(current_city)
250
+
251
+ if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-':
252
+ for city in final_city_list:
253
+ if city not in unit['transportation']:
254
+ # print(city)
255
+ return False, f"The transportation in day {i+1} is invalid city choice."
256
+ # elif 'transportation' not in unit:
257
+ # return False, f"No Transportation Info."
258
+
259
+ if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
260
+
261
+ flag = False
262
+
263
+ for city in final_city_list:
264
+ if city in unit['breakfast']:
265
+ flag = True
266
+
267
+ if not flag:
268
+ return False, f"The breakfast in day {i+1} is invalid city choice."
269
+ # elif 'breakfast' not in unit:
270
+ # return False, f"No Breakfast Info."
271
+
272
+ if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
273
+ flag = False
274
+
275
+ for city in final_city_list:
276
+ if city in unit['lunch']:
277
+ flag = True
278
+
279
+ if not flag:
280
+ return False, f"The lunch in day {i+1} is invalid city choice."
281
+ # elif 'lunch' not in unit:
282
+ # return False, f"No Lunch Info."
283
+
284
+ if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
285
+ flag = False
286
+
287
+ for city in final_city_list:
288
+ if city in unit['dinner']:
289
+ flag = True
290
+
291
+ if not flag:
292
+ return False, f"The dinner in day {i+1} is invalid city choice."
293
+ # elif 'dinner' not in unit:
294
+ # return False, f"No Dinner Info."
295
+
296
+ if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
297
+
298
+ attraction_list = unit['attraction'].split(';')[:-1]
299
+
300
+ for attraction in attraction_list:
301
+ flag = False
302
+ for city in final_city_list:
303
+ if city in attraction:
304
+ flag = True
305
+ if not flag:
306
+ return False, f"The attraction in day {i+1} is invalid city choice."
307
+
308
+ # elif 'attraction' not in unit:
309
+ # return False, f"No Attraction Info."
310
+
311
+
312
+ if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-':
313
+
314
+ if final_city_list[-1] not in unit['accommodation']:
315
+ return False, f"The accommodation in day {i+1} is invalid city choice."
316
+
317
+ # elif 'accommodation' not in unit:
318
+ # return False, f"No Accommodation Info."
319
+
320
+ return True, None
321
+
322
+ # hallucination
323
+ def is_valid_information_in_sandbox(question, tested_data):
324
+
325
+ for i in range(min(question['days'],len(tested_data))):
326
+ unit = tested_data[i]
327
+
328
+ if unit['transportation'] and unit['transportation'] != '-':
329
+ value = unit['transportation']
330
+ org_city, dest_city = extract_from_to(value)
331
+ if org_city == None or dest_city == None:
332
+ org_city, dest_city = extract_from_to(unit['current_city'])
333
+ if 'flight number' in value.lower():
334
+ try:
335
+ org_city = extract_before_parenthesis(org_city)
336
+ dest_city = extract_before_parenthesis(dest_city)
337
+ except TypeError:
338
+ raise ValueError("The transportation {} in day {} can not be parsed.".format(value,i+1))
339
+ # print(value)
340
+ if len(flight.data[(flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]) & (flight.data['OriginCityName']==org_city) & (flight.data['DestCityName']==dest_city)]) < 1:
341
+ return False, f"The flight number in day {i+1} is invalid in the sandbox."
342
+
343
+ elif 'self-driving' in value.lower() or 'taxi' in value.lower():
344
+ try:
345
+ org_city = extract_before_parenthesis(org_city)
346
+ dest_city = extract_before_parenthesis(dest_city)
347
+ except TypeError:
348
+ org_city = '-'
349
+ dest_city = '-'
350
+ print("The transportation {} in day {} can not be parsed and '-' will be used instead.".format(value,i+1))
351
+
352
+ if 'self-driving' in value.lower():
353
+ if googleDistanceMatrix.run_for_evaluation(org_city, dest_city, mode='self-driving')['cost'] == None:
354
+ return False, f"The self-driving in day {i+1} is invalid in the sandbox."
355
+ else:
356
+ if googleDistanceMatrix.run_for_evaluation(org_city, dest_city, mode='taxi')['cost'] == None:
357
+ return False, f"The taxi in day {i+1} is invalid in the sandbox."
358
+
359
+ if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
360
+ name, city = get_valid_name_city(unit['breakfast'])
361
+ if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
362
+ return False, f"The breakfast in day {i+1} is invalid in the sandbox."
363
+ # elif 'breakfast' not in unit:
364
+ # return False, f"No Breakfast Info."
365
+
366
+ if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
367
+ name, city = get_valid_name_city(unit['lunch'])
368
+ if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
369
+ return False, f"The lunch in day {i+1} is invalid in the sandbox."
370
+ # elif 'lunch' not in unit:
371
+ # return False, f"No Lunch Info."
372
+
373
+ if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
374
+ name, city = get_valid_name_city(unit['dinner'])
375
+ if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
376
+ return False, f"The dinner in day {i+1} is invalid in the sandbox."
377
+ # elif 'dinner' not in unit:
378
+ # return False, f"No Dinner Info."
379
+
380
+ if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
381
+ attractions_list = unit['attraction'].split(';')[:-1]
382
+ for attraction in attractions_list:
383
+ name, city = get_valid_name_city(attraction)
384
+ if len(attractions.data[(attractions.data['Name'].astype(str).str.contains(re.escape(name))) & (attractions.data['City'] == city)]) < 1:
385
+ return False, f"The attraction {attraction} in day {i+1} is invalid in the sandbox."
386
+ # elif 'attraction' not in unit:
387
+ # return False, f"No Attraction Info."
388
+
389
+ if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-':
390
+ name, city = get_valid_name_city(unit['accommodation'])
391
+ # print(name,city)
392
+ # print(accommodation.data[accommodation.data['NAME'].astype(str).str.contains(re.escape(name))])
393
+ if len(accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]) < 1:
394
+ return False, f"The accommodation in day {i+1} is invalid in the sandbox."
395
+ # elif 'accommodation' not in unit:
396
+ # return False, f"No Accommodation Info."
397
+
398
+ return True, None
399
+
400
+
401
+ def is_valid_accommodaton(question, tested_data):
402
+ data = []
403
+ for i in range(min(question['days'],len(tested_data))):
404
+ unit = tested_data[i]
405
+
406
+ if 'accommodation' not in unit:
407
+ return False, f"No Accommodation Info."
408
+
409
+ data.append(unit['accommodation'])
410
+ # data = [unit['accommodation'] for unit in tested_data]
411
+ consectutive_accommodation = count_consecutive_values(data)
412
+ for unit in consectutive_accommodation:
413
+ # print(unit)
414
+ if unit and unit[0] not in ['-',''] :
415
+ name, city = get_valid_name_city(unit[0])
416
+ # print(unit[0],name,city)
417
+ # try:
418
+ if len(accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]) == 1 and unit[1] < accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)].iloc[0]['minimum nights']:
419
+ return False, f"The accommodation {unit[0]} do not obey the minumum nights rule."
420
+ # can not parse data
421
+ # except re.error:
422
+ # continue
423
+
424
+ return True, None
425
+
426
+ def is_valid_visiting_city_number(question, tested_data):
427
+
428
+ city_set = set()
429
+
430
+
431
+ for i in range(min(question['days'],len(tested_data))):
432
+ city_value = tested_data[i]['current_city']
433
+
434
+ if 'from' in city_value:
435
+ city1, city2 = extract_from_to(city_value)
436
+ city1 = extract_before_parenthesis(city1)
437
+ city2 = extract_before_parenthesis(city2)
438
+ if i==0 and city1 != question['org']:
439
+ return False, f"The first day's city should be {question['org']}."
440
+
441
+ city_set.add(city1)
442
+ city_set.add(city2)
443
+
444
+ else:
445
+ city_set.add(extract_before_parenthesis(city_value))
446
+
447
+ city_set.discard(question['org'])
448
+
449
+ if len(city_set) != question['visiting_city_number']:
450
+ return False, f"The number of visiting cities should be {question['visiting_city_number']}."
451
+
452
+ return True, None
453
+
454
+ def is_valid_days(question, tested_data):
455
+ lens = 0
456
+ for i in range(min(question['days'],len(tested_data))):
457
+ if tested_data[i] != {} and tested_data[i]['current_city'] != "You don't need to fill in the information for this or later days.":
458
+ lens += 1
459
+
460
+ if lens != question['days']:
461
+ # print(lens)
462
+ return False, f"The number of days should be {question['days']}."
463
+ else:
464
+ return True, None
465
+
466
+ def is_not_absent(question, tested_data):
467
+ needed_info = 6 * question['days']
468
+ total_valid_info = 0
469
+
470
+ if not is_valid_days(question, tested_data)[0]:
471
+ return False, "Invalid Days"
472
+
473
+ if not is_valid_visiting_city_number(question, tested_data)[0]:
474
+ return False, "Invalid City Number"
475
+
476
+ for i in range(min(question['days'],len(tested_data))):
477
+ unit = tested_data[i]
478
+
479
+ if 'transportation' not in unit:
480
+ return False, f"No Transportation Info."
481
+
482
+ if 'breakfast' not in unit:
483
+ return False, f"No Breakfast Info."
484
+
485
+ if 'lunch' not in unit:
486
+ return False, f"No Lunch Info."
487
+
488
+ if 'dinner' not in unit:
489
+ return False, f"No Dinner Info."
490
+
491
+ if 'attraction' not in unit:
492
+ return False, f"No Attraction Info."
493
+
494
+ if 'accommodation' not in unit:
495
+ return False, f"No Accommodation Info."
496
+
497
+ if ('from ' in unit['current_city'] or 'to ' in unit['current_city']) and unit['transportation'] in ['','-']:
498
+ return False, f"No transportation in day {i+1} is not allowed."
499
+
500
+ if ('from ' not in unit['current_city'] and ' to ' not in unit['current_city']) and unit['attraction'] in ['','-']:
501
+ return False, f"No attaction in day {i+1} is not allowed."
502
+
503
+ if i != question['days'] - 1 and unit['accommodation'] in ['','-']:
504
+ return False, f"No accommodation in day {i+1} is not allowed."
505
+
506
+ if (unit['breakfast'] in ['','-'] or unit['lunch'] in ['','-'] or unit['dinner'] in ['','-']) and 'from ' not in unit['current_city']:
507
+ return False, f"No meal in day {i+1} is not allowed."
508
+
509
+
510
+ for key in unit:
511
+ if unit[key] and unit[key] != '-':
512
+ total_valid_info += 1
513
+
514
+
515
+ if total_valid_info * 1.0 / needed_info < 0.5:
516
+ return False, f"The absent information is more than 50%."
517
+
518
+ return True, None
519
+
520
+
521
+ def evaluation(query_data, tested_data):
522
+ return_info = {}
523
+ return_info['is_reasonalbe_visiting_city'] = is_reasonalbe_visiting_city(query_data, tested_data)
524
+ return_info['is_valid_restaurants'] = is_valid_restaurants(query_data, tested_data)
525
+ return_info['is_valid_attractions'] = is_valid_attractions(query_data, tested_data)
526
+ return_info['is_valid_accommodation'] = is_valid_accommodaton(query_data, tested_data)
527
+ return_info['is_valid_transportation'] = is_valid_transportation(query_data, tested_data)
528
+ return_info['is_valid_information_in_current_city'] = is_valid_information_in_current_city(query_data, tested_data)
529
+ return_info['is_valid_information_in_sandbox'] = is_valid_information_in_sandbox(query_data, tested_data)
530
+ return_info['is_not_absent'] = is_not_absent(query_data, tested_data)
531
+ return return_info
532
+
533
+ def boolean_evaluation(query_data, tested_data):
534
+ return_info = {}
535
+ return_info['is_reasonalbe_visiting_city'] = is_reasonalbe_visiting_city(query_data, tested_data)
536
+ return_info['is_valid_restaurants'] = is_valid_restaurants(query_data, tested_data)
537
+ return_info['is_valid_accommodation'] = is_valid_accommodaton(query_data, tested_data)
538
+ return_info['is_valid_attractions'] = is_valid_attractions(query_data, tested_data)
539
+ return_info['is_valid_transportation'] = is_valid_transportation(query_data, tested_data)
540
+ return_info['is_valid_information_in_current_city'] = is_valid_information_in_current_city(query_data, tested_data)
541
+ return_info['is_valid_information_in_sandbox'] = is_valid_information_in_sandbox(query_data, tested_data)
542
+ return_info['is_not_absent'] = is_not_absent(query_data, tested_data)
543
+ for key in return_info:
544
+ if return_info[key][0] == False:
545
+ print(return_info[key][1])
546
+ return False
547
+ return True
548
+
549
+ # if __name__ == '__main__':
550
+ # number_list = extract_numbers_from_filenames('/home/xj/toolAugEnv/code/toolConstraint/data/annotation/lrz')
551
+ # # json_data = json.load(open('/home/xj/toolAugEnv/code/toolConstraint/data/annotation/x/annotation_4.json'))
552
+ # query_data = load_line_json_data('/home/xj/toolAugEnv/code/toolConstraint/data/query/lrz.jsonl')
553
+ # for idx in number_list:
554
+ # json_data = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/lrz/annotation_{idx}.json'))
555
+ # print(str(idx), evaluation(query_data[idx-1], json_data))
556
+ # # json_data = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/plan_{idx}.json'))
557
+ # # query_data = load_line_json_data('/home/xj/toolAugEnv/code/toolConstraint/data/query/test.jsonl')[idx-1]
558
+ # # help me write all function name in this file, just the name
559
+ # #
560
+ # # list all function name in this file
561
+ # # ['is_reasonalbe_visiting_city', 'is_valiable_restaurants', 'is_valiable_attractions', 'is_valiable_transportation', 'is_valid_information_in_current_city', 'is_valid_information_in_sandbox']
562
+ # # print(is_valiable_restaurants(query_data, json_data))
563
+
564
+ # if __name__ == "__main__":
565
+ # user = 'zk'
566
+ # query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
567
+ # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
568
+ # commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
569
+ # for idx in idx_number_list:
570
+ # print(idx)
571
+ # query_data = query_data_list[idx-1]
572
+ # generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/plan_{idx}.json'))
573
+ # # generated_plan = generated_plan[:-1]
574
+ # if generated_plan[-1]['gpt-3.5-turbo-16k-result'] != 'Plan Fail':
575
+ # info_box = evaluation(query_data, generated_plan[-1]['gpt-3.5-turbo-16k-result'])
576
+ # generated_plan[-1]['toolAug-commonsense'] = info_box
577
+ # else:
578
+ # generated_plan[-1]['toolAug-commonsense'] = None
579
+ # info_box = None
580
+ # commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
581
+ # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/plan_{idx}.json','w') as f:
582
+ # json.dump(generated_plan,f)
583
+
584
+ # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/commonsense_statistic.json','w') as f:
585
+ # json.dump(commonsense_statistic,f)
586
+
587
+ # if __name__ == "__main__":
588
+ # user = 'all'
589
+ # model_type = ['chatgpt','gpt4','greedy_search'][2]
590
+ # query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
591
+ # # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
592
+ # idx_number_list = [i for i in range(1,501)]
593
+ # commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
594
+
595
+ # for idx in idx_number_list:
596
+ # print(idx)
597
+ # query_data = query_data_list[idx-1]
598
+ # generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/plan_{idx}.json'))
599
+ # # generated_plan = generated_plan[:-1]
600
+ # if model_type == 'greedy_search':
601
+ # info_box = evaluation(query_data, generated_plan[-1][f'greedy_search_plan'])
602
+ # else:
603
+ # info_box = evaluation(query_data, generated_plan[-1][f'{model_type}_human_collected_info_results_parsed'])
604
+ # generated_plan[-1][f'{model_type}_with_human_collected_commonsense'] = info_box
605
+ # commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
606
+
607
+ # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/plan_{idx}.json','w') as f:
608
+ # json.dump(generated_plan,f)
609
+
610
+ # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/{model_type}_with_human_collected_commonsense_statistic.json','w') as f:
611
+ # json.dump(commonsense_statistic,f)
612
+
613
+
614
+ # if __name__ == "__main__":
615
+ # user = 'all'
616
+ # query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
617
+ # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
618
+ # hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
619
+ # not_satified = []
620
+ # for idx in tqdm(idx_number_list):
621
+ # # print(idx)
622
+ # query_data = query_data_list[idx-1]
623
+ # generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}/annotation_{idx}.json'))
624
+
625
+ # if not boolean_evaluation(query_data, generated_plan):
626
+ # not_satified.append(idx)
627
+ # print(idx)
628
+ # generated_plan = generated_plan[:-1]
629
+ # print(not_satified)
630
+
631
+ if __name__ == "__main__":
632
+ set_type = ["train",'dev','test'][0]
633
+ query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/query/query.jsonl')
634
+ # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/plan')
635
+ commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
636
+ not_satified = []
637
+ # print( idx_number_list)
638
+ for idx in tqdm(range(1,len(query_data_list)+1)):
639
+ # print(idx)
640
+ query_data = query_data_list[idx-1]
641
+ generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/plan/plan_{idx}.json'))
642
+ try:
643
+ store_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{idx}.json'))
644
+ except FileNotFoundError:
645
+ store_plan = [{}]
646
+ info_box = evaluation(query_data,generated_plan[1])
647
+ # if not boolean_evaluation(query_data, generated_plan[1]):
648
+ # not_satified.append(idx)
649
+ # print(idx)
650
+ # print(store_plan[-1])
651
+ store_plan[-1][f'human_anno_commonsense_constraint'] = info_box
652
+ with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{idx}.json','w') as f:
653
+ json.dump(store_plan,f)
654
+ commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
655
+ print(not_satified)
656
+ with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/human_anno_commonsense_constraint.json','w') as f:
657
+ json.dump(commonsense_statistic,f)
658
+
659
+ # if __name__ == "__main__":
660
+ # user = 'all'
661
+ # model_type = ['chatgpt','gpt4'][1]
662
+ # query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
663
+ # # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
664
+ # idx_number_list = [i for i in range(1,501)]
665
+ # commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
666
+ # cnt = 0
667
+ # for idx in idx_number_list:
668
+ # # print(idx)
669
+ # query_data = query_data_list[idx-1]
670
+ # generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre/{user}/plan_{idx}.json'))[-1]['gpt4_human_collected_info_results_parsed']
671
+ # # generated_plan = generated_plan[:-1]
672
+
673
+ # if not boolean_evaluation(query_data, generated_plan):
674
+ # cnt += 1
675
+ # print(idx)
676
+ # print(cnt)
677
+
678
+ # if __name__ == "__main__":
679
+ # parser = argparse.ArgumentParser(description="")
680
+ # # model_type = ['gpt-3.5-turbo-1106','gpt-4-1106-preview','greedy_search','mistral-7B-32K','gemini2','mixtral','gpt-3.5-turbo-11062'][-1]
681
+ # # method = ['direct','cot','react','reflexion','tool-use'][-1]
682
+ # # set_type = ['dev','test'][0]
683
+ # parser.add_argument("--model_type", type=str, default="gpt-3.5-turbo-1106")
684
+ # parser.add_argument("--method", type=str, default="direct")
685
+ # parser.add_argument("--set_type", type=str, default="dev")
686
+ # args = parser.parse_args()
687
+ # directory = f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{args.set_type}'
688
+ # query_data_list = load_line_json_data(os.path.join(directory, 'query/query.jsonl'))
689
+ # # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
690
+ # idx_number_list = [i for i in range(1,len(query_data_list)+1)]
691
+ # commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
692
+ # deliver_cnt = 0
693
+ # if args.method == 'tool-use':
694
+ # suffix = ''
695
+ # else:
696
+ # suffix = '_with_human_info'
697
+ # for idx in tqdm(idx_number_list):
698
+ # # print(idx)
699
+ # query_data = query_data_list[idx-1]
700
+ # generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/plan_{idx}.json'))
701
+ # # generated_plan = generated_plan[:-1]
702
+ # if args.model_type == 'greedy_search':
703
+ # info_box = evaluation(query_data, generated_plan[-1][f'greedy_search_plan'])
704
+ # else:
705
+ # if args.method == 'tool-use':
706
+ # suffix2 = ''
707
+ # else:
708
+ # suffix2 = '_collected'
709
+ # if generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] and generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results']!='Max Token Length Exceeded.':
710
+ # try:
711
+ # info_box = evaluation(query_data, generated_plan[-1][f'{args.model_type}_{args.method}{suffix}_results_parsed'])
712
+ # except KeyError:
713
+ # info_box = None
714
+ # generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] = ""
715
+ # except IndexError:
716
+ # info_box = None
717
+ # generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] = ""
718
+ # else:
719
+ # info_box = None
720
+ # if info_box:
721
+ # deliver_cnt += 1
722
+ # generated_plan[-1][f'{args.model_type}_{args.method}{suffix}_commonsense_constraint'] = info_box
723
+ # commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
724
+
725
+ # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/plan_{idx}.json','w') as f:
726
+ # json.dump(generated_plan,f)
727
+
728
+ # with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/{args.model_type}_{args.method}{suffix}_commonsense_constraint.json','w') as f:
729
+ # json.dump(commonsense_statistic,f)
730
+
731
+ # if args.set_type == 'dev':
732
+ # print(f"Model:{args.model_type} Method:{args.method} Set: {args.set_type} \nDeliver Rate: {deliver_cnt/180}" )
733
+ # elif args.set_type == 'test':
734
+ # print(f"Model:{args.model_type} Method:{args.method} Set: {args.set_type} \nDeliver Rate: {deliver_cnt/1000}" )
735
+
eval.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from commonsenseConstraint import evaluation as commonsense_eval
2
+ from hardConstraint import evaluation as hard_eval
3
+ import json
4
+ from tqdm import tqdm
5
+ from datasets import load_dataset
6
+
7
+
8
+ def load_line_json_data(filename):
9
+ data = []
10
+ with open(filename, 'r', encoding='utf-8') as f:
11
+ for line in f.read().strip().split('\n'):
12
+ unit = json.loads(line)
13
+ data.append(unit)
14
+ return data
15
+
16
+ def count_true_false(data):
17
+ """Count the number of true and false values in a list."""
18
+ true_count = data.count(True)
19
+ false_count = data.count(False)
20
+ return true_count, false_count
21
+
22
+ def statistics(commonsense_statistic):
23
+ """Generate statistics for each level and day in the given data with a different structure."""
24
+ result = {level: {day: {} for day in commonsense_statistic[level]} for level in commonsense_statistic}
25
+
26
+ for level, days in commonsense_statistic.items():
27
+ for day, dicts in days.items():
28
+ for dct in dicts:
29
+ if dct:
30
+ for key, data in dct.items():
31
+ true_count, false_count = count_true_false(data)
32
+ if key not in result[level][day]:
33
+ result[level][day][key] = {"true": 0, "false": 0}
34
+ result[level][day][key]["true"] += true_count
35
+ result[level][day][key]["false"] += false_count
36
+
37
+ return result
38
+
39
+
40
+ def eval_score(validation_or_test: str, file_path: str, TOKEN):
41
+
42
+ if validation_or_test == 'validation':
43
+ query_data_list = load_dataset('osunlp/TravelBenchEval','validation',token=TOKEN)['validation']
44
+ elif validation_or_test == 'test':
45
+ query_data_list = load_dataset('osunlp/TravelBenchEval','test',token=TOKEN)['test']
46
+
47
+ query_data_list = [x for x in query_data_list]
48
+ hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
49
+ commonsenseConstraint_statistic = {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
50
+ tested_plans = load_line_json_data(file_path)
51
+ delivery_cnt = 0
52
+ plan_constraint_store = []
53
+ for idx in tqdm(range(0,len(query_data_list))):
54
+ query_data = query_data_list[idx]
55
+ tested_plan = tested_plans[idx]
56
+ if type(query_data) == str:
57
+ query_data = eval(query_data)
58
+ if type(tested_plan) == str:
59
+ tested_plan = eval(tested_plan)
60
+ if type(query_data['local_constraint']) == str:
61
+ query_data['local_constraint'] = eval(query_data['local_constraint'])
62
+
63
+ if tested_plan['plan']:
64
+ delivery_cnt += 1
65
+ commonsense_info_box = commonsense_eval(query_data,tested_plan['plan'])
66
+ else:
67
+ commonsense_info_box = None
68
+
69
+ if commonsense_info_box and commonsense_info_box['is_not_absent'][0] and commonsense_info_box['is_valid_information_in_sandbox'][0]:
70
+ hard_info_box = hard_eval(query_data,tested_plan['plan'])
71
+ else:
72
+ hard_info_box = None
73
+
74
+ plan_constraint_store.append({'commonsense_constraint':commonsense_info_box,'hard_constraint':hard_info_box})
75
+
76
+ commonsenseConstraint_statistic[query_data['level']][query_data['days']].append(commonsense_info_box)
77
+ hardConstraint_statistic[query_data['level']][query_data['days']].append(hard_info_box)
78
+
79
+ commonsenseConstraint_statistic_processed = statistics(commonsenseConstraint_statistic)
80
+ hardConstraint_statistic_processed = statistics(hardConstraint_statistic)
81
+ # print(commonsenseConstraint_statistic_processed)
82
+ # print(hardConstraint_statistic_processed)
83
+ constraint_record = {key: {day: {'house rule':0, 'cuisine':0, 'room type':0, 'transportation':0} for day in [3,5,7]} for key in ['medium','hard']}
84
+ constraint_mapping = {'house rule':'valid_room_rule','cuisine':'valid_cuisine','room type':'valid_room_type','transportation':'valid_transportation'}
85
+ mapping_constraint_record = {key: {day: {'valid_room_rule':0, 'valid_cuisine':0, 'valid_room_type':0, 'valid_transportation':0} for day in [3,5,7]} for key in ['medium','hard']}
86
+ count_record = {key:{day:0 for day in [3,5,7]} for key in ['easy','medium','hard']}
87
+
88
+ for unit in query_data_list:
89
+ count_record[unit['level']][unit['days']] += 1
90
+ for key in constraint_record['medium'][3]:
91
+ if unit['local_constraint'][key] != None:
92
+ constraint_record[unit['level']][unit['days']][key] += 1
93
+ mapping_constraint_record[unit['level']][unit['days']][constraint_mapping[key]] += 1
94
+
95
+ data_record = {key:{day:[] for day in [3,5,7]} for key in ['easy','medium','hard']}
96
+
97
+ constraint_dis_record = {"commonsense":{"pass":0,"total":0},"hard":{"pass":0,"total":0}}
98
+
99
+ for constraint in ['commonsense','hard']:
100
+ if constraint == 'commonsense':
101
+ constraint_statistic = commonsenseConstraint_statistic_processed
102
+ elif constraint == 'hard':
103
+ constraint_statistic = hardConstraint_statistic_processed
104
+
105
+ key_dict = {'commonsense':['is_valid_information_in_current_city','is_valid_information_in_sandbox','is_reasonalbe_visiting_city','is_valid_restaurants','is_valid_transportation','is_valid_attractions','is_valid_accommodation','is_not_absent'],'hard':['valid_cost','valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']}
106
+
107
+ for key in constraint_statistic:
108
+ # level
109
+ for key2 in constraint_statistic[key]:
110
+ # day
111
+ # print(key2)
112
+ # key2 = eval(key2)
113
+ if key2 == -1:
114
+ print(constraint_statistic[key])
115
+ exit(0)
116
+ for key3 in key_dict[constraint]:
117
+ data_record[key][key2].append('0/0')
118
+ if key3 in constraint_statistic[key][key2]:
119
+ constraint_dis_record[constraint]['pass'] += constraint_statistic[key][key2][key3]['true']
120
+ if constraint == 'hard':
121
+ if key == 'hard' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']:
122
+ data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}"
123
+ constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3]
124
+ elif key == 'medium' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type']:
125
+ data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}"
126
+ constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3]
127
+ else:
128
+ data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}"
129
+ if key3 in ['valid_cost','valid_visitng_city_number','valid_days']:
130
+ constraint_dis_record[constraint]['total'] += count_record[key][key2]
131
+ else:
132
+ data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}"
133
+ constraint_dis_record[constraint]['total'] += count_record[key][key2]
134
+
135
+ final_all_cnt = 0
136
+ final_commonsense_cnt = 0
137
+ final_hardConstraint_cnt = 0
138
+ final_all_cnt_map = {level:0 for level in ['easy','medium','hard']}
139
+ for idx in (range(0,len(query_data_list))):
140
+ if plan_constraint_store[idx]['commonsense_constraint']:
141
+ final_commonsense_pass = True
142
+ final_hardConstraint_pass = True
143
+ for item in plan_constraint_store[idx]['commonsense_constraint']:
144
+ if plan_constraint_store[idx]['commonsense_constraint'][item][0] is not None and not plan_constraint_store[idx]['commonsense_constraint'][item][0]:
145
+ final_commonsense_pass = False
146
+ break
147
+ if plan_constraint_store[idx]['hard_constraint'] is None:
148
+ continue
149
+ for item in plan_constraint_store[idx]['hard_constraint']:
150
+ if plan_constraint_store[idx]['hard_constraint'][item][0] is not None and plan_constraint_store[idx]['hard_constraint'][item][0] == False:
151
+ final_hardConstraint_pass = False
152
+ break
153
+
154
+ if final_commonsense_pass:
155
+ final_commonsense_cnt += 1
156
+ if final_hardConstraint_pass:
157
+ final_hardConstraint_cnt += 1
158
+ if final_commonsense_pass and final_hardConstraint_pass:
159
+ final_all_cnt += 1
160
+ final_all_cnt_map[query_data_list[idx]['level']] += 1
161
+
162
+ result = {}
163
+
164
+ if validation_or_test == 'validation':
165
+ result['Delivery Rate'] = delivery_cnt / 180
166
+ result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 1440
167
+ result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 180
168
+ result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 420
169
+ result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 180
170
+ result['Final Pass Rate'] = final_all_cnt / 180
171
+
172
+ elif validation_or_test == 'test':
173
+ result['Delivery Rate'] = delivery_cnt / 1000
174
+ result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 8000
175
+ result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 1000
176
+ result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 2290
177
+ result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 1000
178
+ result['Final Pass Rate'] = final_all_cnt / 1000
179
+
180
+ return result
181
+
hardConstraint.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from annotation.src.utils import get_valid_name_city,extract_before_parenthesis,extract_numbers_from_filenames
2
+ from tools.flights.apis import Flights
3
+ from tools.accommodations.apis import Accommodations
4
+ from tools.restaurants.apis import Restaurants
5
+ from tools.googleDistanceMatrix.apis import GoogleDistanceMatrix
6
+ from tools.attractions.apis import Attractions
7
+ import math
8
+ import json
9
+ import re
10
+ import numpy as np
11
+ import os
12
+ import sys
13
+ from tqdm import tqdm
14
+ import argparse
15
+
16
+ sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
17
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
18
+
19
+
20
+ flight = Flights()
21
+ accommodation = Accommodations()
22
+ restaurants = Restaurants()
23
+ googleDistanceMatrix = GoogleDistanceMatrix()
24
+ attractions = Attractions()
25
+
26
+
27
+ def load_line_json_data(filename):
28
+ data = []
29
+ with open(filename, 'r', encoding='utf-8') as f:
30
+ for line in f.read().strip().split('\n'):
31
+ unit = json.loads(line)
32
+ data.append(unit)
33
+ return data
34
+
35
+
36
+ def convert_bool_values(item):
37
+ if isinstance(item, dict):
38
+ # If the item is a dictionary, recurse on each value
39
+ return {key: convert_bool_values(value) for key, value in item.items()}
40
+ elif isinstance(item, list):
41
+ # If the item is a list, recurse on each item in the list
42
+ return [convert_bool_values(value) for value in item]
43
+ elif isinstance(item, tuple):
44
+ # If the item is a tuple, recurse on each item in the tuple and repackage as a tuple
45
+ return tuple(convert_bool_values(value) for value in item)
46
+ elif isinstance(item, np.bool_): # Here we check for numpy's bool_ type
47
+ # If the item is a numpy bool_, convert it to a standard Python bool
48
+ return bool(item)
49
+ else:
50
+ # If the item is any other type, return it unchanged
51
+ return item
52
+
53
+
54
+
55
+
56
+ def extract_from_to(text: str):
57
+ """
58
+ Extracts 'A' and 'B' from the format "from A to B" in the given text, with B ending at a comma or the end of the string.
59
+
60
+ Args:
61
+ - text (str): The input string.
62
+
63
+ Returns:
64
+ - tuple: A tuple containing 'A' and 'B'. If no match is found, returns (None, None).
65
+ """
66
+ pattern = r"from\s+(.+?)\s+to\s+([^,]+)(?=[,\s]|$)"
67
+ matches = re.search(pattern, text)
68
+ return matches.groups() if matches else (None, None)
69
+
70
+
71
+ def get_total_cost(question, tested_data):
72
+ total_cost = 0
73
+ for i in range(min(question['days'],len(tested_data))):
74
+ unit = tested_data[i]
75
+ # transporation
76
+ if unit['transportation'] and unit['transportation'] != '-':
77
+ value = unit['transportation']
78
+ org_city, dest_city = extract_from_to(value)
79
+ if org_city == None or dest_city == None:
80
+ org_city, dest_city = extract_from_to(unit['current_city'])
81
+
82
+ if org_city == None or dest_city == None:
83
+ pass
84
+ else:
85
+ if 'flight number' in value.lower():
86
+ res = flight.data[flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]]
87
+ if len(res) > 0:
88
+ total_cost += res['Price'].values[0] * question['people_number']
89
+
90
+ elif 'self-driving' in value.lower() or 'taxi' in value.lower():
91
+ if 'self-driving' in value.lower():
92
+ # print(org_city,dest_city)
93
+ cost = googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'self-driving')['cost']
94
+ total_cost += cost * math.ceil(question['people_number'] * 1.0 / 5)
95
+ else:
96
+ cost = googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'taxi')['cost']
97
+ total_cost += cost * math.ceil(question['people_number'] * 1.0 / 4)
98
+
99
+ # breakfast
100
+ if unit['breakfast'] and unit['breakfast'] != '-':
101
+ name, city = get_valid_name_city(unit['breakfast'])
102
+ res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
103
+ if len(res) > 0:
104
+ total_cost += res['Average Cost'].values[0] * question['people_number']
105
+
106
+
107
+ # lunch
108
+ if unit['lunch'] and unit['lunch'] != '-':
109
+ name, city = get_valid_name_city(unit['lunch'])
110
+ res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
111
+ if len(res) > 0:
112
+ total_cost += res['Average Cost'].values[0] * question['people_number']
113
+
114
+ # dinner
115
+ if unit['dinner'] and unit['dinner'] != '-':
116
+ name, city = get_valid_name_city(unit['dinner'])
117
+ res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
118
+ if len(res) > 0:
119
+ total_cost += res['Average Cost'].values[0] * question['people_number']
120
+
121
+ # accommodation
122
+ if unit['accommodation'] and unit['accommodation'] != '-':
123
+ name, city = get_valid_name_city(unit['accommodation'])
124
+ res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
125
+ if len(res) > 0:
126
+ total_cost += res['price'].values[0] * math.ceil(question['people_number'] * 1.0 / res['maximum occupancy'].values[0])
127
+ # print(total_cost)
128
+ return total_cost
129
+
130
+
131
+ def is_valid_room_rule(question, tested_data):
132
+
133
+ if question['local_constraint']['house rule'] is None:
134
+ return None,None
135
+
136
+ for i in range(min(question['days'],len(tested_data))):
137
+ unit = tested_data[i]
138
+ if unit['accommodation'] and unit['accommodation'] != '-':
139
+ name, city = get_valid_name_city(unit['accommodation'])
140
+ res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
141
+ if len(res) > 0:
142
+ if question['local_constraint']['house rule'] == 'smoking' and 'No smoking' in str(res['house_rules'].values[0]):
143
+ return False, f"The house rule should be {question['local_constraint']['house rule']}."
144
+ if question['local_constraint']['house rule'] == 'parities' and 'No parties' in str(res['house_rules'].values[0]):
145
+ return False, f"The house rule should be {question['local_constraint']['house rule']}."
146
+ if question['local_constraint']['house rule'] == 'children under 10' and 'No children under 10' in str(res['house_rules'].values[0]):
147
+ return False, f"The house rule should be {question['local_constraint']['house rule']}."
148
+ if question['local_constraint']['house rule'] == 'visitors' and 'No visitors' in str(res['house_rules'].values[0]):
149
+ return False, f"The house rule should be {question['local_constraint']['house rule']}."
150
+ if question['local_constraint']['house rule'] == 'pets' and 'No pets' in str(res['house_rules'].values[0]):
151
+ return False, f"The house rule should be {question['local_constraint']['house rule']}."
152
+
153
+
154
+ return True, None
155
+
156
+
157
+
158
+ def is_valid_cuisine(question, tested_data):
159
+ cuisine_set = set()
160
+ if question['local_constraint']['cuisine']:
161
+ for i in range(min(question['days'],len(tested_data))):
162
+ unit = tested_data[i]
163
+
164
+ if unit['breakfast'] and unit['breakfast'] != '-':
165
+ name, city = get_valid_name_city(unit['breakfast'])
166
+ if city == question['org']:
167
+ continue
168
+ res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
169
+ if len(res) > 0:
170
+ for cuisine in question['local_constraint']['cuisine']:
171
+ if cuisine in res.iloc[0]['Cuisines']:
172
+ cuisine_set.add(cuisine)
173
+
174
+ if unit['lunch'] and unit['lunch'] != '-':
175
+ name, city = get_valid_name_city(unit['lunch'])
176
+ if city == question['org']:
177
+ continue
178
+ res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
179
+ if len(res) > 0:
180
+ for cuisine in question['local_constraint']['cuisine']:
181
+ if cuisine in res.iloc[0]['Cuisines']:
182
+ cuisine_set.add(cuisine)
183
+
184
+ if unit['dinner'] and unit['dinner'] != '-':
185
+ name, city = get_valid_name_city(unit['dinner'])
186
+ if city == question['org']:
187
+ continue
188
+ res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
189
+ if len(res) > 0:
190
+ for cuisine in question['local_constraint']['cuisine']:
191
+ if cuisine in res.iloc[0]['Cuisines']:
192
+ cuisine_set.add(cuisine)
193
+
194
+ if len(cuisine_set) == len(question['local_constraint']['cuisine']):
195
+ return True, None
196
+ else:
197
+ # judge which cuisine is not satisfied
198
+ for cuisine in question['local_constraint']['cuisine']:
199
+ if cuisine not in cuisine_set:
200
+ return False, f"The cuisine {cuisine} is not satisfied."
201
+ # return False, f"The cuisine should be {question['local_constraint']['cuisine']}."
202
+ else:
203
+ return None,None
204
+
205
+
206
+ def is_valid_transportation(question, tested_data):
207
+ if question['local_constraint']['transportation'] is None:
208
+ return None,None
209
+ for i in range(min(question['days'],len(tested_data))):
210
+ unit = tested_data[i]
211
+ if unit['transportation'] and unit['transportation'] != '-':
212
+ value = unit['transportation']
213
+ if question['local_constraint']['transportation'] == 'no flight' and 'Flight' in value:
214
+ return False, f"The transportation should not be {question['local_constraint']['transportation']}."
215
+ elif question['local_constraint']['transportation'] == 'no self-driving' and 'Self-driving' in value:
216
+ return False, f"The transportation should not be {question['local_constraint']['transportation']}."
217
+
218
+ return True, None
219
+
220
+
221
+ def is_valid_room_type(question, tested_data):
222
+ if question['local_constraint']['room type'] is None:
223
+ return None,None
224
+ for i in range(min(question['days'],len(tested_data))):
225
+ unit = tested_data[i]
226
+ if unit['accommodation'] and unit['accommodation'] != '-':
227
+ name, city = get_valid_name_city(unit['accommodation'])
228
+ res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
229
+ if len(res) > 0:
230
+ if question['local_constraint']['room type'] == 'not shared room' and res['room type'].values[0] == 'Shared room':
231
+ return False, f"The room type should be {question['local_constraint']['room type']}."
232
+ # "shared room", "not shared room", "private room", "entire room"
233
+ elif question['local_constraint']['room type'] == 'shared room' and res['room type'].values[0] != 'Shared room':
234
+ return False, f"The room type should be {question['local_constraint']['room type']}."
235
+
236
+ elif question['local_constraint']['room type'] == 'private room' and res['room type'].values[0] != 'Private room':
237
+ return False, f"The room type should be {question['local_constraint']['room type']}."
238
+
239
+ elif question['local_constraint']['room type'] == 'entire room' and res['room type'].values[0] != 'Entire home/apt':
240
+ return False, f"The room type should be {question['local_constraint']['room type']}."
241
+
242
+ return True, None
243
+
244
+
245
+ def evaluation(query_data, tested_data):
246
+ return_info = {}
247
+ return_info['valid_cuisine'] = is_valid_cuisine(query_data, tested_data)
248
+ return_info['valid_room_rule'] = is_valid_room_rule(query_data, tested_data)
249
+ return_info['valid_transportation'] = is_valid_transportation(query_data, tested_data)
250
+ return_info['valid_room_type'] = is_valid_room_type(query_data, tested_data)
251
+ return_info['valid_cost'] = (bool(get_total_cost(query_data, tested_data) <= query_data['budget']), None)
252
+ return return_info
253
+
254
+ def boolean_evaluation(query_data, tested_data):
255
+ return_info = {}
256
+ return_info['valid_cuisine'] = is_valid_cuisine(query_data, tested_data)
257
+ return_info['valid_room_rule'] = is_valid_room_rule(query_data, tested_data)
258
+ return_info['valid_transportation'] = is_valid_transportation(query_data, tested_data)
259
+ return_info['valid_room_type'] = is_valid_room_type(query_data, tested_data)
260
+ return_info['valid_cost'] = (bool(get_total_cost(query_data, tested_data) <= query_data['budget']), None)
261
+ for key in return_info:
262
+ if return_info[key][0] == False:
263
+ print(key)
264
+ return False
265
+ return True
266
+
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
  datasets==2.16.1
2
  gradio==3.50.2
3
- huggingface-hub==0.20.2
4
- APScheduler==3.10.1
 
1
  datasets==2.16.1
2
  gradio==3.50.2
3
+ huggingface-hub==0.20.2