debjitpaul commited on
Commit
fab8e16
1 Parent(s): c248499

Upload 9 files

Browse files
src/__init__.py ADDED
File without changes
src/data_transformations/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .correctness_flag import CorrectnessFlag
2
+ from .testing_results_summary_generation import TestingResultsSummaryGeneration
src/data_transformations/correctness_flag.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+
3
+ from flows.data_transformations.abstract import DataTransformation
4
+
5
+
6
+ class CorrectnessFlag(DataTransformation):
7
+ def __init__(self, output_key, input_key):
8
+ super().__init__(output_key)
9
+ self.input_key = input_key
10
+
11
+ def __call__(self, data_dict: Dict[str, Any], **kwargs) -> Dict[str, Any]:
12
+ all_tests_passed = all([test_result["status"] for test_result in data_dict[self.input_key]])
13
+ data_dict[self.output_key] = all_tests_passed
14
+ return data_dict
src/data_transformations/testing_results_summary_generation.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+
3
+ import jinja2
4
+
5
+ from flows.data_transformations.abstract import DataTransformation
6
+ from flows.utils.general_helpers import unflatten_dict
7
+
8
+
9
+ class TestingResultsSummaryGeneration(DataTransformation):
10
+ def __init__(self, output_key, **kwargs):
11
+ super().__init__(output_key)
12
+ self.params = kwargs
13
+
14
+ def __call__(self, data_dict: Dict[str, Any], **kwargs) -> Dict[str, Any]:
15
+ if data_dict["all_tests_passed"]:
16
+ # the execution did not result in any errors
17
+ data_dict[self.output_key] = self.params["no_error_template"]
18
+ return data_dict
19
+
20
+ test_data = unflatten_dict(data_dict)["raw_response"]
21
+
22
+ if not test_data["compilation_status"]:
23
+ # compilation error occurred
24
+ kwargs = {
25
+ "error_message": test_data["compilation_error_message"].strip(),
26
+ }
27
+
28
+ message_content = (
29
+ jinja2.Environment(loader=jinja2.BaseLoader())
30
+ .from_string(self.params["compilation_error_template"])
31
+ .render(**kwargs)
32
+ )
33
+ elif test_data["timeout_error"]:
34
+ # timeout error occurred
35
+
36
+ message_content = self.params["timeout_error_template"]
37
+ else:
38
+ # code compiled successfully without timeouts
39
+
40
+ # retrieve the failed tests
41
+ failed_tests = [
42
+ test_result
43
+ for test_result in test_data["public_tests_results"]
44
+ if not test_result["status"]
45
+ ]
46
+
47
+ runtime_error_test = None
48
+ for test_result in failed_tests:
49
+ if test_result["generated_output"] is None:
50
+ # runtime error occurred
51
+ runtime_error_test = test_result
52
+
53
+ if runtime_error_test:
54
+ # construct the error message for the runtime error
55
+ kwargs = {
56
+ "test_input": runtime_error_test["input"],
57
+ "error_message": runtime_error_test["error_message"].strip(),
58
+ }
59
+
60
+ message_content = (
61
+ jinja2.Environment(loader=jinja2.BaseLoader())
62
+ .from_string(self.params["runtime_error_template"])
63
+ .render(**kwargs)
64
+ )
65
+ else:
66
+ # construct the error message corresponding to a logical error
67
+
68
+ if self.params["single_test_error_message"]:
69
+ # construct the error message for a single (the first) failed test
70
+ first_failed_test = failed_tests[0]
71
+
72
+ kwargs = {
73
+ "test_input": first_failed_test["input"],
74
+ "expected_output": first_failed_test["expected_output"],
75
+ "generated_output": first_failed_test["generated_output"],
76
+ }
77
+
78
+ message_content = (
79
+ jinja2.Environment(loader=jinja2.BaseLoader())
80
+ .from_string(self.params["single_test_error_template"])
81
+ .render(**kwargs)
82
+ )
83
+ else:
84
+ # construct the error message covering all failed tests
85
+ parts = [self.params["all_tests_header"]]
86
+
87
+ for idx, test_result in enumerate(failed_tests):
88
+ kwargs = {
89
+ "idx": idx + 1,
90
+ "test_input": test_result["input"],
91
+ "expected_output": test_result["expected_output"],
92
+ "generated_output": test_result["generated_output"],
93
+ }
94
+
95
+ parts.append(
96
+ jinja2.Environment(loader=jinja2.BaseLoader())
97
+ .from_string(self.params["test_error_template"])
98
+ .render(**kwargs)
99
+ )
100
+
101
+ message_content = self.params["tests_separator"].join(parts)
102
+ data_dict[self.output_key] = message_content
103
+ return data_dict
src/datasets/__init__.py ADDED
File without changes
src/datasets/schema.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Dict
2
+
3
+
4
+ def assert_test_format_codeforces(tests: List[Tuple[List[str], str]]):
5
+ assert isinstance(tests, list) or tests is None
6
+ if tests is None:
7
+ return
8
+ for test in tests:
9
+ assert isinstance(test, list)
10
+ assert len(test) == 2
11
+ inputs, outputs = test
12
+ assert isinstance(inputs, list)
13
+ assert isinstance(outputs, str)
14
+ for input in inputs:
15
+ assert isinstance(input, str)
16
+
17
+
18
+ def assert_entry_format_codeforces(obj: Dict):
19
+ # each data point must follow the same schema
20
+ assert isinstance(obj["id"], str) # contest + problem_name = id, will not change when formatting changes
21
+ assert isinstance(obj["id_hash"], str) # hashsum of all entries, any change to obj will change this
22
+ assert isinstance(obj["contest"], int)
23
+ assert isinstance(obj["problem_name"], str)
24
+ assert isinstance(obj["problem_url"], str)
25
+ assert isinstance(obj["solution_url"], str)
26
+
27
+ assert isinstance(obj["header"], str)
28
+ assert isinstance(obj["problem_description"], str)
29
+ assert isinstance(obj["input_description"], str)
30
+ assert isinstance(obj["output_description"], str)
31
+ assert isinstance(obj["note"], str) or obj["note"] is None
32
+
33
+ assert isinstance(obj["difficulty"], int)
34
+ assert isinstance(obj["tags"], list)
35
+ assert isinstance(obj["working_solution"], str) # can be empty
36
+
37
+ assert_test_format_codeforces(obj["public_tests_io"])
38
+ assert_test_format_codeforces(obj["public_tests_individual_io"])
39
+ assert_test_format_codeforces(obj["hidden_tests_io"])
40
+
41
+
42
+ def assert_test_format_leetcode(tests: List[Tuple[List[str], str]]):
43
+ pass
44
+ # ToDo: Uncomment after the test format is updated
45
+ # assert isinstance(tests, list)
46
+ # for test in tests:
47
+ # assert isinstance(test, tuple)
48
+ # assert len(test) == 2
49
+ # x, y = test
50
+ # assert isinstance(x, str)
51
+ # assert isinstance(y, str)
52
+
53
+
54
+ def assert_entry_format_leetcode(obj: Dict):
55
+ # each data point must follow the same schema
56
+ assert isinstance(obj["id"], str) # contest + problem_name = id, will not change when formatting changes
57
+ assert isinstance(obj["id_hash"], str) # hashsum of all entries, any change to obj will change this
58
+ assert isinstance(obj["index"], int)
59
+ assert isinstance(obj["problem_name"], str)
60
+ assert isinstance(obj["problem_url"], str)
61
+
62
+ assert isinstance(obj["problem_description"], str)
63
+ assert isinstance(obj["constraints"], str)
64
+ assert isinstance(obj["python_stub"], str)
65
+ assert isinstance(obj["difficulty"], str) and obj["difficulty"] in {"easy", "medium", "hard"}
66
+
67
+ # ToDo: Should be added
68
+ # assert isinstance(obj['tags'], list)
69
+ # assert isinstance(obj['solution_url'], str)
70
+ # assert isinstance(obj['working_solution'], str) # can be empty
71
+
72
+ # ToDo: Uncomment after the test format is updated
73
+ # assert_test_format_leetcode(obj['public_tests_io'])
74
+ # assert_test_format_leetcode(obj['hidden_tests_io'])
src/evaluation/__init__.py ADDED
File without changes
src/evaluation/testing_utils_codeforces.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is based heavily on the huggingface APPS metric
2
+ import re
3
+
4
+ # to run the solution files we're using a timing based approach
5
+ import signal
6
+ import sys
7
+
8
+ # for capturing the stdout
9
+ from io import StringIO
10
+ from typing import List, Tuple
11
+
12
+ # used for testing the code that reads from input
13
+ from unittest.mock import patch, mock_open
14
+
15
+ import numpy as np
16
+ from pyext import RuntimeModule
17
+ from wrapt_timeout_decorator import timeout as wrapt_timeout
18
+ import threading
19
+
20
+ from ..datasets.schema import assert_test_format_codeforces
21
+
22
+ from flows import logging
23
+
24
+ log = logging.get_logger(__name__)
25
+ lock = threading.Lock()
26
+
27
+
28
+ def evaluate_solution_for_problem(
29
+ candidate_solution,
30
+ hidden_tests_io=None,
31
+ public_tests_io=None,
32
+ timeout=10,
33
+ debug=False,
34
+ add_extra_imports=False,
35
+ allow_truncated_io=False,
36
+ ):
37
+ with lock:
38
+ """See the readme for the output format of this function."""
39
+ if hidden_tests_io is None:
40
+ hidden_tests_io = []
41
+ if public_tests_io is None:
42
+ public_tests_io = []
43
+
44
+ if candidate_solution is None:
45
+ results_dict = {
46
+ "compilation_status": False,
47
+ "compilation_error_message": "No code was provided.",
48
+ "timeout_error": False,
49
+ "hidden_tests_results": [
50
+ {
51
+ "status": False,
52
+ "error_message": "No code was provided.",
53
+ "generated_output": None,
54
+ "input": test[0],
55
+ "expected_output": test[1],
56
+ }
57
+ for test in hidden_tests_io
58
+ ],
59
+ "public_tests_results": [
60
+ {
61
+ "status": False,
62
+ "error_message": "No code was provided.",
63
+ "generated_output": None,
64
+ "input": test[0],
65
+ "expected_output": test[1],
66
+ }
67
+ for test in public_tests_io
68
+ ],
69
+ }
70
+ return results_dict
71
+
72
+ @wrapt_timeout(timeout, use_signals=False)
73
+ def run_tests():
74
+ hidden_tests_results = check_correctness(
75
+ candidate_solution, hidden_tests_io, timeout, debug, add_extra_imports, allow_truncated_io
76
+ )
77
+ public_tests_results = check_correctness(
78
+ candidate_solution, public_tests_io, timeout, debug, add_extra_imports, allow_truncated_io
79
+ )
80
+
81
+ return hidden_tests_results, public_tests_results
82
+
83
+ try:
84
+ hidden_tests_results, public_tests_results = run_tests()
85
+ timeout_error_occurred = False
86
+ except BaseException as e:
87
+ log.info(e)
88
+ hidden_tests_results = {}
89
+ public_tests_results = {}
90
+
91
+ hidden_tests_results["compilation_status"] = True
92
+ public_tests_results["compilation_status"] = True
93
+ timeout_error_occurred = True
94
+ hidden_tests_results["error_message"] = "Timeout error."
95
+
96
+ hidden_tests_results["results"] = [
97
+ {
98
+ "status": False,
99
+ "error_message": hidden_tests_results["error_message"],
100
+ "generated_output": None,
101
+ "input": test[0],
102
+ "expected_output": test[1],
103
+ }
104
+ for test in hidden_tests_io
105
+ ]
106
+ public_tests_results["results"] = [
107
+ {
108
+ "status": False,
109
+ "error_message": hidden_tests_results["error_message"],
110
+ "generated_output": None,
111
+ "input": test[0],
112
+ "expected_output": test[1],
113
+ }
114
+ for test in public_tests_io
115
+ ]
116
+
117
+ # the compilation status shouldn't depend on the tests
118
+ assert hidden_tests_results["compilation_status"] == public_tests_results["compilation_status"]
119
+
120
+ results_dict = {
121
+ "compilation_status": hidden_tests_results["compilation_status"],
122
+ "compilation_error_message": hidden_tests_results["error_message"],
123
+ "timeout_error": timeout_error_occurred,
124
+ "hidden_tests_results": hidden_tests_results["results"],
125
+ "public_tests_results": public_tests_results["results"],
126
+ }
127
+
128
+ return results_dict
129
+
130
+
131
+ def check_correctness(
132
+ candidate_solution: str,
133
+ tests: List[Tuple[List[str], str]],
134
+ timeout: int = 6000,
135
+ debug=True,
136
+ add_extra_imports=False,
137
+ allow_truncated_io=True,
138
+ ):
139
+ """
140
+ wrapping the testing code in a global timeout, based on huggingface code
141
+ """
142
+
143
+ assert_test_format_codeforces(tests)
144
+ inputs, outputs = [], []
145
+ if len(tests) > 0:
146
+ inputs, outputs = zip(*tests)
147
+
148
+ compilation_error, results = run_test(
149
+ candidate_solution, inputs, outputs, timeout, debug, add_extra_imports, allow_truncated_io
150
+ )
151
+
152
+ assert len(results) == len(inputs)
153
+
154
+ for result in results:
155
+ assert isinstance(result["generated_output"], str) or result["generated_output"] is None
156
+ assert isinstance(result["status"], bool)
157
+ assert isinstance(result["error_message"], str) or result["error_message"] is None
158
+ assert isinstance(result["input"], list)
159
+ assert isinstance(result["expected_output"], str)
160
+
161
+ compilation_status = compilation_error == ""
162
+ if compilation_status:
163
+ compilation_error = None
164
+
165
+ return {"compilation_status": compilation_status, "error_message": compilation_error, "results": results}
166
+
167
+
168
+ class TimeoutException(Exception):
169
+ pass
170
+
171
+
172
+ def timeout_handler(signum, frame):
173
+ log.info("alarm went off")
174
+ # return
175
+ raise TimeoutException
176
+
177
+
178
+ signal.signal(signal.SIGALRM, timeout_handler)
179
+
180
+
181
+ # used to capture stdout as a list
182
+ # from https://stackoverflow.com/a/16571630/6416660
183
+ # alternative use redirect_stdout() from contextlib
184
+ class Capturing(list):
185
+ def __enter__(self):
186
+ self._stdout = sys.stdout
187
+ sys.stdout = self._stringio = StringIO()
188
+ # Make closing the StringIO a no-op
189
+ self._stringio.close = lambda x: 1
190
+ return self
191
+
192
+ def __exit__(self, *args):
193
+ self.extend(self._stringio.getvalue().splitlines())
194
+ del self._stringio # free up some memory
195
+ sys.stdout = self._stdout
196
+
197
+
198
+ def run_test(code, inputs, outputs, timeout: int = 6000, debug=True, add_extra_imports=False, allow_truncated_io=True):
199
+ """
200
+ runs the code and tries to match inputs and outputs
201
+ the scraped testcases may be incomplete
202
+ if allow_truncated_io==True, then we ignore an EOF exception at the end of the generated output
203
+ """
204
+ # Disable functionalities that can make destructive changes to the test.
205
+
206
+ results = []
207
+
208
+ if isinstance(code, list):
209
+ tmp_test = code
210
+ elif isinstance(code, str):
211
+ tmp_test = code.split("\n")
212
+ else:
213
+ raise AssertionError("code must be provided as list of lines or string with \\n linebreaks.")
214
+
215
+ # parse the code into code and imports
216
+ import_lines = []
217
+ future_import_lines = []
218
+ code_lines = []
219
+ for x in tmp_test:
220
+ if (not x.startswith("from ")) and (not x.startswith("import ")):
221
+ code_lines.append("\t" + x + "\n")
222
+ else:
223
+ if "__future__" in x:
224
+ future_import_lines.append(x + "\n")
225
+ else:
226
+ import_lines.append(x + "\n")
227
+
228
+ # assemble a new solution snippet which wraps the generated solution in a function code()
229
+ new_test = "stdin = sys.stdin\nstdout = sys.stdout\n"
230
+ new_test += '__name__="__main__"\n'
231
+ new_test += "def code():\n"
232
+ new_test += "\tstdin = sys.stdin\n\tstdout = sys.stdout\n"
233
+
234
+ for line in code_lines:
235
+ new_test += line
236
+
237
+ sol = "\n".join(future_import_lines)
238
+ sol += "import sys\n"
239
+ if add_extra_imports:
240
+ sol += "import time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
241
+ sol += "\n".join(import_lines) + "\n" + new_test
242
+
243
+ if debug:
244
+ log.info(f"sol = {sol}")
245
+ method_name = "code"
246
+ signal.alarm(timeout)
247
+
248
+ # convert the solution snippet into a pyext runtime module
249
+ sol_module = None
250
+ try:
251
+ sol_module = RuntimeModule.from_string("tmp_sol", "", sol)
252
+ signal.alarm(0)
253
+ except Exception as e:
254
+ signal.alarm(0)
255
+ if debug:
256
+ log.info(f"type 1 compilation error = {e}")
257
+ for inp, out in zip(inputs, outputs):
258
+ # consider all inputs failed
259
+ results.append(
260
+ {
261
+ "status": False,
262
+ "input": inp,
263
+ "expected_output": out,
264
+ "generated_output": None,
265
+ "error_message": repr(e),
266
+ }
267
+ )
268
+ return repr(e), results
269
+
270
+ assert sol_module is not None
271
+ signal.alarm(0)
272
+
273
+ try:
274
+ method = getattr(sol_module, method_name) # get_attr second arg must be str
275
+ except:
276
+ signal.alarm(0)
277
+ e = sys.exc_info()
278
+ log.info(f"unable to get function error = {e}")
279
+
280
+ for inp, out in zip(inputs, outputs):
281
+ # consider all inputs failed
282
+ results.append(
283
+ {
284
+ "status": False,
285
+ "input": inp,
286
+ "expected_output": out,
287
+ "generated_output": None,
288
+ "error_message": repr(e),
289
+ }
290
+ )
291
+ return repr(e), results
292
+
293
+ # go through all tests, call our runtime module with the inputs
294
+ # then compare with the reference output
295
+ for index, (test_input, reference_output) in enumerate(zip(inputs, outputs)):
296
+
297
+ result_object = {
298
+ "input": test_input,
299
+ "expected_output": reference_output,
300
+ }
301
+
302
+ # if the last token of the input is truncated and marked with "..." we delete it
303
+ input_truncated = False
304
+ if "".join(test_input).strip().endswith("...") and allow_truncated_io:
305
+ test_input = test_input[:-1]
306
+ input_truncated = True
307
+
308
+ # sometimes the last input token is ""
309
+ # if len(test_input)>0:
310
+ # if test_input[-1]=="":
311
+ # test_input = test_input[:-1]
312
+
313
+ error_code = None
314
+ with Capturing() as generated_output:
315
+ try:
316
+ call_method(method, test_input)
317
+ # reset the alarm
318
+ signal.alarm(0)
319
+ except Exception as e:
320
+ # runtime error or took too long
321
+ signal.alarm(0)
322
+ error_code = e
323
+ if debug:
324
+ log.info(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
325
+ signal.alarm(0)
326
+
327
+ # in some cases we run into truncated tests
328
+ # in such cases we expect the error code to be None, EOFError or ValueError
329
+ if (
330
+ (input_truncated or reference_output.strip().endswith("..."))
331
+ and allow_truncated_io
332
+ and (error_code is None or isinstance(error_code, EOFError) or isinstance(error_code, ValueError))
333
+ ):
334
+
335
+ generated_output = generated_output[:-1]
336
+ reference_output = reference_output.rstrip("...")
337
+ if len(generated_output) == 0:
338
+ # no output left, we pass by default
339
+ result_object.update(
340
+ **{
341
+ "status": True,
342
+ "generated_output": "\n".join(generated_output),
343
+ "error_message": None,
344
+ }
345
+ )
346
+ results.append(result_object)
347
+ else:
348
+ result_object.update(
349
+ **{
350
+ "status": string_compare(generated_output, reference_output, True),
351
+ "generated_output": "\n".join(generated_output),
352
+ "error_message": None,
353
+ }
354
+ )
355
+ results.append(result_object)
356
+
357
+ # if the input and output are not truncated, we don't allow any errors
358
+ elif error_code is not None:
359
+ result_object.update(**{"status": False, "generated_output": None, "error_message": repr(error_code)})
360
+ results.append(result_object)
361
+ # finally, if there are no errors, we expect the output to match the reference output
362
+ else:
363
+ # the execution went well, let's compare the outputs
364
+ result_object.update(
365
+ **{
366
+ "status": string_compare(generated_output, reference_output, False),
367
+ "generated_output": "\n".join(generated_output),
368
+ "error_message": None,
369
+ }
370
+ )
371
+ results.append(result_object)
372
+
373
+ return "", results
374
+
375
+
376
+ def string_compare(candidate, correct, truncate_output=False, floating_point_accuracy=0.01):
377
+ candidate = [o.strip().lower() for o in candidate]
378
+ correct = correct.strip().lower()
379
+
380
+ # normalize whitespace
381
+ candidate = "\n".join(candidate)
382
+ candidate = re.sub("\s+", " ", candidate).strip()
383
+ correct = re.sub("\s+", " ", correct).strip()
384
+
385
+ # split into individual tokens
386
+ candidate = candidate.split(" ")
387
+ correct = correct.split(" ")
388
+
389
+ # some tests may be truncated, if we allow this we don't enforce equal length of inputs/outputs
390
+ if not truncate_output:
391
+ if not len(candidate) == len(correct):
392
+ return False
393
+
394
+ # if we allow truncated io, the last token of the output may have been corrupted
395
+ if truncate_output:
396
+ correct = correct[:-1]
397
+
398
+ # when zip is used for lists of unequal length it will give as many pairs as there are items in the shorter list
399
+ for left, right in zip(candidate, correct):
400
+ if left == right:
401
+ continue
402
+
403
+ try:
404
+ int_left = int(left)
405
+ int_right = int(right)
406
+ if int_left == int_right:
407
+ continue
408
+ except ValueError:
409
+ pass
410
+
411
+ try:
412
+ float_left = float(left)
413
+ float_right = float(right)
414
+ if np.abs(float_left - float_right) < floating_point_accuracy:
415
+ continue
416
+ except ValueError:
417
+ pass
418
+
419
+ return False
420
+
421
+ return True
422
+
423
+
424
+ def call_method(method, inputs):
425
+ if isinstance(inputs, list):
426
+ inputs = "\n".join(inputs)
427
+
428
+ inputs_line_iterator = iter(inputs.split("\n"))
429
+
430
+ # sys.setrecursionlimit(10000)
431
+
432
+ # @patch('builtins.input', side_effect=inputs.split("\n"))
433
+ @patch("builtins.open", mock_open(read_data=inputs))
434
+ @patch("sys.stdin", StringIO(inputs))
435
+ @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator))
436
+ @patch("sys.stdin.readlines", lambda *args: inputs.split("\n"))
437
+ @patch("sys.stdin.read", lambda *args: inputs)
438
+ # @patch('sys.stdout.write', print)
439
+ def _inner_call_method(_method):
440
+ try:
441
+ return _method()
442
+ except SystemExit as e:
443
+ pass
444
+ finally:
445
+ pass
446
+
447
+ return _inner_call_method(method)
src/evaluation/testing_utils_leetcode.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is based heavily on the huggingface APPS metric
2
+ # to run the solution files we're using a timing based approach
3
+ # for capturing the stdout
4
+ # used for testing the code that reads from input
5
+ import logging
6
+ import re
7
+ from subprocess import Popen, PIPE, TimeoutExpired
8
+ from typing import List, Tuple
9
+ import threading
10
+
11
+ log = logging.getLogger(__name__)
12
+ lock = threading.Lock()
13
+
14
+ def evaluate_solution_for_problem(
15
+ candidate_solution,
16
+ python_stub,
17
+ hidden_tests_io=None,
18
+ public_tests_io=None,
19
+ timeout=10,
20
+ debug=False,
21
+ add_extra_imports=False,
22
+ ):
23
+ with lock:
24
+ """See the readme for the output format of this function."""
25
+ if hidden_tests_io is None:
26
+ hidden_tests_io = []
27
+ if public_tests_io is None:
28
+ public_tests_io = []
29
+
30
+ if candidate_solution is None:
31
+ results_dict = {
32
+ "compilation_status": False,
33
+ "compilation_error_message": "No code was provided.",
34
+ "timeout_error": False,
35
+ "hidden_tests_results": [
36
+ {
37
+ "status": False,
38
+ "error_message": "No code was provided.",
39
+ "generated_output": None,
40
+ "input": test[0],
41
+ "expected_output": test[1],
42
+ }
43
+ for test in hidden_tests_io
44
+ ],
45
+ "public_tests_results": [
46
+ {
47
+ "status": False,
48
+ "error_message": "No code was provided.",
49
+ "generated_output": None,
50
+ "input": test[0],
51
+ "expected_output": test[1],
52
+ }
53
+ for test in public_tests_io
54
+ ],
55
+ }
56
+ return results_dict
57
+
58
+ hidden_tests_results = check_correctness(
59
+ candidate_solution, python_stub, hidden_tests_io, timeout, debug, add_extra_imports
60
+ )
61
+ public_tests_results = check_correctness(
62
+ candidate_solution, python_stub, public_tests_io, timeout, debug, add_extra_imports
63
+ )
64
+
65
+ # the compilation status shouldn't depend on the tests
66
+ if len(hidden_tests_io) > 0 and len(public_tests_io) > 0:
67
+ assert hidden_tests_results["compilation_status"] == public_tests_results["compilation_status"]
68
+
69
+ compilation_status = True
70
+ error_message = None
71
+ timeout_error = False
72
+
73
+ if len(hidden_tests_io) > 0:
74
+ compilation_status = compilation_status and hidden_tests_results["compilation_status"]
75
+ error_message = hidden_tests_results["error_message"]
76
+ timeout_error = timeout_error or hidden_tests_results["timeout_error"]
77
+
78
+ if len(public_tests_io) > 0:
79
+ compilation_status = compilation_status and public_tests_results["compilation_status"]
80
+ error_message = public_tests_results["error_message"]
81
+ timeout_error = timeout_error or public_tests_results["timeout_error"]
82
+
83
+ results_dict = {
84
+ "compilation_status": compilation_status,
85
+ "compilation_error_message": error_message,
86
+ "timeout_error": timeout_error,
87
+ "hidden_tests_results": hidden_tests_results["results"],
88
+ "public_tests_results": public_tests_results["results"],
89
+ }
90
+
91
+ return results_dict
92
+
93
+
94
+ def check_correctness(
95
+ candidate_solution: str,
96
+ python_stub: str,
97
+ tests: List[Tuple[List[str], str]],
98
+ timeout: int = 6000,
99
+ debug=True,
100
+ add_extra_imports=False,
101
+ ):
102
+ compilation_status = True
103
+ compilation_error = None
104
+ results = []
105
+ timeout_occurred = False
106
+
107
+ for idx, test in enumerate(tests):
108
+ inp, out, expl = test
109
+ result = one_test(
110
+ candidate_solution, python_stub, inp, out, timeout=timeout, debug=debug, add_extra_imports=add_extra_imports
111
+ )
112
+ error_message = result["error_message"]
113
+
114
+ if error_message is not None:
115
+ if "syntaxerror" in error_message.lower():
116
+ compilation_status = False
117
+ compilation_error = error_message
118
+ if "timeout" in error_message.lower():
119
+ timeout_occurred = True
120
+ results.append(result)
121
+
122
+ if timeout_occurred:
123
+ break
124
+
125
+ if timeout_occurred:
126
+ return {
127
+ "compilation_status": True,
128
+ "timeout_error": True,
129
+ "error_message": "Timeout error.",
130
+ "results": results,
131
+ }
132
+
133
+ return {
134
+ "compilation_status": compilation_status,
135
+ "timeout_error": False,
136
+ "error_message": compilation_error,
137
+ "results": results,
138
+ }
139
+
140
+
141
+ def one_test(candidate_solution, python_stub, inp, out, timeout=10, debug=False, add_extra_imports=False):
142
+ python_stub = python_stub.strip()
143
+ candidate_solution = candidate_solution.strip()
144
+
145
+ out = out.replace("null", "None").replace("true", "True").replace("false", "False")
146
+
147
+ # reformat the solution and parse class and method name
148
+ class_def, signature = python_stub.split(" def ")
149
+ class_name = class_def.split("class ")[1].strip().rstrip(":")
150
+ func_name, _ = signature.split("(")
151
+
152
+ # reformatting the input
153
+ first_param = r"^\w+\s\=\s"
154
+ later_params = r",\s\w+\s\=\s"
155
+
156
+ inp = re.sub(first_param, "", inp)
157
+ inp = re.sub(later_params, ", ", inp)
158
+
159
+ # we add custom code to invoke the solution
160
+ before_output = "AFTER THIS COMES OUR OWN GENERATED OUTPUT !@#!@!"
161
+ after_output = "AFTER THIS COMES OUR VERDICT !@#!@!"
162
+
163
+ if add_extra_imports:
164
+ sol = f"""
165
+ from collections import *
166
+ from math import *
167
+ import math
168
+ from functools import *
169
+ from heapq import *
170
+ import heapq
171
+ import itertools
172
+ from itertools import *
173
+ import bisect
174
+ from bisect import *
175
+ """
176
+ else:
177
+ sol = ""
178
+
179
+ sol += f"""
180
+ from typing import List, Tuple, Optional
181
+ {candidate_solution}
182
+ sfohsdfdsfjhsdkfjhsdkjfh = {class_name}()
183
+ res = sfohsdfdsfjhsdkfjhsdkjfh.{func_name}({inp})
184
+
185
+ def nested_list_convert(inp):
186
+ try:
187
+ try:
188
+ inp = list(inp)
189
+ except BaseException as e:
190
+ return inp
191
+ out = []
192
+ for i in inp:
193
+ out.append(nested_list_convert(i))
194
+ except BaseException as e:
195
+ return inp
196
+ return out
197
+
198
+ matching = False
199
+ matching = matching or res == {out}
200
+ matching = matching or nested_list_convert(res) == {out}
201
+ matching = matching or nested_list_convert(res) == nested_list_convert({out})
202
+ matching = matching or str({out})==str(res).replace("{{","[").replace("(","[").replace("}}","]").replace(")","]")
203
+ matching = matching or str({out})==str(res).replace("{{","[").replace("(","[").replace("}}","]").replace(")","]")
204
+ print("res: ", res)
205
+ print("out: ", {out})
206
+ print("{before_output}")
207
+ print(res)
208
+ print("{after_output}")
209
+ print(matching)
210
+ """
211
+
212
+ cmd = "python3"
213
+
214
+ proc = Popen([cmd, "-c", sol], stdin=PIPE, stdout=PIPE, stderr=PIPE)
215
+
216
+ result_object = {"input": inp, "expected_output": out.strip('"')}
217
+
218
+ try:
219
+ stdout, stderr = proc.communicate("", timeout=timeout)
220
+ except TimeoutExpired as e:
221
+ if debug:
222
+ log.info(f"Timeout error, timeout={timeout}")
223
+ result_object.update({"status": False, "error_message": "Timeout error.", "generated_output": None})
224
+ return result_object
225
+
226
+ finally:
227
+ proc.kill()
228
+
229
+ stdout = stdout.decode()
230
+ stderr = stderr.decode().lower()
231
+
232
+ if stderr == "":
233
+ # No compilation or runtime error
234
+ stderr = None
235
+ else:
236
+ # Runtime or compilation error (distinction is made by the presence of "syntaxerror" in the error message)
237
+ result_object.update(**{"status": False, "error_message": stderr, "generated_output": None})
238
+ return result_object
239
+
240
+ try:
241
+ generated_output = stdout.split(before_output)[1]
242
+ generated_output, verdict = generated_output.split(after_output)
243
+ result_object.update(
244
+ **{
245
+ "status": verdict.strip() == "True",
246
+ "error_message": stderr,
247
+ "generated_output": generated_output.strip(),
248
+ }
249
+ )
250
+ return result_object
251
+ except IndexError as e:
252
+ raise Exception(f"An unexpected error has occurred while parsing the following generated output: {stdout}")
253
+ # Used in debugging
254
+ # log.info(e)
255
+ # result_object.update(
256
+ # **{"status": False, "error_message": "The output couldn't be parsed", "generated_output": None}
257
+ # )
258
+ # return result_object