Delete evaluation
Browse files- evaluation/.DS_Store +0 -0
- evaluation/__pycache__/commonsenseConstraint.cpython-39.pyc +0 -0
- evaluation/__pycache__/eval.cpython-39.pyc +0 -0
- evaluation/__pycache__/hardConstraint.cpython-39.pyc +0 -0
- evaluation/commonsenseConstraint.py +0 -735
- evaluation/eval.py +0 -181
- evaluation/hardConstraint.py +0 -266
- evaluation/scored/1_validation_two-stage_1.jsonl +0 -1
- evaluation/scored/textbox_validation_two-stage_1.jsonl +0 -1
evaluation/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
evaluation/__pycache__/commonsenseConstraint.cpython-39.pyc
DELETED
Binary file (14 kB)
|
|
evaluation/__pycache__/eval.cpython-39.pyc
DELETED
Binary file (7.05 kB)
|
|
evaluation/__pycache__/hardConstraint.cpython-39.pyc
DELETED
Binary file (8.13 kB)
|
|
evaluation/commonsenseConstraint.py
DELETED
@@ -1,735 +0,0 @@
|
|
1 |
-
from annotation.src.utils import get_valid_name_city,extract_before_parenthesis,extract_numbers_from_filenames
|
2 |
-
from tools.flights.apis import Flights
|
3 |
-
from tools.accommodations.apis import Accommodations
|
4 |
-
from tools.restaurants.apis import Restaurants
|
5 |
-
from tools.googleDistanceMatrix.apis import GoogleDistanceMatrix
|
6 |
-
from tools.attractions.apis import Attractions
|
7 |
-
import math
|
8 |
-
import json
|
9 |
-
import re
|
10 |
-
import os
|
11 |
-
import sys
|
12 |
-
from tqdm import tqdm
|
13 |
-
import argparse
|
14 |
-
|
15 |
-
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
|
16 |
-
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
17 |
-
|
18 |
-
flight = Flights()
|
19 |
-
accommodation = Accommodations()
|
20 |
-
restaurants = Restaurants()
|
21 |
-
googleDistanceMatrix = GoogleDistanceMatrix()
|
22 |
-
attractions = Attractions()
|
23 |
-
|
24 |
-
city_state_set = open('../database/background/citySet_with_states.txt','r').read().split('\n')
|
25 |
-
city_state_map = {x:y for x,y in [unit.split('\t') for unit in city_state_set]}
|
26 |
-
|
27 |
-
|
28 |
-
def load_line_json_data(filename):
|
29 |
-
data = []
|
30 |
-
with open(filename, 'r', encoding='utf-8') as f:
|
31 |
-
for line in f.read().strip().split('\n'):
|
32 |
-
unit = json.loads(line)
|
33 |
-
data.append(unit)
|
34 |
-
return data
|
35 |
-
|
36 |
-
|
37 |
-
def count_consecutive_values(lst):
|
38 |
-
if not lst:
|
39 |
-
return []
|
40 |
-
|
41 |
-
result = []
|
42 |
-
current_string = lst[0]
|
43 |
-
count = 1
|
44 |
-
|
45 |
-
for i in range(1, len(lst)):
|
46 |
-
if lst[i] == current_string:
|
47 |
-
count += 1
|
48 |
-
else:
|
49 |
-
result.append((current_string, count))
|
50 |
-
current_string = lst[i]
|
51 |
-
count = 1
|
52 |
-
|
53 |
-
result.append((current_string, count)) # Add the last group of values
|
54 |
-
return result
|
55 |
-
|
56 |
-
|
57 |
-
def transportation_match(text: str):
|
58 |
-
|
59 |
-
if 'taxi' in text.lower():
|
60 |
-
return 'Taxi'
|
61 |
-
|
62 |
-
elif 'self-driving' in text.lower():
|
63 |
-
return 'Self-driving'
|
64 |
-
|
65 |
-
elif 'flight' in text.lower():
|
66 |
-
return 'Flight'
|
67 |
-
|
68 |
-
|
69 |
-
def extract_from_to(text: str):
|
70 |
-
"""
|
71 |
-
Extracts 'A' and 'B' from the format "from A to B" in the given text, with B ending at a comma or the end of the string.
|
72 |
-
|
73 |
-
Args:
|
74 |
-
- text (str): The input string.
|
75 |
-
|
76 |
-
Returns:
|
77 |
-
- tuple: A tuple containing 'A' and 'B'. If no match is found, returns (None, None).
|
78 |
-
"""
|
79 |
-
pattern = r"from\s+(.+?)\s+to\s+([^,]+)(?=[,\s]|$)"
|
80 |
-
matches = re.search(pattern, text)
|
81 |
-
return matches.groups() if matches else (None, None)
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
def is_valid_city_sequence(city_list):
|
86 |
-
"""
|
87 |
-
Checks if the city sequence is valid. A valid sequence has every city (except the first and last)
|
88 |
-
appearing consecutively, and no city should appear again once its sequence is over.
|
89 |
-
|
90 |
-
Args:
|
91 |
-
- city_list (list): List of cities.
|
92 |
-
|
93 |
-
Returns:
|
94 |
-
- bool: True if the sequence is valid, False otherwise.
|
95 |
-
"""
|
96 |
-
|
97 |
-
# If the list has less than 3 cities, it's invalid.
|
98 |
-
if len(city_list) < 3:
|
99 |
-
return False
|
100 |
-
|
101 |
-
# Set to keep track of visited cities
|
102 |
-
visited_cities = set()
|
103 |
-
|
104 |
-
i = 0
|
105 |
-
while i < len(city_list):
|
106 |
-
city = city_list[i]
|
107 |
-
|
108 |
-
# If the city was already visited, it's invalid.
|
109 |
-
if city in visited_cities and (i != 0 and i != len(city_list) - 1):
|
110 |
-
return False
|
111 |
-
|
112 |
-
# Count the consecutive occurrences of the city
|
113 |
-
count = 0
|
114 |
-
while i < len(city_list) and city_list[i] == city:
|
115 |
-
count += 1
|
116 |
-
i += 1
|
117 |
-
|
118 |
-
# If the city appeared only once in the medium, it's invalid.
|
119 |
-
if count == 1 and 0 < i - 1 < len(city_list) - 1:
|
120 |
-
return False
|
121 |
-
|
122 |
-
visited_cities.add(city)
|
123 |
-
|
124 |
-
return True
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
def is_reasonalbe_visiting_city(question, tested_data):
|
129 |
-
|
130 |
-
city_list = []
|
131 |
-
|
132 |
-
# print(tested_data)
|
133 |
-
for i in range(min(question['days'],len(tested_data))):
|
134 |
-
city_value = tested_data[i]['current_city']
|
135 |
-
|
136 |
-
if 'from' in city_value:
|
137 |
-
city1, city2 = extract_from_to(city_value)
|
138 |
-
city1 = extract_before_parenthesis(city1)
|
139 |
-
city2 = extract_before_parenthesis(city2)
|
140 |
-
if i==0 and city1 != question['org']:
|
141 |
-
return False, f"The first day's city should be {question['org']}."
|
142 |
-
|
143 |
-
city_list += [city1, city2]
|
144 |
-
|
145 |
-
else:
|
146 |
-
city_list.append(extract_before_parenthesis(city_value))
|
147 |
-
|
148 |
-
if city_list[0] != city_list[-1]:
|
149 |
-
return False, "The trip should be a closed circle."
|
150 |
-
|
151 |
-
if not is_valid_city_sequence(city_list):
|
152 |
-
return False, "The city sequence is invalid."
|
153 |
-
|
154 |
-
for idx, city in enumerate(city_list):
|
155 |
-
if city not in city_state_map:
|
156 |
-
return False, f"{city} is not a valid city."
|
157 |
-
if idx not in [0,len(city_list)-1] and question['days'] >3 and city_state_map[city] != question['dest']:
|
158 |
-
return False, f"{city} is not in {question['dest']}."
|
159 |
-
|
160 |
-
return True, None
|
161 |
-
|
162 |
-
|
163 |
-
def is_valid_restaurants(question, tested_data):
|
164 |
-
|
165 |
-
restaurants_list = []
|
166 |
-
|
167 |
-
for i in range(min(question['days'],len(tested_data))):
|
168 |
-
unit = tested_data[i]
|
169 |
-
|
170 |
-
if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
|
171 |
-
if unit['breakfast'] not in restaurants_list:
|
172 |
-
restaurants_list.append(unit['breakfast'])
|
173 |
-
else:
|
174 |
-
return False, f"The restaurant in day {i+1} breakfast is repeated."
|
175 |
-
# elif 'breakfast' not in unit :
|
176 |
-
# return False, f"No Breakfast Info."
|
177 |
-
|
178 |
-
if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
|
179 |
-
if unit['lunch'] not in restaurants_list:
|
180 |
-
restaurants_list.append(unit['lunch'])
|
181 |
-
else:
|
182 |
-
return False, f"The restaurant in day {i+1} lunch {unit['lunch']} is repeated."
|
183 |
-
# elif 'lunch' not in unit:
|
184 |
-
# return False, f"No Lunch Info."
|
185 |
-
|
186 |
-
if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
|
187 |
-
if unit['dinner'] not in restaurants_list:
|
188 |
-
restaurants_list.append(unit['dinner'])
|
189 |
-
else:
|
190 |
-
return False, f"The restaurant in day {i+1} dinner is repeated."
|
191 |
-
# elif 'dinner' not in unit:
|
192 |
-
# return False, f"No Dinner Info."
|
193 |
-
|
194 |
-
return True, None
|
195 |
-
|
196 |
-
def is_valid_attractions(question, tested_data):
|
197 |
-
|
198 |
-
attractions_list = []
|
199 |
-
|
200 |
-
for i in range(min(question['days'],len(tested_data))):
|
201 |
-
unit = tested_data[i]
|
202 |
-
|
203 |
-
if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
|
204 |
-
for attraction in unit['attraction'].split(';')[:-1]:
|
205 |
-
if attraction not in attractions_list:
|
206 |
-
attractions_list.append(attraction)
|
207 |
-
else:
|
208 |
-
return False, f"The attraction '{attraction}' in day {i+1} is repeated."
|
209 |
-
|
210 |
-
# elif 'attraction' not in unit:
|
211 |
-
# return False, f"No Attraction Info."
|
212 |
-
|
213 |
-
return True, None
|
214 |
-
|
215 |
-
def is_valid_transportation(question, tested_data):
|
216 |
-
|
217 |
-
if tested_data[0]['transportation'] and tested_data[0]['transportation'] != '-':
|
218 |
-
transportation_list = [transportation_match(tested_data[0]['transportation'])]
|
219 |
-
|
220 |
-
else:
|
221 |
-
return False, "The transportation in day 1 should not be empty."
|
222 |
-
|
223 |
-
for i in range(min(question['days'],len(tested_data))):
|
224 |
-
unit = tested_data[i]
|
225 |
-
|
226 |
-
if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-':
|
227 |
-
transportation_list.append(transportation_match(unit['transportation']))
|
228 |
-
# elif 'transportation' not in unit:
|
229 |
-
# return False, f"No Transportation Info."
|
230 |
-
|
231 |
-
if (('Self-driving' in transportation_list) and ('Flight' in transportation_list)) or (('Taxi' in transportation_list) and ('Self-driving' in transportation_list)):
|
232 |
-
return False, "The transportation is conflicting."
|
233 |
-
|
234 |
-
return True, None
|
235 |
-
|
236 |
-
def is_valid_information_in_current_city(question, tested_data):
|
237 |
-
|
238 |
-
for i in range(min(question['days'],len(tested_data))):
|
239 |
-
unit = tested_data[i]
|
240 |
-
current_city = unit['current_city']
|
241 |
-
final_city_list = []
|
242 |
-
|
243 |
-
if 'from' in current_city:
|
244 |
-
city1, city2 = extract_from_to(current_city)
|
245 |
-
city1 = extract_before_parenthesis(city1)
|
246 |
-
city2 = extract_before_parenthesis(city2)
|
247 |
-
final_city_list = [city1, city2]
|
248 |
-
else:
|
249 |
-
final_city_list = extract_before_parenthesis(current_city)
|
250 |
-
|
251 |
-
if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-':
|
252 |
-
for city in final_city_list:
|
253 |
-
if city not in unit['transportation']:
|
254 |
-
# print(city)
|
255 |
-
return False, f"The transportation in day {i+1} is invalid city choice."
|
256 |
-
# elif 'transportation' not in unit:
|
257 |
-
# return False, f"No Transportation Info."
|
258 |
-
|
259 |
-
if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
|
260 |
-
|
261 |
-
flag = False
|
262 |
-
|
263 |
-
for city in final_city_list:
|
264 |
-
if city in unit['breakfast']:
|
265 |
-
flag = True
|
266 |
-
|
267 |
-
if not flag:
|
268 |
-
return False, f"The breakfast in day {i+1} is invalid city choice."
|
269 |
-
# elif 'breakfast' not in unit:
|
270 |
-
# return False, f"No Breakfast Info."
|
271 |
-
|
272 |
-
if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
|
273 |
-
flag = False
|
274 |
-
|
275 |
-
for city in final_city_list:
|
276 |
-
if city in unit['lunch']:
|
277 |
-
flag = True
|
278 |
-
|
279 |
-
if not flag:
|
280 |
-
return False, f"The lunch in day {i+1} is invalid city choice."
|
281 |
-
# elif 'lunch' not in unit:
|
282 |
-
# return False, f"No Lunch Info."
|
283 |
-
|
284 |
-
if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
|
285 |
-
flag = False
|
286 |
-
|
287 |
-
for city in final_city_list:
|
288 |
-
if city in unit['dinner']:
|
289 |
-
flag = True
|
290 |
-
|
291 |
-
if not flag:
|
292 |
-
return False, f"The dinner in day {i+1} is invalid city choice."
|
293 |
-
# elif 'dinner' not in unit:
|
294 |
-
# return False, f"No Dinner Info."
|
295 |
-
|
296 |
-
if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
|
297 |
-
|
298 |
-
attraction_list = unit['attraction'].split(';')[:-1]
|
299 |
-
|
300 |
-
for attraction in attraction_list:
|
301 |
-
flag = False
|
302 |
-
for city in final_city_list:
|
303 |
-
if city in attraction:
|
304 |
-
flag = True
|
305 |
-
if not flag:
|
306 |
-
return False, f"The attraction in day {i+1} is invalid city choice."
|
307 |
-
|
308 |
-
# elif 'attraction' not in unit:
|
309 |
-
# return False, f"No Attraction Info."
|
310 |
-
|
311 |
-
|
312 |
-
if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-':
|
313 |
-
|
314 |
-
if final_city_list[-1] not in unit['accommodation']:
|
315 |
-
return False, f"The accommodation in day {i+1} is invalid city choice."
|
316 |
-
|
317 |
-
# elif 'accommodation' not in unit:
|
318 |
-
# return False, f"No Accommodation Info."
|
319 |
-
|
320 |
-
return True, None
|
321 |
-
|
322 |
-
# hallucination
|
323 |
-
def is_valid_information_in_sandbox(question, tested_data):
|
324 |
-
|
325 |
-
for i in range(min(question['days'],len(tested_data))):
|
326 |
-
unit = tested_data[i]
|
327 |
-
|
328 |
-
if unit['transportation'] and unit['transportation'] != '-':
|
329 |
-
value = unit['transportation']
|
330 |
-
org_city, dest_city = extract_from_to(value)
|
331 |
-
if org_city == None or dest_city == None:
|
332 |
-
org_city, dest_city = extract_from_to(unit['current_city'])
|
333 |
-
if 'flight number' in value.lower():
|
334 |
-
try:
|
335 |
-
org_city = extract_before_parenthesis(org_city)
|
336 |
-
dest_city = extract_before_parenthesis(dest_city)
|
337 |
-
except TypeError:
|
338 |
-
raise ValueError("The transportation {} in day {} can not be parsed.".format(value,i+1))
|
339 |
-
# print(value)
|
340 |
-
if len(flight.data[(flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]) & (flight.data['OriginCityName']==org_city) & (flight.data['DestCityName']==dest_city)]) < 1:
|
341 |
-
return False, f"The flight number in day {i+1} is invalid in the sandbox."
|
342 |
-
|
343 |
-
elif 'self-driving' in value.lower() or 'taxi' in value.lower():
|
344 |
-
try:
|
345 |
-
org_city = extract_before_parenthesis(org_city)
|
346 |
-
dest_city = extract_before_parenthesis(dest_city)
|
347 |
-
except TypeError:
|
348 |
-
org_city = '-'
|
349 |
-
dest_city = '-'
|
350 |
-
print("The transportation {} in day {} can not be parsed and '-' will be used instead.".format(value,i+1))
|
351 |
-
|
352 |
-
if 'self-driving' in value.lower():
|
353 |
-
if googleDistanceMatrix.run_for_evaluation(org_city, dest_city, mode='self-driving')['cost'] == None:
|
354 |
-
return False, f"The self-driving in day {i+1} is invalid in the sandbox."
|
355 |
-
else:
|
356 |
-
if googleDistanceMatrix.run_for_evaluation(org_city, dest_city, mode='taxi')['cost'] == None:
|
357 |
-
return False, f"The taxi in day {i+1} is invalid in the sandbox."
|
358 |
-
|
359 |
-
if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
|
360 |
-
name, city = get_valid_name_city(unit['breakfast'])
|
361 |
-
if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
|
362 |
-
return False, f"The breakfast in day {i+1} is invalid in the sandbox."
|
363 |
-
# elif 'breakfast' not in unit:
|
364 |
-
# return False, f"No Breakfast Info."
|
365 |
-
|
366 |
-
if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
|
367 |
-
name, city = get_valid_name_city(unit['lunch'])
|
368 |
-
if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
|
369 |
-
return False, f"The lunch in day {i+1} is invalid in the sandbox."
|
370 |
-
# elif 'lunch' not in unit:
|
371 |
-
# return False, f"No Lunch Info."
|
372 |
-
|
373 |
-
if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
|
374 |
-
name, city = get_valid_name_city(unit['dinner'])
|
375 |
-
if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1:
|
376 |
-
return False, f"The dinner in day {i+1} is invalid in the sandbox."
|
377 |
-
# elif 'dinner' not in unit:
|
378 |
-
# return False, f"No Dinner Info."
|
379 |
-
|
380 |
-
if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-':
|
381 |
-
attractions_list = unit['attraction'].split(';')[:-1]
|
382 |
-
for attraction in attractions_list:
|
383 |
-
name, city = get_valid_name_city(attraction)
|
384 |
-
if len(attractions.data[(attractions.data['Name'].astype(str).str.contains(re.escape(name))) & (attractions.data['City'] == city)]) < 1:
|
385 |
-
return False, f"The attraction {attraction} in day {i+1} is invalid in the sandbox."
|
386 |
-
# elif 'attraction' not in unit:
|
387 |
-
# return False, f"No Attraction Info."
|
388 |
-
|
389 |
-
if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-':
|
390 |
-
name, city = get_valid_name_city(unit['accommodation'])
|
391 |
-
# print(name,city)
|
392 |
-
# print(accommodation.data[accommodation.data['NAME'].astype(str).str.contains(re.escape(name))])
|
393 |
-
if len(accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]) < 1:
|
394 |
-
return False, f"The accommodation in day {i+1} is invalid in the sandbox."
|
395 |
-
# elif 'accommodation' not in unit:
|
396 |
-
# return False, f"No Accommodation Info."
|
397 |
-
|
398 |
-
return True, None
|
399 |
-
|
400 |
-
|
401 |
-
def is_valid_accommodaton(question, tested_data):
|
402 |
-
data = []
|
403 |
-
for i in range(min(question['days'],len(tested_data))):
|
404 |
-
unit = tested_data[i]
|
405 |
-
|
406 |
-
if 'accommodation' not in unit:
|
407 |
-
return False, f"No Accommodation Info."
|
408 |
-
|
409 |
-
data.append(unit['accommodation'])
|
410 |
-
# data = [unit['accommodation'] for unit in tested_data]
|
411 |
-
consectutive_accommodation = count_consecutive_values(data)
|
412 |
-
for unit in consectutive_accommodation:
|
413 |
-
# print(unit)
|
414 |
-
if unit and unit[0] not in ['-',''] :
|
415 |
-
name, city = get_valid_name_city(unit[0])
|
416 |
-
# print(unit[0],name,city)
|
417 |
-
# try:
|
418 |
-
if len(accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]) == 1 and unit[1] < accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)].iloc[0]['minimum nights']:
|
419 |
-
return False, f"The accommodation {unit[0]} do not obey the minumum nights rule."
|
420 |
-
# can not parse data
|
421 |
-
# except re.error:
|
422 |
-
# continue
|
423 |
-
|
424 |
-
return True, None
|
425 |
-
|
426 |
-
def is_valid_visiting_city_number(question, tested_data):
|
427 |
-
|
428 |
-
city_set = set()
|
429 |
-
|
430 |
-
|
431 |
-
for i in range(min(question['days'],len(tested_data))):
|
432 |
-
city_value = tested_data[i]['current_city']
|
433 |
-
|
434 |
-
if 'from' in city_value:
|
435 |
-
city1, city2 = extract_from_to(city_value)
|
436 |
-
city1 = extract_before_parenthesis(city1)
|
437 |
-
city2 = extract_before_parenthesis(city2)
|
438 |
-
if i==0 and city1 != question['org']:
|
439 |
-
return False, f"The first day's city should be {question['org']}."
|
440 |
-
|
441 |
-
city_set.add(city1)
|
442 |
-
city_set.add(city2)
|
443 |
-
|
444 |
-
else:
|
445 |
-
city_set.add(extract_before_parenthesis(city_value))
|
446 |
-
|
447 |
-
city_set.discard(question['org'])
|
448 |
-
|
449 |
-
if len(city_set) != question['visiting_city_number']:
|
450 |
-
return False, f"The number of visiting cities should be {question['visiting_city_number']}."
|
451 |
-
|
452 |
-
return True, None
|
453 |
-
|
454 |
-
def is_valid_days(question, tested_data):
|
455 |
-
lens = 0
|
456 |
-
for i in range(min(question['days'],len(tested_data))):
|
457 |
-
if tested_data[i] != {} and tested_data[i]['current_city'] != "You don't need to fill in the information for this or later days.":
|
458 |
-
lens += 1
|
459 |
-
|
460 |
-
if lens != question['days']:
|
461 |
-
# print(lens)
|
462 |
-
return False, f"The number of days should be {question['days']}."
|
463 |
-
else:
|
464 |
-
return True, None
|
465 |
-
|
466 |
-
def is_not_absent(question, tested_data):
|
467 |
-
needed_info = 6 * question['days']
|
468 |
-
total_valid_info = 0
|
469 |
-
|
470 |
-
if not is_valid_days(question, tested_data)[0]:
|
471 |
-
return False, "Invalid Days"
|
472 |
-
|
473 |
-
if not is_valid_visiting_city_number(question, tested_data)[0]:
|
474 |
-
return False, "Invalid City Number"
|
475 |
-
|
476 |
-
for i in range(min(question['days'],len(tested_data))):
|
477 |
-
unit = tested_data[i]
|
478 |
-
|
479 |
-
if 'transportation' not in unit:
|
480 |
-
return False, f"No Transportation Info."
|
481 |
-
|
482 |
-
if 'breakfast' not in unit:
|
483 |
-
return False, f"No Breakfast Info."
|
484 |
-
|
485 |
-
if 'lunch' not in unit:
|
486 |
-
return False, f"No Lunch Info."
|
487 |
-
|
488 |
-
if 'dinner' not in unit:
|
489 |
-
return False, f"No Dinner Info."
|
490 |
-
|
491 |
-
if 'attraction' not in unit:
|
492 |
-
return False, f"No Attraction Info."
|
493 |
-
|
494 |
-
if 'accommodation' not in unit:
|
495 |
-
return False, f"No Accommodation Info."
|
496 |
-
|
497 |
-
if ('from ' in unit['current_city'] or 'to ' in unit['current_city']) and unit['transportation'] in ['','-']:
|
498 |
-
return False, f"No transportation in day {i+1} is not allowed."
|
499 |
-
|
500 |
-
if ('from ' not in unit['current_city'] and ' to ' not in unit['current_city']) and unit['attraction'] in ['','-']:
|
501 |
-
return False, f"No attaction in day {i+1} is not allowed."
|
502 |
-
|
503 |
-
if i != question['days'] - 1 and unit['accommodation'] in ['','-']:
|
504 |
-
return False, f"No accommodation in day {i+1} is not allowed."
|
505 |
-
|
506 |
-
if (unit['breakfast'] in ['','-'] or unit['lunch'] in ['','-'] or unit['dinner'] in ['','-']) and 'from ' not in unit['current_city']:
|
507 |
-
return False, f"No meal in day {i+1} is not allowed."
|
508 |
-
|
509 |
-
|
510 |
-
for key in unit:
|
511 |
-
if unit[key] and unit[key] != '-':
|
512 |
-
total_valid_info += 1
|
513 |
-
|
514 |
-
|
515 |
-
if total_valid_info * 1.0 / needed_info < 0.5:
|
516 |
-
return False, f"The absent information is more than 50%."
|
517 |
-
|
518 |
-
return True, None
|
519 |
-
|
520 |
-
|
521 |
-
def evaluation(query_data, tested_data):
|
522 |
-
return_info = {}
|
523 |
-
return_info['is_reasonalbe_visiting_city'] = is_reasonalbe_visiting_city(query_data, tested_data)
|
524 |
-
return_info['is_valid_restaurants'] = is_valid_restaurants(query_data, tested_data)
|
525 |
-
return_info['is_valid_attractions'] = is_valid_attractions(query_data, tested_data)
|
526 |
-
return_info['is_valid_accommodation'] = is_valid_accommodaton(query_data, tested_data)
|
527 |
-
return_info['is_valid_transportation'] = is_valid_transportation(query_data, tested_data)
|
528 |
-
return_info['is_valid_information_in_current_city'] = is_valid_information_in_current_city(query_data, tested_data)
|
529 |
-
return_info['is_valid_information_in_sandbox'] = is_valid_information_in_sandbox(query_data, tested_data)
|
530 |
-
return_info['is_not_absent'] = is_not_absent(query_data, tested_data)
|
531 |
-
return return_info
|
532 |
-
|
533 |
-
def boolean_evaluation(query_data, tested_data):
|
534 |
-
return_info = {}
|
535 |
-
return_info['is_reasonalbe_visiting_city'] = is_reasonalbe_visiting_city(query_data, tested_data)
|
536 |
-
return_info['is_valid_restaurants'] = is_valid_restaurants(query_data, tested_data)
|
537 |
-
return_info['is_valid_accommodation'] = is_valid_accommodaton(query_data, tested_data)
|
538 |
-
return_info['is_valid_attractions'] = is_valid_attractions(query_data, tested_data)
|
539 |
-
return_info['is_valid_transportation'] = is_valid_transportation(query_data, tested_data)
|
540 |
-
return_info['is_valid_information_in_current_city'] = is_valid_information_in_current_city(query_data, tested_data)
|
541 |
-
return_info['is_valid_information_in_sandbox'] = is_valid_information_in_sandbox(query_data, tested_data)
|
542 |
-
return_info['is_not_absent'] = is_not_absent(query_data, tested_data)
|
543 |
-
for key in return_info:
|
544 |
-
if return_info[key][0] == False:
|
545 |
-
print(return_info[key][1])
|
546 |
-
return False
|
547 |
-
return True
|
548 |
-
|
549 |
-
# if __name__ == '__main__':
|
550 |
-
# number_list = extract_numbers_from_filenames('/home/xj/toolAugEnv/code/toolConstraint/data/annotation/lrz')
|
551 |
-
# # json_data = json.load(open('/home/xj/toolAugEnv/code/toolConstraint/data/annotation/x/annotation_4.json'))
|
552 |
-
# query_data = load_line_json_data('/home/xj/toolAugEnv/code/toolConstraint/data/query/lrz.jsonl')
|
553 |
-
# for idx in number_list:
|
554 |
-
# json_data = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/lrz/annotation_{idx}.json'))
|
555 |
-
# print(str(idx), evaluation(query_data[idx-1], json_data))
|
556 |
-
# # json_data = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/plan_{idx}.json'))
|
557 |
-
# # query_data = load_line_json_data('/home/xj/toolAugEnv/code/toolConstraint/data/query/test.jsonl')[idx-1]
|
558 |
-
# # help me write all function name in this file, just the name
|
559 |
-
# #
|
560 |
-
# # list all function name in this file
|
561 |
-
# # ['is_reasonalbe_visiting_city', 'is_valiable_restaurants', 'is_valiable_attractions', 'is_valiable_transportation', 'is_valid_information_in_current_city', 'is_valid_information_in_sandbox']
|
562 |
-
# # print(is_valiable_restaurants(query_data, json_data))
|
563 |
-
|
564 |
-
# if __name__ == "__main__":
|
565 |
-
# user = 'zk'
|
566 |
-
# query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
|
567 |
-
# idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
|
568 |
-
# commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
569 |
-
# for idx in idx_number_list:
|
570 |
-
# print(idx)
|
571 |
-
# query_data = query_data_list[idx-1]
|
572 |
-
# generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/plan_{idx}.json'))
|
573 |
-
# # generated_plan = generated_plan[:-1]
|
574 |
-
# if generated_plan[-1]['gpt-3.5-turbo-16k-result'] != 'Plan Fail':
|
575 |
-
# info_box = evaluation(query_data, generated_plan[-1]['gpt-3.5-turbo-16k-result'])
|
576 |
-
# generated_plan[-1]['toolAug-commonsense'] = info_box
|
577 |
-
# else:
|
578 |
-
# generated_plan[-1]['toolAug-commonsense'] = None
|
579 |
-
# info_box = None
|
580 |
-
# commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
|
581 |
-
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/plan_{idx}.json','w') as f:
|
582 |
-
# json.dump(generated_plan,f)
|
583 |
-
|
584 |
-
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/turbo16k-turbo16k/{user}/commonsense_statistic.json','w') as f:
|
585 |
-
# json.dump(commonsense_statistic,f)
|
586 |
-
|
587 |
-
# if __name__ == "__main__":
|
588 |
-
# user = 'all'
|
589 |
-
# model_type = ['chatgpt','gpt4','greedy_search'][2]
|
590 |
-
# query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
|
591 |
-
# # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
|
592 |
-
# idx_number_list = [i for i in range(1,501)]
|
593 |
-
# commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
594 |
-
|
595 |
-
# for idx in idx_number_list:
|
596 |
-
# print(idx)
|
597 |
-
# query_data = query_data_list[idx-1]
|
598 |
-
# generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/plan_{idx}.json'))
|
599 |
-
# # generated_plan = generated_plan[:-1]
|
600 |
-
# if model_type == 'greedy_search':
|
601 |
-
# info_box = evaluation(query_data, generated_plan[-1][f'greedy_search_plan'])
|
602 |
-
# else:
|
603 |
-
# info_box = evaluation(query_data, generated_plan[-1][f'{model_type}_human_collected_info_results_parsed'])
|
604 |
-
# generated_plan[-1][f'{model_type}_with_human_collected_commonsense'] = info_box
|
605 |
-
# commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
|
606 |
-
|
607 |
-
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/plan_{idx}.json','w') as f:
|
608 |
-
# json.dump(generated_plan,f)
|
609 |
-
|
610 |
-
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre2/{user}/{model_type}_with_human_collected_commonsense_statistic.json','w') as f:
|
611 |
-
# json.dump(commonsense_statistic,f)
|
612 |
-
|
613 |
-
|
614 |
-
# if __name__ == "__main__":
|
615 |
-
# user = 'all'
|
616 |
-
# query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
|
617 |
-
# idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
|
618 |
-
# hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
619 |
-
# not_satified = []
|
620 |
-
# for idx in tqdm(idx_number_list):
|
621 |
-
# # print(idx)
|
622 |
-
# query_data = query_data_list[idx-1]
|
623 |
-
# generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}/annotation_{idx}.json'))
|
624 |
-
|
625 |
-
# if not boolean_evaluation(query_data, generated_plan):
|
626 |
-
# not_satified.append(idx)
|
627 |
-
# print(idx)
|
628 |
-
# generated_plan = generated_plan[:-1]
|
629 |
-
# print(not_satified)
|
630 |
-
|
631 |
-
if __name__ == "__main__":
|
632 |
-
set_type = ["train",'dev','test'][0]
|
633 |
-
query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/query/query.jsonl')
|
634 |
-
# idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/plan')
|
635 |
-
commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
636 |
-
not_satified = []
|
637 |
-
# print( idx_number_list)
|
638 |
-
for idx in tqdm(range(1,len(query_data_list)+1)):
|
639 |
-
# print(idx)
|
640 |
-
query_data = query_data_list[idx-1]
|
641 |
-
generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}/plan/plan_{idx}.json'))
|
642 |
-
try:
|
643 |
-
store_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{idx}.json'))
|
644 |
-
except FileNotFoundError:
|
645 |
-
store_plan = [{}]
|
646 |
-
info_box = evaluation(query_data,generated_plan[1])
|
647 |
-
# if not boolean_evaluation(query_data, generated_plan[1]):
|
648 |
-
# not_satified.append(idx)
|
649 |
-
# print(idx)
|
650 |
-
# print(store_plan[-1])
|
651 |
-
store_plan[-1][f'human_anno_commonsense_constraint'] = info_box
|
652 |
-
with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{idx}.json','w') as f:
|
653 |
-
json.dump(store_plan,f)
|
654 |
-
commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
|
655 |
-
print(not_satified)
|
656 |
-
with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/human_anno_commonsense_constraint.json','w') as f:
|
657 |
-
json.dump(commonsense_statistic,f)
|
658 |
-
|
659 |
-
# if __name__ == "__main__":
|
660 |
-
# user = 'all'
|
661 |
-
# model_type = ['chatgpt','gpt4'][1]
|
662 |
-
# query_data_list = load_line_json_data(f'/home/xj/toolAugEnv/code/toolConstraint/data/query/{user}.jsonl')
|
663 |
-
# # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
|
664 |
-
# idx_number_list = [i for i in range(1,501)]
|
665 |
-
# commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
666 |
-
# cnt = 0
|
667 |
-
# for idx in idx_number_list:
|
668 |
-
# # print(idx)
|
669 |
-
# query_data = query_data_list[idx-1]
|
670 |
-
# generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/pre/{user}/plan_{idx}.json'))[-1]['gpt4_human_collected_info_results_parsed']
|
671 |
-
# # generated_plan = generated_plan[:-1]
|
672 |
-
|
673 |
-
# if not boolean_evaluation(query_data, generated_plan):
|
674 |
-
# cnt += 1
|
675 |
-
# print(idx)
|
676 |
-
# print(cnt)
|
677 |
-
|
678 |
-
# if __name__ == "__main__":
|
679 |
-
# parser = argparse.ArgumentParser(description="")
|
680 |
-
# # model_type = ['gpt-3.5-turbo-1106','gpt-4-1106-preview','greedy_search','mistral-7B-32K','gemini2','mixtral','gpt-3.5-turbo-11062'][-1]
|
681 |
-
# # method = ['direct','cot','react','reflexion','tool-use'][-1]
|
682 |
-
# # set_type = ['dev','test'][0]
|
683 |
-
# parser.add_argument("--model_type", type=str, default="gpt-3.5-turbo-1106")
|
684 |
-
# parser.add_argument("--method", type=str, default="direct")
|
685 |
-
# parser.add_argument("--set_type", type=str, default="dev")
|
686 |
-
# args = parser.parse_args()
|
687 |
-
# directory = f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{args.set_type}'
|
688 |
-
# query_data_list = load_line_json_data(os.path.join(directory, 'query/query.jsonl'))
|
689 |
-
# # idx_number_list = extract_numbers_from_filenames(f'/home/xj/toolAugEnv/code/toolConstraint/data/annotation/{user}')
|
690 |
-
# idx_number_list = [i for i in range(1,len(query_data_list)+1)]
|
691 |
-
# commonsense_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
692 |
-
# deliver_cnt = 0
|
693 |
-
# if args.method == 'tool-use':
|
694 |
-
# suffix = ''
|
695 |
-
# else:
|
696 |
-
# suffix = '_with_human_info'
|
697 |
-
# for idx in tqdm(idx_number_list):
|
698 |
-
# # print(idx)
|
699 |
-
# query_data = query_data_list[idx-1]
|
700 |
-
# generated_plan = json.load(open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/plan_{idx}.json'))
|
701 |
-
# # generated_plan = generated_plan[:-1]
|
702 |
-
# if args.model_type == 'greedy_search':
|
703 |
-
# info_box = evaluation(query_data, generated_plan[-1][f'greedy_search_plan'])
|
704 |
-
# else:
|
705 |
-
# if args.method == 'tool-use':
|
706 |
-
# suffix2 = ''
|
707 |
-
# else:
|
708 |
-
# suffix2 = '_collected'
|
709 |
-
# if generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] and generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results']!='Max Token Length Exceeded.':
|
710 |
-
# try:
|
711 |
-
# info_box = evaluation(query_data, generated_plan[-1][f'{args.model_type}_{args.method}{suffix}_results_parsed'])
|
712 |
-
# except KeyError:
|
713 |
-
# info_box = None
|
714 |
-
# generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] = ""
|
715 |
-
# except IndexError:
|
716 |
-
# info_box = None
|
717 |
-
# generated_plan[-1][f'{args.model_type}_{args.method}{suffix2}_info_results'] = ""
|
718 |
-
# else:
|
719 |
-
# info_box = None
|
720 |
-
# if info_box:
|
721 |
-
# deliver_cnt += 1
|
722 |
-
# generated_plan[-1][f'{args.model_type}_{args.method}{suffix}_commonsense_constraint'] = info_box
|
723 |
-
# commonsense_statistic[query_data['level']][query_data['days']].append(info_box)
|
724 |
-
|
725 |
-
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/plan_{idx}.json','w') as f:
|
726 |
-
# json.dump(generated_plan,f)
|
727 |
-
|
728 |
-
# with open(f'/home/xj/toolAugEnv/code/toolConstraint/results/{args.set_type}/{args.model_type}_{args.method}{suffix}_commonsense_constraint.json','w') as f:
|
729 |
-
# json.dump(commonsense_statistic,f)
|
730 |
-
|
731 |
-
# if args.set_type == 'dev':
|
732 |
-
# print(f"Model:{args.model_type} Method:{args.method} Set: {args.set_type} \nDeliver Rate: {deliver_cnt/180}" )
|
733 |
-
# elif args.set_type == 'test':
|
734 |
-
# print(f"Model:{args.model_type} Method:{args.method} Set: {args.set_type} \nDeliver Rate: {deliver_cnt/1000}" )
|
735 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluation/eval.py
DELETED
@@ -1,181 +0,0 @@
|
|
1 |
-
from commonsenseConstraint import evaluation as commonsense_eval
|
2 |
-
from hardConstraint import evaluation as hard_eval
|
3 |
-
import json
|
4 |
-
from tqdm import tqdm
|
5 |
-
from datasets import load_dataset
|
6 |
-
|
7 |
-
|
8 |
-
def load_line_json_data(filename):
|
9 |
-
data = []
|
10 |
-
with open(filename, 'r', encoding='utf-8') as f:
|
11 |
-
for line in f.read().strip().split('\n'):
|
12 |
-
unit = json.loads(line)
|
13 |
-
data.append(unit)
|
14 |
-
return data
|
15 |
-
|
16 |
-
def count_true_false(data):
|
17 |
-
"""Count the number of true and false values in a list."""
|
18 |
-
true_count = data.count(True)
|
19 |
-
false_count = data.count(False)
|
20 |
-
return true_count, false_count
|
21 |
-
|
22 |
-
def statistics(commonsense_statistic):
|
23 |
-
"""Generate statistics for each level and day in the given data with a different structure."""
|
24 |
-
result = {level: {day: {} for day in commonsense_statistic[level]} for level in commonsense_statistic}
|
25 |
-
|
26 |
-
for level, days in commonsense_statistic.items():
|
27 |
-
for day, dicts in days.items():
|
28 |
-
for dct in dicts:
|
29 |
-
if dct:
|
30 |
-
for key, data in dct.items():
|
31 |
-
true_count, false_count = count_true_false(data)
|
32 |
-
if key not in result[level][day]:
|
33 |
-
result[level][day][key] = {"true": 0, "false": 0}
|
34 |
-
result[level][day][key]["true"] += true_count
|
35 |
-
result[level][day][key]["false"] += false_count
|
36 |
-
|
37 |
-
return result
|
38 |
-
|
39 |
-
|
40 |
-
def eval_score(validation_or_test: str, file_path: str, TOKEN):
|
41 |
-
|
42 |
-
if validation_or_test == 'validation':
|
43 |
-
query_data_list = load_dataset('osunlp/TravelBenchEval','validation',token=TOKEN)['validation']
|
44 |
-
elif validation_or_test == 'test':
|
45 |
-
query_data_list = load_dataset('osunlp/TravelBenchEval','test',token=TOKEN)['test']
|
46 |
-
|
47 |
-
query_data_list = [x for x in query_data_list]
|
48 |
-
hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
49 |
-
commonsenseConstraint_statistic = {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']}
|
50 |
-
tested_plans = load_line_json_data(file_path)
|
51 |
-
delivery_cnt = 0
|
52 |
-
plan_constraint_store = []
|
53 |
-
for idx in tqdm(range(0,len(query_data_list))):
|
54 |
-
query_data = query_data_list[idx]
|
55 |
-
tested_plan = tested_plans[idx]
|
56 |
-
if type(query_data) == str:
|
57 |
-
query_data = eval(query_data)
|
58 |
-
if type(tested_plan) == str:
|
59 |
-
tested_plan = eval(tested_plan)
|
60 |
-
if type(query_data['local_constraint']) == str:
|
61 |
-
query_data['local_constraint'] = eval(query_data['local_constraint'])
|
62 |
-
|
63 |
-
if tested_plan['plan']:
|
64 |
-
delivery_cnt += 1
|
65 |
-
commonsense_info_box = commonsense_eval(query_data,tested_plan['plan'])
|
66 |
-
else:
|
67 |
-
commonsense_info_box = None
|
68 |
-
|
69 |
-
if commonsense_info_box and commonsense_info_box['is_not_absent'][0] and commonsense_info_box['is_valid_information_in_sandbox'][0]:
|
70 |
-
hard_info_box = hard_eval(query_data,tested_plan['plan'])
|
71 |
-
else:
|
72 |
-
hard_info_box = None
|
73 |
-
|
74 |
-
plan_constraint_store.append({'commonsense_constraint':commonsense_info_box,'hard_constraint':hard_info_box})
|
75 |
-
|
76 |
-
commonsenseConstraint_statistic[query_data['level']][query_data['days']].append(commonsense_info_box)
|
77 |
-
hardConstraint_statistic[query_data['level']][query_data['days']].append(hard_info_box)
|
78 |
-
|
79 |
-
commonsenseConstraint_statistic_processed = statistics(commonsenseConstraint_statistic)
|
80 |
-
hardConstraint_statistic_processed = statistics(hardConstraint_statistic)
|
81 |
-
# print(commonsenseConstraint_statistic_processed)
|
82 |
-
# print(hardConstraint_statistic_processed)
|
83 |
-
constraint_record = {key: {day: {'house rule':0, 'cuisine':0, 'room type':0, 'transportation':0} for day in [3,5,7]} for key in ['medium','hard']}
|
84 |
-
constraint_mapping = {'house rule':'valid_room_rule','cuisine':'valid_cuisine','room type':'valid_room_type','transportation':'valid_transportation'}
|
85 |
-
mapping_constraint_record = {key: {day: {'valid_room_rule':0, 'valid_cuisine':0, 'valid_room_type':0, 'valid_transportation':0} for day in [3,5,7]} for key in ['medium','hard']}
|
86 |
-
count_record = {key:{day:0 for day in [3,5,7]} for key in ['easy','medium','hard']}
|
87 |
-
|
88 |
-
for unit in query_data_list:
|
89 |
-
count_record[unit['level']][unit['days']] += 1
|
90 |
-
for key in constraint_record['medium'][3]:
|
91 |
-
if unit['local_constraint'][key] != None:
|
92 |
-
constraint_record[unit['level']][unit['days']][key] += 1
|
93 |
-
mapping_constraint_record[unit['level']][unit['days']][constraint_mapping[key]] += 1
|
94 |
-
|
95 |
-
data_record = {key:{day:[] for day in [3,5,7]} for key in ['easy','medium','hard']}
|
96 |
-
|
97 |
-
constraint_dis_record = {"commonsense":{"pass":0,"total":0},"hard":{"pass":0,"total":0}}
|
98 |
-
|
99 |
-
for constraint in ['commonsense','hard']:
|
100 |
-
if constraint == 'commonsense':
|
101 |
-
constraint_statistic = commonsenseConstraint_statistic_processed
|
102 |
-
elif constraint == 'hard':
|
103 |
-
constraint_statistic = hardConstraint_statistic_processed
|
104 |
-
|
105 |
-
key_dict = {'commonsense':['is_valid_information_in_current_city','is_valid_information_in_sandbox','is_reasonalbe_visiting_city','is_valid_restaurants','is_valid_transportation','is_valid_attractions','is_valid_accommodation','is_not_absent'],'hard':['valid_cost','valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']}
|
106 |
-
|
107 |
-
for key in constraint_statistic:
|
108 |
-
# level
|
109 |
-
for key2 in constraint_statistic[key]:
|
110 |
-
# day
|
111 |
-
# print(key2)
|
112 |
-
# key2 = eval(key2)
|
113 |
-
if key2 == -1:
|
114 |
-
print(constraint_statistic[key])
|
115 |
-
exit(0)
|
116 |
-
for key3 in key_dict[constraint]:
|
117 |
-
data_record[key][key2].append('0/0')
|
118 |
-
if key3 in constraint_statistic[key][key2]:
|
119 |
-
constraint_dis_record[constraint]['pass'] += constraint_statistic[key][key2][key3]['true']
|
120 |
-
if constraint == 'hard':
|
121 |
-
if key == 'hard' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']:
|
122 |
-
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}"
|
123 |
-
constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3]
|
124 |
-
elif key == 'medium' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type']:
|
125 |
-
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}"
|
126 |
-
constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3]
|
127 |
-
else:
|
128 |
-
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}"
|
129 |
-
if key3 in ['valid_cost','valid_visitng_city_number','valid_days']:
|
130 |
-
constraint_dis_record[constraint]['total'] += count_record[key][key2]
|
131 |
-
else:
|
132 |
-
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}"
|
133 |
-
constraint_dis_record[constraint]['total'] += count_record[key][key2]
|
134 |
-
|
135 |
-
final_all_cnt = 0
|
136 |
-
final_commonsense_cnt = 0
|
137 |
-
final_hardConstraint_cnt = 0
|
138 |
-
final_all_cnt_map = {level:0 for level in ['easy','medium','hard']}
|
139 |
-
for idx in (range(0,len(query_data_list))):
|
140 |
-
if plan_constraint_store[idx]['commonsense_constraint']:
|
141 |
-
final_commonsense_pass = True
|
142 |
-
final_hardConstraint_pass = True
|
143 |
-
for item in plan_constraint_store[idx]['commonsense_constraint']:
|
144 |
-
if plan_constraint_store[idx]['commonsense_constraint'][item][0] is not None and not plan_constraint_store[idx]['commonsense_constraint'][item][0]:
|
145 |
-
final_commonsense_pass = False
|
146 |
-
break
|
147 |
-
if plan_constraint_store[idx]['hard_constraint'] is None:
|
148 |
-
continue
|
149 |
-
for item in plan_constraint_store[idx]['hard_constraint']:
|
150 |
-
if plan_constraint_store[idx]['hard_constraint'][item][0] is not None and plan_constraint_store[idx]['hard_constraint'][item][0] == False:
|
151 |
-
final_hardConstraint_pass = False
|
152 |
-
break
|
153 |
-
|
154 |
-
if final_commonsense_pass:
|
155 |
-
final_commonsense_cnt += 1
|
156 |
-
if final_hardConstraint_pass:
|
157 |
-
final_hardConstraint_cnt += 1
|
158 |
-
if final_commonsense_pass and final_hardConstraint_pass:
|
159 |
-
final_all_cnt += 1
|
160 |
-
final_all_cnt_map[query_data_list[idx]['level']] += 1
|
161 |
-
|
162 |
-
result = {}
|
163 |
-
|
164 |
-
if validation_or_test == 'validation':
|
165 |
-
result['Delivery Rate'] = delivery_cnt / 180
|
166 |
-
result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 1440
|
167 |
-
result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 180
|
168 |
-
result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 420
|
169 |
-
result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 180
|
170 |
-
result['Final Pass Rate'] = final_all_cnt / 180
|
171 |
-
|
172 |
-
elif validation_or_test == 'test':
|
173 |
-
result['Delivery Rate'] = delivery_cnt / 1000
|
174 |
-
result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 8000
|
175 |
-
result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 1000
|
176 |
-
result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 2290
|
177 |
-
result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 1000
|
178 |
-
result['Final Pass Rate'] = final_all_cnt / 1000
|
179 |
-
|
180 |
-
return result
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluation/hardConstraint.py
DELETED
@@ -1,266 +0,0 @@
|
|
1 |
-
from annotation.src.utils import get_valid_name_city,extract_before_parenthesis,extract_numbers_from_filenames
|
2 |
-
from tools.flights.apis import Flights
|
3 |
-
from tools.accommodations.apis import Accommodations
|
4 |
-
from tools.restaurants.apis import Restaurants
|
5 |
-
from tools.googleDistanceMatrix.apis import GoogleDistanceMatrix
|
6 |
-
from tools.attractions.apis import Attractions
|
7 |
-
import math
|
8 |
-
import json
|
9 |
-
import re
|
10 |
-
import numpy as np
|
11 |
-
import os
|
12 |
-
import sys
|
13 |
-
from tqdm import tqdm
|
14 |
-
import argparse
|
15 |
-
|
16 |
-
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
|
17 |
-
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
18 |
-
|
19 |
-
|
20 |
-
flight = Flights()
|
21 |
-
accommodation = Accommodations()
|
22 |
-
restaurants = Restaurants()
|
23 |
-
googleDistanceMatrix = GoogleDistanceMatrix()
|
24 |
-
attractions = Attractions()
|
25 |
-
|
26 |
-
|
27 |
-
def load_line_json_data(filename):
|
28 |
-
data = []
|
29 |
-
with open(filename, 'r', encoding='utf-8') as f:
|
30 |
-
for line in f.read().strip().split('\n'):
|
31 |
-
unit = json.loads(line)
|
32 |
-
data.append(unit)
|
33 |
-
return data
|
34 |
-
|
35 |
-
|
36 |
-
def convert_bool_values(item):
|
37 |
-
if isinstance(item, dict):
|
38 |
-
# If the item is a dictionary, recurse on each value
|
39 |
-
return {key: convert_bool_values(value) for key, value in item.items()}
|
40 |
-
elif isinstance(item, list):
|
41 |
-
# If the item is a list, recurse on each item in the list
|
42 |
-
return [convert_bool_values(value) for value in item]
|
43 |
-
elif isinstance(item, tuple):
|
44 |
-
# If the item is a tuple, recurse on each item in the tuple and repackage as a tuple
|
45 |
-
return tuple(convert_bool_values(value) for value in item)
|
46 |
-
elif isinstance(item, np.bool_): # Here we check for numpy's bool_ type
|
47 |
-
# If the item is a numpy bool_, convert it to a standard Python bool
|
48 |
-
return bool(item)
|
49 |
-
else:
|
50 |
-
# If the item is any other type, return it unchanged
|
51 |
-
return item
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
def extract_from_to(text: str):
|
57 |
-
"""
|
58 |
-
Extracts 'A' and 'B' from the format "from A to B" in the given text, with B ending at a comma or the end of the string.
|
59 |
-
|
60 |
-
Args:
|
61 |
-
- text (str): The input string.
|
62 |
-
|
63 |
-
Returns:
|
64 |
-
- tuple: A tuple containing 'A' and 'B'. If no match is found, returns (None, None).
|
65 |
-
"""
|
66 |
-
pattern = r"from\s+(.+?)\s+to\s+([^,]+)(?=[,\s]|$)"
|
67 |
-
matches = re.search(pattern, text)
|
68 |
-
return matches.groups() if matches else (None, None)
|
69 |
-
|
70 |
-
|
71 |
-
def get_total_cost(question, tested_data):
|
72 |
-
total_cost = 0
|
73 |
-
for i in range(min(question['days'],len(tested_data))):
|
74 |
-
unit = tested_data[i]
|
75 |
-
# transporation
|
76 |
-
if unit['transportation'] and unit['transportation'] != '-':
|
77 |
-
value = unit['transportation']
|
78 |
-
org_city, dest_city = extract_from_to(value)
|
79 |
-
if org_city == None or dest_city == None:
|
80 |
-
org_city, dest_city = extract_from_to(unit['current_city'])
|
81 |
-
|
82 |
-
if org_city == None or dest_city == None:
|
83 |
-
pass
|
84 |
-
else:
|
85 |
-
if 'flight number' in value.lower():
|
86 |
-
res = flight.data[flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]]
|
87 |
-
if len(res) > 0:
|
88 |
-
total_cost += res['Price'].values[0] * question['people_number']
|
89 |
-
|
90 |
-
elif 'self-driving' in value.lower() or 'taxi' in value.lower():
|
91 |
-
if 'self-driving' in value.lower():
|
92 |
-
# print(org_city,dest_city)
|
93 |
-
cost = googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'self-driving')['cost']
|
94 |
-
total_cost += cost * math.ceil(question['people_number'] * 1.0 / 5)
|
95 |
-
else:
|
96 |
-
cost = googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'taxi')['cost']
|
97 |
-
total_cost += cost * math.ceil(question['people_number'] * 1.0 / 4)
|
98 |
-
|
99 |
-
# breakfast
|
100 |
-
if unit['breakfast'] and unit['breakfast'] != '-':
|
101 |
-
name, city = get_valid_name_city(unit['breakfast'])
|
102 |
-
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
103 |
-
if len(res) > 0:
|
104 |
-
total_cost += res['Average Cost'].values[0] * question['people_number']
|
105 |
-
|
106 |
-
|
107 |
-
# lunch
|
108 |
-
if unit['lunch'] and unit['lunch'] != '-':
|
109 |
-
name, city = get_valid_name_city(unit['lunch'])
|
110 |
-
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
111 |
-
if len(res) > 0:
|
112 |
-
total_cost += res['Average Cost'].values[0] * question['people_number']
|
113 |
-
|
114 |
-
# dinner
|
115 |
-
if unit['dinner'] and unit['dinner'] != '-':
|
116 |
-
name, city = get_valid_name_city(unit['dinner'])
|
117 |
-
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
118 |
-
if len(res) > 0:
|
119 |
-
total_cost += res['Average Cost'].values[0] * question['people_number']
|
120 |
-
|
121 |
-
# accommodation
|
122 |
-
if unit['accommodation'] and unit['accommodation'] != '-':
|
123 |
-
name, city = get_valid_name_city(unit['accommodation'])
|
124 |
-
res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
|
125 |
-
if len(res) > 0:
|
126 |
-
total_cost += res['price'].values[0] * math.ceil(question['people_number'] * 1.0 / res['maximum occupancy'].values[0])
|
127 |
-
# print(total_cost)
|
128 |
-
return total_cost
|
129 |
-
|
130 |
-
|
131 |
-
def is_valid_room_rule(question, tested_data):
|
132 |
-
|
133 |
-
if question['local_constraint']['house rule'] is None:
|
134 |
-
return None,None
|
135 |
-
|
136 |
-
for i in range(min(question['days'],len(tested_data))):
|
137 |
-
unit = tested_data[i]
|
138 |
-
if unit['accommodation'] and unit['accommodation'] != '-':
|
139 |
-
name, city = get_valid_name_city(unit['accommodation'])
|
140 |
-
res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
|
141 |
-
if len(res) > 0:
|
142 |
-
if question['local_constraint']['house rule'] == 'smoking' and 'No smoking' in str(res['house_rules'].values[0]):
|
143 |
-
return False, f"The house rule should be {question['local_constraint']['house rule']}."
|
144 |
-
if question['local_constraint']['house rule'] == 'parities' and 'No parties' in str(res['house_rules'].values[0]):
|
145 |
-
return False, f"The house rule should be {question['local_constraint']['house rule']}."
|
146 |
-
if question['local_constraint']['house rule'] == 'children under 10' and 'No children under 10' in str(res['house_rules'].values[0]):
|
147 |
-
return False, f"The house rule should be {question['local_constraint']['house rule']}."
|
148 |
-
if question['local_constraint']['house rule'] == 'visitors' and 'No visitors' in str(res['house_rules'].values[0]):
|
149 |
-
return False, f"The house rule should be {question['local_constraint']['house rule']}."
|
150 |
-
if question['local_constraint']['house rule'] == 'pets' and 'No pets' in str(res['house_rules'].values[0]):
|
151 |
-
return False, f"The house rule should be {question['local_constraint']['house rule']}."
|
152 |
-
|
153 |
-
|
154 |
-
return True, None
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
def is_valid_cuisine(question, tested_data):
|
159 |
-
cuisine_set = set()
|
160 |
-
if question['local_constraint']['cuisine']:
|
161 |
-
for i in range(min(question['days'],len(tested_data))):
|
162 |
-
unit = tested_data[i]
|
163 |
-
|
164 |
-
if unit['breakfast'] and unit['breakfast'] != '-':
|
165 |
-
name, city = get_valid_name_city(unit['breakfast'])
|
166 |
-
if city == question['org']:
|
167 |
-
continue
|
168 |
-
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
169 |
-
if len(res) > 0:
|
170 |
-
for cuisine in question['local_constraint']['cuisine']:
|
171 |
-
if cuisine in res.iloc[0]['Cuisines']:
|
172 |
-
cuisine_set.add(cuisine)
|
173 |
-
|
174 |
-
if unit['lunch'] and unit['lunch'] != '-':
|
175 |
-
name, city = get_valid_name_city(unit['lunch'])
|
176 |
-
if city == question['org']:
|
177 |
-
continue
|
178 |
-
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
179 |
-
if len(res) > 0:
|
180 |
-
for cuisine in question['local_constraint']['cuisine']:
|
181 |
-
if cuisine in res.iloc[0]['Cuisines']:
|
182 |
-
cuisine_set.add(cuisine)
|
183 |
-
|
184 |
-
if unit['dinner'] and unit['dinner'] != '-':
|
185 |
-
name, city = get_valid_name_city(unit['dinner'])
|
186 |
-
if city == question['org']:
|
187 |
-
continue
|
188 |
-
res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]
|
189 |
-
if len(res) > 0:
|
190 |
-
for cuisine in question['local_constraint']['cuisine']:
|
191 |
-
if cuisine in res.iloc[0]['Cuisines']:
|
192 |
-
cuisine_set.add(cuisine)
|
193 |
-
|
194 |
-
if len(cuisine_set) == len(question['local_constraint']['cuisine']):
|
195 |
-
return True, None
|
196 |
-
else:
|
197 |
-
# judge which cuisine is not satisfied
|
198 |
-
for cuisine in question['local_constraint']['cuisine']:
|
199 |
-
if cuisine not in cuisine_set:
|
200 |
-
return False, f"The cuisine {cuisine} is not satisfied."
|
201 |
-
# return False, f"The cuisine should be {question['local_constraint']['cuisine']}."
|
202 |
-
else:
|
203 |
-
return None,None
|
204 |
-
|
205 |
-
|
206 |
-
def is_valid_transportation(question, tested_data):
|
207 |
-
if question['local_constraint']['transportation'] is None:
|
208 |
-
return None,None
|
209 |
-
for i in range(min(question['days'],len(tested_data))):
|
210 |
-
unit = tested_data[i]
|
211 |
-
if unit['transportation'] and unit['transportation'] != '-':
|
212 |
-
value = unit['transportation']
|
213 |
-
if question['local_constraint']['transportation'] == 'no flight' and 'Flight' in value:
|
214 |
-
return False, f"The transportation should not be {question['local_constraint']['transportation']}."
|
215 |
-
elif question['local_constraint']['transportation'] == 'no self-driving' and 'Self-driving' in value:
|
216 |
-
return False, f"The transportation should not be {question['local_constraint']['transportation']}."
|
217 |
-
|
218 |
-
return True, None
|
219 |
-
|
220 |
-
|
221 |
-
def is_valid_room_type(question, tested_data):
|
222 |
-
if question['local_constraint']['room type'] is None:
|
223 |
-
return None,None
|
224 |
-
for i in range(min(question['days'],len(tested_data))):
|
225 |
-
unit = tested_data[i]
|
226 |
-
if unit['accommodation'] and unit['accommodation'] != '-':
|
227 |
-
name, city = get_valid_name_city(unit['accommodation'])
|
228 |
-
res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]
|
229 |
-
if len(res) > 0:
|
230 |
-
if question['local_constraint']['room type'] == 'not shared room' and res['room type'].values[0] == 'Shared room':
|
231 |
-
return False, f"The room type should be {question['local_constraint']['room type']}."
|
232 |
-
# "shared room", "not shared room", "private room", "entire room"
|
233 |
-
elif question['local_constraint']['room type'] == 'shared room' and res['room type'].values[0] != 'Shared room':
|
234 |
-
return False, f"The room type should be {question['local_constraint']['room type']}."
|
235 |
-
|
236 |
-
elif question['local_constraint']['room type'] == 'private room' and res['room type'].values[0] != 'Private room':
|
237 |
-
return False, f"The room type should be {question['local_constraint']['room type']}."
|
238 |
-
|
239 |
-
elif question['local_constraint']['room type'] == 'entire room' and res['room type'].values[0] != 'Entire home/apt':
|
240 |
-
return False, f"The room type should be {question['local_constraint']['room type']}."
|
241 |
-
|
242 |
-
return True, None
|
243 |
-
|
244 |
-
|
245 |
-
def evaluation(query_data, tested_data):
|
246 |
-
return_info = {}
|
247 |
-
return_info['valid_cuisine'] = is_valid_cuisine(query_data, tested_data)
|
248 |
-
return_info['valid_room_rule'] = is_valid_room_rule(query_data, tested_data)
|
249 |
-
return_info['valid_transportation'] = is_valid_transportation(query_data, tested_data)
|
250 |
-
return_info['valid_room_type'] = is_valid_room_type(query_data, tested_data)
|
251 |
-
return_info['valid_cost'] = (bool(get_total_cost(query_data, tested_data) <= query_data['budget']), None)
|
252 |
-
return return_info
|
253 |
-
|
254 |
-
def boolean_evaluation(query_data, tested_data):
|
255 |
-
return_info = {}
|
256 |
-
return_info['valid_cuisine'] = is_valid_cuisine(query_data, tested_data)
|
257 |
-
return_info['valid_room_rule'] = is_valid_room_rule(query_data, tested_data)
|
258 |
-
return_info['valid_transportation'] = is_valid_transportation(query_data, tested_data)
|
259 |
-
return_info['valid_room_type'] = is_valid_room_type(query_data, tested_data)
|
260 |
-
return_info['valid_cost'] = (bool(get_total_cost(query_data, tested_data) <= query_data['budget']), None)
|
261 |
-
for key in return_info:
|
262 |
-
if return_info[key][0] == False:
|
263 |
-
print(key)
|
264 |
-
return False
|
265 |
-
return True
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluation/scored/1_validation_two-stage_1.jsonl
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"Delivery Rate": 0.8944444444444445, "Commonsense Constraint Micro Pass Rate": 0.6111111111111112, "Commonsense Constraint Macro Pass Rate": 0.027777777777777776, "Hard Constraint Micro Pass Rate": 0.1523809523809524, "Hard Constraint Macro Pass Rate": 0.10555555555555556, "Final Pass Rate": 0.005555555555555556}
|
|
|
|
evaluation/scored/textbox_validation_two-stage_1.jsonl
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"Delivery Rate": 0.8944444444444445, "Commonsense Constraint Micro Pass Rate": 0.6111111111111112, "Commonsense Constraint Macro Pass Rate": 0.027777777777777776, "Hard Constraint Micro Pass Rate": 0.1523809523809524, "Hard Constraint Macro Pass Rate": 0.10555555555555556, "Final Pass Rate": 0.005555555555555556}
|
|
|
|