bc_eval / execution.py
gabeorlanski's picture
Update execution.py
589a6ad
import datetime
import gc
import multiprocessing as mp
import pathlib
import subprocess
from dataclasses import dataclass
from typing import Dict, List
from tqdm import tqdm
@dataclass
class CommandResult:
return_code: int
runtime: float
stdout: str
stderr: str
timed_out: bool
def safe_execute(
command_to_run: List[str],
working_dir: pathlib.Path,
timeout: int = 10,
) -> CommandResult:
"""Executes a list of commands safely.
Args:
command_to_run: The command to run.
working_dir: The working directory to run them in.
timeout Timeout.
Returns:
The result of executing the command.
"""
timed_out = False
return_code = -1
runtime = timeout
stderr = None
stdout = None
start_time = datetime.datetime.now()
try:
execution_process = subprocess.Popen(
command_to_run,
cwd=str(working_dir),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
try:
outputs = execution_process.communicate(timeout=timeout)
stdout, stderr = outputs
stdout = stdout.decode('utf-8')
stderr = stderr.decode('utf-8')
runtime = (datetime.datetime.now() - start_time).total_seconds()
return_code = execution_process.returncode
except subprocess.TimeoutExpired:
timed_out = True
runtime = timeout
finally:
execution_process.kill()
except Exception as e:
stderr = str(e)
stdout = ""
return_code = -1
runtime = -1
timed_out = False
return CommandResult(
return_code=return_code,
runtime=runtime,
stderr=stderr,
stdout=stdout,
timed_out=timed_out,
)
def execute_code(sample: Dict):
"""Execute a file of code.
Args:
sample: The sample to run.
Returns:
The execution result.
"""
file_path = sample["cwd"]
working_dir_for_execution = (
file_path.parent if file_path.is_file() else file_path
)
working_dir_for_execution = working_dir_for_execution.resolve().absolute()
timed_out = False
failed = False
results = []
for command in sample['commands']:
res = safe_execute(command['command'], working_dir=working_dir_for_execution, timeout=command['timeout'])
results.append(res)
if res.timed_out:
timed_out = True
break
if res.return_code != 0:
failed = True
break
return {
"qid":sample['qid'],
"idx": sample["idx"],
"file_path": str(file_path.absolute().resolve()),
"results": results,
"failed":failed,
"timed_out": timed_out,
}
def execute_predictions(
predictions: List[Dict],
num_workers: int = 1,
max_task_per_child: int = 1,
garbage_collection_freq: int = 500,
):
"""Execute a list of predictions in a specific language.
Args:
predictions: List of predictions.
num_workers: The number of workers to use.
max_task_per_child: The maximum tasks ran per child before it is killed.
garbage_collection_freq: How often to run garbage collection.
Returns:
The the array of raw execution results and the total runtime.
"""
# Make the arguments to submit to the ThreadPoolExecutor. Do it here so we
# can have a progress bar as well.
num_to_complete = len(predictions)
num_completed = 0
results = []
with mp.Pool(num_workers, maxtasksperchild=max_task_per_child) as pool:
for result in tqdm(
pool.imap_unordered(execute_code, predictions),
total=num_to_complete,
desc="Executing",
):
num_completed += 1
results.append(result)
if num_completed % garbage_collection_freq == 0:
gc.collect()
# Cleanup pool
pool.close()
pool.terminate()
return results