Spaces:
Sleeping
Sleeping
| import itertools | |
| import json | |
| import os | |
| import sys | |
| from collections import Counter, defaultdict | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from pathlib import Path | |
| import fire | |
| import jsonlines | |
| import numpy as np | |
| import tqdm | |
| sys.path.extend( | |
| [Path(__file__).parent.parent, Path(__file__).parent.parent / "execution_engine"] | |
| ) | |
| # exit(0) | |
| # sys.path.extend([ | |
| from api_comm import APICommunication | |
| from exec_outcome import ExecOutcome | |
| from yaml import safe_load | |
| def estimate_pass_at_k( | |
| num_samples: int | list[int] | np.ndarray, | |
| num_correct: list[int] | np.ndarray, | |
| k: int, | |
| ) -> np.ndarray: | |
| """ | |
| Estimates pass@k of each problem and returns them in an array. | |
| """ | |
| def estimator(n: int, c: int, k: int): | |
| """ | |
| Calculates 1 - comb(n - c, k) / comb(n, k). | |
| """ | |
| if n - c < k: | |
| return 1.0 | |
| return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) | |
| if isinstance(num_samples, int): | |
| num_samples_it = itertools.repeat(num_samples, len(num_correct)) | |
| else: | |
| assert len(num_samples) == len(num_correct) | |
| num_samples_it = iter(num_samples) | |
| return np.array( | |
| [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)] | |
| ) | |
| def evaluate_functional_correctness( | |
| sample_file: str, | |
| k: list[int] = [1, 10, 100], | |
| n_workers: int = 4, | |
| limits_by_lang: dict = {}, | |
| compile_n_execute_args_by_lang: dict = {}, | |
| eval_result_file: str | None = None, | |
| unittest_file: str = "unittest_db.json", | |
| execeval_url: str = "http://localhost:5000", | |
| block_network: bool = True, | |
| stop_on_first_fail: bool = True, | |
| use_sanitizer: bool = False, | |
| ): | |
| """ | |
| Evaluates the functional correctness of generated samples, and writes | |
| results to f"{sample_file}_results.jsonl.gz" | |
| """ | |
| if eval_result_file is None: | |
| eval_result_file = f"{sample_file.split('.')[0]}-evaluated.jsonl" | |
| with open(unittest_file) as ut_rp: | |
| unittest_db = json.load(ut_rp) | |
| # Check the generated samples against test suites. | |
| with APICommunication(execeval_url) as execeval: | |
| execute_code = execeval.execute_code | |
| supported_langs = {r["runtime_name"] for r in execeval.get_runtimes()} | |
| with ThreadPoolExecutor(max_workers=n_workers) as executor: | |
| futures = [] | |
| completion_id = Counter() | |
| n_samples = 0 | |
| results = defaultdict(list) | |
| with jsonlines.open(sample_file) as sample_rp: | |
| for idx, sample in tqdm.tqdm( | |
| enumerate(sample_rp), desc="Reading samples" | |
| ): | |
| src_uid = sample["src_uid"] | |
| source_code = sample["source_code"] | |
| task_id = sample["task_id"] | |
| lang = sample["lang"] | |
| if src_uid not in unittest_db: | |
| continue | |
| unittests = unittest_db[src_uid] | |
| if len(unittests) == 0: | |
| continue | |
| if lang not in supported_langs: | |
| continue | |
| args = ( | |
| lang, | |
| source_code, | |
| unittests, | |
| limits_by_lang[lang], | |
| block_network, | |
| stop_on_first_fail, | |
| use_sanitizer, | |
| compile_n_execute_args_by_lang.get(lang, {}).get("compile_cmd"), | |
| compile_n_execute_args_by_lang.get(lang, {}).get( | |
| "compile_flags" | |
| ), | |
| compile_n_execute_args_by_lang.get(lang, {}).get("execute_cmd"), | |
| compile_n_execute_args_by_lang.get(lang, {}).get( | |
| "execute_flags" | |
| ), | |
| idx, | |
| task_id, | |
| ) | |
| future = executor.submit(execute_code, *args) | |
| futures.append(future) | |
| completion_id[task_id] += 1 | |
| n_samples += 1 | |
| print("Running test suites...") | |
| for idx, future in tqdm.tqdm( | |
| enumerate(as_completed(futures)), | |
| desc="Test running", | |
| total=len(futures), | |
| ): | |
| result = future.result() | |
| unittests, sample_idx, task_id = result | |
| if not isinstance(unittests, list) and "error" in unittests: | |
| """ | |
| [TODO] log it | |
| """ | |
| print("ERROR: ", unittests["error"]) | |
| continue | |
| results[task_id].append((sample_idx, unittests)) | |
| print("Calculate pass@k.") | |
| total, correct = [], [] | |
| for result in results.values(): | |
| result.sort() | |
| passed = [ | |
| all(x["exec_outcome"] == ExecOutcome.PASSED.value for x in r[1]) | |
| for r in result | |
| ] | |
| total.append(len(passed)) | |
| correct.append(sum(passed)) | |
| total = np.array(total) | |
| correct = np.array(correct) | |
| ks = k | |
| pass_at_k = { | |
| f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() | |
| for k in ks | |
| if (total >= k).all() | |
| } | |
| # Finally, save the results in one file: | |
| def combine_results(): | |
| with jsonlines.open(sample_file) as sample_rp: | |
| cnt = 0 | |
| for idx, sample in enumerate(sample_rp): | |
| cnt += 1 | |
| if sample["lang"] not in supported_langs: | |
| continue | |
| task_id = sample["task_id"] | |
| if len(results[task_id]) == 0: | |
| continue | |
| if results[task_id][0][0] > idx: | |
| continue | |
| result = results[task_id].pop(0) | |
| sample["unittests"] = result[1] | |
| _exec_outcomes = [ | |
| r["exec_outcome"] | |
| for r in result[1] | |
| if r["exec_outcome"] != ExecOutcome.PASSED.value | |
| ] + [ExecOutcome.PASSED.value] | |
| sample["exec_outcome"] = _exec_outcomes[0] | |
| yield sample | |
| print(f"Writing results to {eval_result_file}...") | |
| with jsonlines.open(eval_result_file, "w") as result_wp: | |
| for result in tqdm.tqdm(combine_results(), total=n_samples): | |
| result_wp.write(result) | |
| return pass_at_k | |
| def entry_point( | |
| sample_file: str, | |
| k: str | list | tuple = "1,2,5,10", | |
| n_workers: int = 4, | |
| compile_n_execute_args_by_lang_cfg_file: str | None = None, | |
| limits_by_lang_cfg_file: str | None = None, | |
| unittest_file: str = "unittest_db.json", | |
| execeval_url: str = "http://localhost:5000", | |
| block_network: bool = True, | |
| stop_on_first_fail: bool = True, | |
| use_sanitizer: bool = False, | |
| ): | |
| """ | |
| Evaluates the functional correctness of generated samples, and writes | |
| results to f"{sample_file}_results.jsonl.gz" | |
| """ | |
| """ | |
| [TODO] | |
| compile_n_execute_args_by_lang_cfg_file: str | None = None, | |
| limits_by_lang_cfg_file: str | None = None, | |
| assume yaml files and consider config.yaml for compile..args, | |
| and resource_limits.py for limits_by_lang | |
| """ | |
| limits_by_lang, compile_n_execute_args_by_lang = None, {} | |
| if limits_by_lang_cfg_file is None: | |
| limits_by_lang_cfg_file = "limits_by_lang.yaml" | |
| if not os.path.exists(limits_by_lang_cfg_file): | |
| print( | |
| "Need resource limit defaults for all runtimes, provide the path to default 'limits_by_lang.yaml' or to the modified one." | |
| ) | |
| exit(-1) | |
| with open(limits_by_lang_cfg_file) as limit_cfg_rp: | |
| limits_by_lang = safe_load(limit_cfg_rp) | |
| if compile_n_execute_args_by_lang_cfg_file is not None and os.path.exists( | |
| compile_n_execute_args_by_lang_cfg_file | |
| ): | |
| with open( | |
| compile_n_execute_args_by_lang_cfg_file | |
| ) as compile_n_execute_args_by_lang_rp: | |
| compile_n_execute_args_by_lang = safe_load( | |
| compile_n_execute_args_by_lang_rp | |
| ) | |
| ks = list(map(int, k.split(","))) if isinstance(k, str) else list(k) | |
| results = evaluate_functional_correctness( | |
| sample_file, | |
| ks, | |
| n_workers, | |
| block_network=block_network, | |
| limits_by_lang=limits_by_lang, | |
| compile_n_execute_args_by_lang=compile_n_execute_args_by_lang, | |
| unittest_file=unittest_file, | |
| execeval_url=execeval_url, | |
| stop_on_first_fail=stop_on_first_fail, | |
| use_sanitizer=use_sanitizer, | |
| ) | |
| print(results) | |
| def main(): | |
| fire.Fire(entry_point) | |
| sys.exit(main()) | |