bc_eval / execution.py
gabeorlanski's picture
BC eval
1359055 unverified
raw
history blame
3.84 kB
import datetime
import gc
import multiprocessing as mp
import pathlib
import subprocess
from dataclasses import dataclass
from typing import Dict, List
from tqdm import tqdm
@dataclass
class CommandResult:
return_code: int
runtime: float
stdout: str
stderr: str
timed_out: bool
def safe_execute(
command_to_run: List[str],
working_dir: pathlib.Path,
timeout: int = 10,
) -> CommandResult:
"""Executes a list of commands safely.
Args:
command_to_run: The command to run.
working_dir: The working directory to run them in.
timeout Timeout.
Returns:
The result of executing the command.
"""
timed_out = False
return_code = -1
runtime = timeout
stderr = None
stdout = None
start_time = datetime.datetime.now()
execution_process = subprocess.Popen(
command_to_run,
cwd=str(working_dir),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
try:
outputs = execution_process.communicate(timeout=timeout)
stdout, stderr = outputs
stdout = stdout.decode('utf-8')
stderr = stderr.decode('utf-8')
runtime = (datetime.datetime.now() - start_time).total_seconds()
return_code = execution_process.returncode
except subprocess.TimeoutExpired:
timed_out = True
runtime = timeout
finally:
execution_process.kill()
return CommandResult(
return_code=return_code,
runtime=runtime,
stderr=stderr,
stdout=stdout,
timed_out=timed_out,
)
def execute_code(sample: Dict):
"""Execute a file of code.
Args:
sample: The sample to run.
Returns:
The execution result.
"""
file_path = sample["cwd"]
working_dir_for_execution = (
file_path.parent if file_path.is_file() else file_path
)
working_dir_for_execution = working_dir_for_execution.resolve().absolute()
timed_out = False
failed = False
results = []
for command in sample['commands']:
res = safe_execute(command['command'], working_dir=working_dir_for_execution, timeout=command['timeout'])
results.append(res)
if res.timed_out:
timed_out = True
break
if res.return_code != 0:
failed = True
break
return {
"qid":sample['qid'],
"idx": sample["idx"],
"file_path": str(file_path.absolute().resolve()),
"results": results,
"failed":failed,
"timed_out": timed_out,
}
def execute_predictions(
predictions: List[Dict],
num_workers: int = 1,
max_task_per_child: int = 1,
garbage_collection_freq: int = 500,
):
"""Execute a list of predictions in a specific language.
Args:
predictions: List of predictions.
num_workers: The number of workers to use.
max_task_per_child: The maximum tasks ran per child before it is killed.
garbage_collection_freq: How often to run garbage collection.
Returns:
The the array of raw execution results and the total runtime.
"""
# Make the arguments to submit to the ThreadPoolExecutor. Do it here so we
# can have a progress bar as well.
num_to_complete = len(predictions)
num_completed = 0
results = []
with mp.Pool(num_workers, maxtasksperchild=max_task_per_child) as pool:
for result in tqdm(
pool.imap_unordered(execute_code, predictions),
total=num_to_complete,
desc="Executing",
):
num_completed += 1
results.append(result)
if num_completed % garbage_collection_freq == 0:
gc.collect()
# Cleanup pool
pool.close()
pool.terminate()
return results