|
""" NOTE: THIS IS A BETA VERSION OF FUNSEARCH. NEW VERSION DOCUMENTATION WILL BE RELEASED SOON.""" |
|
|
|
|
|
from aiflows.base_flows import AtomicFlow |
|
from .Island import Island,ScoresPerTest |
|
import numpy as np |
|
from typing import Callable,Dict,Union, Any,Optional, List |
|
import time |
|
from aiflows.utils import logging |
|
from .artifacts import AbstractArtifact |
|
from .Program import Program,text_to_artifact |
|
import ast |
|
import os |
|
from aiflows.messages import FlowMessage |
|
log = logging.get_logger(f"aiflows.{__name__}") |
|
|
|
class ProgramDBFlow(AtomicFlow): |
|
""" This class implements a ProgramDBFlow. It's a flow that stores programs and their scores in a database. It can also query the database for the best programs or generate a prompt containing stored programs in order to evolve them with a SamplerFlow. This code is an implementation of Funsearch (https://www.nature.com/articles/s41586-023-06924-6) and is heavily inspired by the original code (https://github.com/google-deepmind/funsearch) |
|
|
|
**Configuration Parameters**: |
|
|
|
- `name` (str): The name of the flow. Default: "ProgramDBFlow" |
|
- `description` (str): A description of the flow. This description is used to generate the help message of the flow. Default: " A flow that saves programs in a database of islands" |
|
- `artifact_to_evolve_name` (str): The name of the artifact/program to evolve. Default: "solve_function" |
|
- `evaluate_function` (str): The function used to evaluate the program. No Default value. This MUST be passed as a parameter. |
|
- `evaluate_file_full_content` (str): The full content of the file containing the evaluation function. No Default value. This MUST be passed as a parameter. |
|
- `num_islands`: The number of islands to use. Default: 3 |
|
- `reset_period`: The period in seconds to reset the islands. Default: 3600 |
|
- `artifacts_per_prompt`: The number of previous artifacts/programs to include in a prompt. Default: 2 |
|
- `temperature`: The temperature of the island. Default: 0.1 |
|
- `temperature_period`: The period in seconds to change the temperature. Default: 30000 |
|
- `sample_with_replacement`: Whether to sample with replacement. Default: False |
|
- `portion_of_islands_to_reset`: The portion of islands to reset. Default: 0.5 |
|
- `template` (dict): The template to use for a program. Default: {"preface": ""} |
|
|
|
**Input Interface**: |
|
|
|
- `operation` (str): The operation to perform. It can be one of the following: ["register_program","get_prompt","get_best_programs_per_island"] |
|
|
|
**Output Interface**: |
|
|
|
- `retrieved` (Any): The retrieved data. It can be one of the following: |
|
- If the operation is "get_prompt", it can be a dictionary with the following keys |
|
- `code` (str): The code of the prompt |
|
- `version_generated` (int): The version of the prompt generated |
|
- `island_id` (int): The id of the island that generated the prompt |
|
- `header` (str): The header of the prompt |
|
- If the operation is "register_program", it can be a string with the message "Program registered" or "Program failed to register" |
|
- If the operation is "get_best_programs_per_island", it can be a dictionary with the following keys: |
|
- `best_island_programs` (List[Dict[str,Any]]): A list of dictionaries with the following keys: |
|
- `rank` (int): The rank of the program (1 is the best) |
|
- `score` (float): The score of the program |
|
- `program` (str): The program |
|
- `island_id` (int): The id of the island that generated the program |
|
|
|
**Citation**: |
|
|
|
@Article{FunSearch2023, |
|
author = {Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Balog, Matej and Kumar, M. Pawan and Dupont, Emilien and Ruiz, Francisco J. R. and Ellenberg, Jordan and Wang, Pengming and Fawzi, Omar and Kohli, Pushmeet and Fawzi, Alhussein}, |
|
journal = {Nature}, |
|
title = {Mathematical discoveries from program search with large language models}, |
|
year = {2023}, |
|
doi = {10.1038/s41586-023-06924-6} |
|
} |
|
""" |
|
def __init__(self, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
|
|
|
|
self.artifact_to_evolve_name: str = self.flow_config["artifact_to_evolve_name"] |
|
|
|
self.artifacts_per_prompt: int = self.flow_config["artifacts_per_prompt"] |
|
self.temperature: float = self.flow_config["temperature"] |
|
self.temperature_period: int = self.flow_config["temperature_period"] |
|
self.reduce_score_method: Callable = np.mean |
|
self.sample_with_replacement: bool = self.flow_config["sample_with_replacement"] |
|
self.num_islands: int = self.flow_config["num_islands"] |
|
self.portion_of_islands_to_reset: float = self.flow_config["portion_of_islands_to_reset"] |
|
self.reset_period: float = self.flow_config["reset_period"] |
|
|
|
self.evaluate_function = self.flow_config["evaluate_function"] |
|
self.evaluate_file_full_content = self.flow_config["evaluate_file_full_content"] |
|
|
|
|
|
|
|
assert self.portion_of_islands_to_reset <= 1.0 and self.portion_of_islands_to_reset >= 0.0, "portion_of_islands_to_reset must be between 0 and 1" |
|
|
|
self.islands_to_reset = int(round(self.portion_of_islands_to_reset * self.num_islands)) |
|
|
|
|
|
def set_up_flow_state(self): |
|
""" This method sets up the state of the flow and clears the previous messages.""" |
|
super().set_up_flow_state() |
|
|
|
preface = \ |
|
self.flow_config["template"]["preface"] + "\n\n" "#function used to evaluate the program:\n" + self.flow_config["evaluate_function"] + "\n\n" |
|
|
|
self.template: Program = Program(preface =preface,artifacts = []) |
|
|
|
|
|
self.flow_state["islands"] = [ |
|
Island( |
|
artifact_to_evolve_name =self.flow_config["artifact_to_evolve_name"], |
|
artifacts_per_prompt = self.flow_config["artifacts_per_prompt"], |
|
temperature = self.flow_config["temperature"], |
|
temperature_period = self.flow_config["temperature_period"], |
|
reduce_score_method = np.mean, |
|
sample_with_replacement = self.flow_config["sample_with_replacement"], |
|
template=self.template |
|
) |
|
for _ in range(self.flow_config["num_islands"]) |
|
] |
|
|
|
self.flow_state["last_reset_time"] = time.time() |
|
self.flow_state["best_score_per_island"] = [float("-inf") for _ in range(self.flow_config["num_islands"])] |
|
self.flow_state["best_program_per_island"] = [None for _ in range(self.flow_config["num_islands"])] |
|
self.flow_state["best_scores_per_test_per_island"] = [None for _ in range(self.flow_config["num_islands"])] |
|
self.flow_state["first_program_registered"] = False |
|
|
|
def get_prompt(self): |
|
""" This method gets a prompt from an island. It returns the code, the version generated and the island id.""" |
|
island_id = np.random.choice(len(self.flow_state["islands"])) |
|
code, version_generated = self.flow_state["islands"][island_id].get_prompt() |
|
return code,version_generated,island_id |
|
|
|
def reset_islands(self): |
|
""" This method resets the islands. It resets the worst islands and copies the best programs to the worst islands.""" |
|
|
|
sorted_island_ids = np.argsort( |
|
np.array(self.flow_state["best_score_per_island"]) + |
|
(np.random.randn(len(self.flow_state["best_score_per_island"])) * 1e-6) |
|
) |
|
|
|
reset_island_ids = sorted_island_ids[:self.islands_to_reset] |
|
keep_island_ids = sorted_island_ids[self.islands_to_reset:] |
|
|
|
for island_id in reset_island_ids: |
|
self.flow_state["islands"][island_id] = Island( |
|
artifact_to_evolve_name =self.artifact_to_evolve_name, |
|
artifacts_per_prompt = self.artifacts_per_prompt, |
|
temperature = self.temperature, |
|
temperature_period = self.temperature_period, |
|
reduce_score_method = np.mean, |
|
sample_with_replacement = self.sample_with_replacement, |
|
template=self.template |
|
) |
|
|
|
self.flow_state["best_score_per_island"][island_id] = float("-inf") |
|
founder_island_id = np.random.choice(keep_island_ids) |
|
founder = self.flow_state["best_score_per_island"][founder_island_id] |
|
founder_scores = self.flow_state["best_scores_per_test_per_island"][founder_island_id] |
|
self._register_program_in_island(program=founder,island_id=island_id,scores_per_test=founder_scores) |
|
|
|
|
|
|
|
def register_program(self, program: AbstractArtifact ,island_id: int,scores_per_test: ScoresPerTest): |
|
""" This method registers a program in an island. It also updates the best program if needed. |
|
|
|
:param program: The program to register |
|
:type program: AbstractArtifact |
|
:param island_id: The id of the island to register the program |
|
:type island_id: int |
|
:param scores_per_test: The scores per test of the program |
|
:type scores_per_test: ScoresPerTest |
|
""" |
|
if not program.calls_ancestor(artifact_to_evolve=self.artifact_to_evolve_name): |
|
|
|
if island_id is None: |
|
for id in range(self.num_islands): |
|
self._register_program_in_island(program=program,island_id=id,scores_per_test=scores_per_test) |
|
|
|
else: |
|
self._register_program_in_island(program=program,island_id=island_id,scores_per_test=scores_per_test) |
|
|
|
|
|
if time.time() - self.flow_state["last_reset_time"]> self.reset_period: |
|
self.reset_islands() |
|
self.flow_state["last_reset_time"] = time.time() |
|
|
|
def _register_program_in_island(self,program: AbstractArtifact, scores_per_test: ScoresPerTest, island_id: Optional[int] = None): |
|
""" This method registers a program in an island. It also updates the best program if needed. |
|
|
|
:param program: The program to register |
|
:type program: AbstractArtifact |
|
:param scores_per_test: The scores per test of the program |
|
:type scores_per_test: ScoresPerTest |
|
:param island_id: The id of the island to register the program |
|
:type island_id: Optional[int] |
|
""" |
|
self.flow_state["islands"][island_id].register_program(program,scores_per_test) |
|
|
|
scores_per_test_values = np.array([ score_per_test["score"] for score_per_test in scores_per_test.values()]) |
|
score = self.reduce_score_method(scores_per_test_values) |
|
|
|
if score > self.flow_state["best_score_per_island"][island_id]: |
|
self.flow_state["best_score_per_island"][island_id] = score |
|
self.flow_state["best_program_per_island"][island_id] = str(program) |
|
self.flow_state["best_scores_per_test_per_island"][island_id] = scores_per_test |
|
|
|
def get_best_programs(self) -> List[Dict[str,Any]]: |
|
""" This method returns the best programs per island.""" |
|
sorted_island_ids = np.argsort(np.array(self.flow_state["best_score_per_island"])) |
|
return { |
|
"best_island_programs": [ |
|
{ |
|
"rank": self.num_islands - rank, |
|
"score": self.flow_state["best_score_per_island"][island_id], |
|
"program": self.flow_state["best_program_per_island"][island_id], |
|
"island_id": int(island_id), |
|
} |
|
for rank,island_id in enumerate(sorted_island_ids) |
|
] |
|
} |
|
|
|
|
|
def run(self,input_message: FlowMessage): |
|
""" This method runs the flow. It performs the operation requested in the input message.""" |
|
input_data = input_message.data |
|
operation = input_data["operation"] |
|
content = input_data["content"] |
|
|
|
possible_operations = [ |
|
"register_program", |
|
"get_prompt", |
|
"get_best_programs_per_island", |
|
] |
|
|
|
if operation not in possible_operations: |
|
raise ValueError(f"operation must be one of the following: {possible_operations}") |
|
|
|
response = {} |
|
if operation == "get_prompt": |
|
response["retrieved"] = False |
|
|
|
if not self.flow_state["first_program_registered"]: |
|
response["retrieved"] = False |
|
|
|
else: |
|
code,version_generated,island_id = self.get_prompt() |
|
|
|
response["retrieved"] = { |
|
"code":code, |
|
"version_generated": version_generated, |
|
"island_id":island_id, |
|
"header": self.evaluate_file_full_content |
|
} |
|
|
|
elif operation == "register_program": |
|
try: |
|
|
|
artifact = text_to_artifact(content["artifact"]) |
|
island_id = content.get("island_id",None) |
|
scores_per_test = content["scores_per_test"] |
|
if scores_per_test is not None: |
|
self.register_program(program=artifact,island_id=island_id,scores_per_test=scores_per_test) |
|
response["retrieved"] = "Program registered" |
|
self.flow_state["first_program_registered"] = True |
|
except: |
|
response["retrieved"] = "Program failed to register" |
|
else: |
|
response["retrieved"] = self.get_best_programs() |
|
|
|
response["from"] = "ProgramDBFlow" |
|
reply = self.package_output_message( |
|
input_message, |
|
response |
|
) |
|
|
|
self.send_message(reply) |