Jialin Song commited on
Commit
429bd99
1 Parent(s): 4fcb844

add option to use a subset of test cases

Browse files
Files changed (1) hide show
  1. testing_util.py +14 -2
testing_util.py CHANGED
@@ -48,13 +48,13 @@ class Capturing(list):
48
  sys.stdout = self._stdout
49
 
50
 
51
- def run_test(sample, test=None, debug=False):
52
  """
53
  if test(generated_code) is not None it'll try to run the code.
54
  otherwise it'll just return an input and output pair.
55
  """
56
  # Disable functionalities that can make destructive changes to the test.
57
- reliability_guard()
58
 
59
  if debug:
60
  print(f"start = {datetime.now().time()}")
@@ -158,6 +158,10 @@ def run_test(sample, test=None, debug=False):
158
 
159
  program_outputs = {}
160
  for index, inputs in enumerate(in_outs["inputs"]):
 
 
 
 
161
  # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
162
  try:
163
  if isinstance(inputs[0], dict):
@@ -216,6 +220,14 @@ def run_test(sample, test=None, debug=False):
216
  if debug:
217
  print(f"Standard input runtime error or time limit exceeded error = {e}")
218
  results.append(-1)
 
 
 
 
 
 
 
 
219
  continue
220
  faulthandler.disable()
221
  signal.alarm(0)
 
48
  sys.stdout = self._stdout
49
 
50
 
51
+ def run_test(sample, test=None, debug=False, num_test=-1):
52
  """
53
  if test(generated_code) is not None it'll try to run the code.
54
  otherwise it'll just return an input and output pair.
55
  """
56
  # Disable functionalities that can make destructive changes to the test.
57
+ # reliability_guard()
58
 
59
  if debug:
60
  print(f"start = {datetime.now().time()}")
 
158
 
159
  program_outputs = {}
160
  for index, inputs in enumerate(in_outs["inputs"]):
161
+ # use at most num_test number of tests
162
+ if num_test > 0 and index >= num_test:
163
+ break
164
+
165
  # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
166
  try:
167
  if isinstance(inputs[0], dict):
 
220
  if debug:
221
  print(f"Standard input runtime error or time limit exceeded error = {e}")
222
  results.append(-1)
223
+ program_outputs[index] = {
224
+ "pass": False,
225
+ "pass_pct": 0,
226
+ "pass_res": [0],
227
+ "output": "",
228
+ "input": inputs,
229
+ "ground_truth": in_outs["outputs"][index]
230
+ }
231
  continue
232
  faulthandler.disable()
233
  signal.alarm(0)