|
""" NOTE: THIS IS A BETA VERSION OF FUNSEARCH. NEW VERSION DOCUMENTATION WILL BE RELEASED SOON.""" |
|
|
|
from aiflows.base_flows import AtomicFlow |
|
from typing import Dict, Any |
|
import os |
|
from aiflows.utils import logging |
|
import ast |
|
import signal |
|
from aiflows.interfaces.key_interface import KeyInterface |
|
log = logging.get_logger(f"aiflows.{__name__}") |
|
import threading |
|
from aiflows.messages import FlowMessage |
|
class TimeoutException(Exception): |
|
pass |
|
|
|
def timeout_handler(signum, frame): |
|
raise TimeoutException("Execution timed out") |
|
|
|
class EvaluatorFlow(AtomicFlow): |
|
""" This class implements an EvaluatorFlow. It is a flow that evaluates a program (python code) using a given evaluator function. This code is an implementation of Funsearch (https://www.nature.com/articles/s41586-023-06924-6) and is heavily inspired by the original code (https://github.com/google-deepmind/funsearch) |
|
|
|
**Configuration Parameters**: |
|
|
|
- `name` (str): The name of the flow. Default: "EvaluatorFlow" |
|
- `description` (str): A description of the flow. This description is used to generate the help message of the flow. Default: "A flow that evaluates code on tests" |
|
- `py_file` (str): The python code containing the evaluation function. No default value. This MUST be passed as a parameter. |
|
- `function_to_run_name` (str): The name of the function to run (the evaluation function) in the evaluator file. No default value. This MUST be passed as a parameter. |
|
- `test_inputs` (Dict[str,Any]): A dictionary of test inputs to evaluate the program. Default: {"test1": None, "test2": None} |
|
- `timeout_seconds` (int): The maximum number of seconds to run the evaluation function before returning an error. Default: 10 |
|
- `run_error_score` (int): The score to return if the evaluation function fails to run. Default: -100 |
|
- `use_test_input_as_key` (bool): Whether to use the test input parameters as the key in the output dictionary. Default: False |
|
|
|
**Input Interface**: |
|
|
|
- `artifact` (str): The program/artifact to evaluate. |
|
|
|
**Output Interface**: |
|
|
|
- `scores_per_test` (Dict[str, Dict[str, Any]]): A dictionary of scores per test input. |
|
|
|
**Citation**: |
|
|
|
@Article{FunSearch2023, |
|
author = {Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Balog, Matej and Kumar, M. Pawan and Dupont, Emilien and Ruiz, Francisco J. R. and Ellenberg, Jordan and Wang, Pengming and Fawzi, Omar and Kohli, Pushmeet and Fawzi, Alhussein}, |
|
journal = {Nature}, |
|
title = {Mathematical discoveries from program search with large language models}, |
|
year = {2023}, |
|
doi = {10.1038/s41586-023-06924-6} |
|
} |
|
""" |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.evaluator_py_file = self.flow_config["py_file"] |
|
self.run_error_score = self.flow_config["run_error_score"] |
|
|
|
|
|
self.local_namespace = {} |
|
self.load_functions() |
|
self.function_to_run_name = self.flow_config["function_to_run_name"] |
|
assert self.function_to_run_name in self.local_namespace, f"Function {self.function_to_run_name} not found in {self.evaluator_py_file_path}" |
|
self.function_to_run = self.local_namespace.get(self.function_to_run_name) |
|
|
|
self.test_inputs = self.flow_config["test_inputs"] |
|
self.timeout_seconds = self.flow_config["timeout_seconds"] |
|
self.local_namespace = {} |
|
|
|
|
|
select_island_id_with_default = lambda data_dict,**kwargs: {**data_dict,**{"island_id": data_dict.get("island_id", None)}} |
|
|
|
self.output_interface = KeyInterface( |
|
additional_transformations= [select_island_id_with_default], |
|
keys_to_select= ["scores_per_test"] |
|
) |
|
|
|
|
|
def load_functions(self): |
|
""" Load the functions from the evaluator py file with ast parsing""" |
|
|
|
file_content = self.evaluator_py_file |
|
try: |
|
|
|
parsed_ast = ast.parse(file_content) |
|
|
|
|
|
for node in parsed_ast.body: |
|
|
|
if isinstance(node, ast.Import): |
|
|
|
exec(compile(ast.Module(body=[node],type_ignores=[]), '<ast>', 'exec'), self.local_namespace) |
|
elif isinstance(node, ast.ImportFrom): |
|
|
|
exec(compile(ast.Module(body=[node],type_ignores=[]), '<ast>', 'exec'), self.local_namespace) |
|
|
|
|
|
exec(file_content, self.local_namespace) |
|
except Exception as e: |
|
log.error(f"Error functions: {e}") |
|
raise e |
|
|
|
def run_function_with_timeout(self, program: str, **kwargs): |
|
""" Run the evaluation function with a timeout |
|
|
|
:param program: The program to evaluate |
|
:type program: str |
|
:param kwargs: The keyword arguments to pass to the evaluation function |
|
:type kwargs: Dict[str, Any] |
|
:return: A tuple (bool, result) where bool is True if the function ran successfully and result is the output of the function |
|
:rtype: Tuple[bool, Any] |
|
""" |
|
self.result = None |
|
self.exception = None |
|
|
|
|
|
def target(): |
|
try: |
|
result = self.function_to_run(program, **kwargs) |
|
self.result = result |
|
except Exception as e: |
|
self.exception = e |
|
|
|
|
|
thread = threading.Thread(target=target) |
|
thread.start() |
|
|
|
|
|
thread.join(self.timeout_seconds) |
|
|
|
|
|
if thread.is_alive(): |
|
|
|
thread.terminate() |
|
return False, f"Function execution timed out after {self.timeout_seconds} seconds" |
|
|
|
|
|
if self.exception is not None: |
|
return False, str(self.exception) |
|
|
|
|
|
return True, self.result |
|
|
|
|
|
|
|
def evaluate_program(self, program: str, **kwargs): |
|
""" Evaluate the program using the evaluation function |
|
|
|
:param program: The program to evaluate |
|
:type program: str |
|
:param kwargs: The keyword arguments to pass to the evaluation function |
|
:type kwargs: Dict[str, Any] |
|
:return: A tuple (bool, result) where bool is True if the function ran successfully and result is the output of the function |
|
:rtype: Tuple[bool, Any] |
|
""" |
|
try: |
|
runs_ok, test_output = self.run_function_with_timeout(program, **kwargs) |
|
|
|
return runs_ok, test_output |
|
|
|
except Exception as e: |
|
log.debug(f"Error defining runnin program: {e} (could be due to syntax error from LLM)") |
|
return False, e |
|
|
|
|
|
|
|
def analyse(self, program: str): |
|
""" Analyse the program on the test inputs |
|
|
|
:param program: The program to evaluate |
|
:type program: str |
|
:return: A dictionary of scores per test input |
|
:rtype: Dict[str, Dict[str, Any]] |
|
""" |
|
|
|
if program.startswith("```python"): |
|
program = program[9:] |
|
if program.endswith("```"): |
|
program = program[:-3] |
|
|
|
scores_per_test = {} |
|
for key,test_input in self.test_inputs.items(): |
|
|
|
test_input_key = str(test_input) if self.flow_config["use_test_input_as_key"] else key |
|
|
|
if test_input is None: |
|
runs_ok,test_output = self.evaluate_program(program) |
|
else: |
|
runs_ok,test_output = self.evaluate_program(program, **test_input) |
|
|
|
if runs_ok and test_output is not None: |
|
scores_per_test[test_input_key] = {"score": test_output, "feedback": "No feedback available."} |
|
log.debug(f"Program run successfully for test case {test_input_key} with score: {test_output}") |
|
else: |
|
log.debug(f"Error running Program for test case {test_input_key}. Error is : {test_output} (could be due to syntax error from LLM)") |
|
scores_per_test[test_input_key] = {"score": self.run_error_score, "feedback": str(test_output)} |
|
|
|
return scores_per_test |
|
|
|
def run(self, input_message: FlowMessage): |
|
""" This method runs the flow. It's the main method of the flow. |
|
|
|
:param input_message: The input message |
|
:type input_message: FlowMessage |
|
""" |
|
input_data = input_message.data |
|
|
|
|
|
scores_per_test = self.analyse(input_data["artifact"]) |
|
|
|
response = {"scores_per_test": scores_per_test, "from": "EvaluatorFlow"} |
|
|
|
|
|
reply = self.package_output_message( |
|
input_message, |
|
response |
|
) |
|
self.send_message(reply) |
|
|