nbaldwin's picture
first version FunSearch
97e363b
raw
history blame
14.3 kB
""" NOTE: THIS IS A BETA VERSION OF FUNSEARCH. NEW VERSION DOCUMENTATION WILL BE RELEASED SOON."""
from aiflows.base_flows import AtomicFlow
from .Island import Island,ScoresPerTest
import numpy as np
from typing import Callable,Dict,Union, Any,Optional, List
import time
from aiflows.utils import logging
from .artifacts import AbstractArtifact
from .Program import Program,text_to_artifact
import ast
import os
from aiflows.messages import FlowMessage
log = logging.get_logger(f"aiflows.{__name__}")
class ProgramDBFlow(AtomicFlow):
""" This class implements a ProgramDBFlow. It's a flow that stores programs and their scores in a database. It can also query the database for the best programs or generate a prompt containing stored programs in order to evolve them with a SamplerFlow. This code is an implementation of Funsearch (https://www.nature.com/articles/s41586-023-06924-6) and is heavily inspired by the original code (https://github.com/google-deepmind/funsearch)
**Configuration Parameters**:
- `name` (str): The name of the flow. Default: "ProgramDBFlow"
- `description` (str): A description of the flow. This description is used to generate the help message of the flow. Default: " A flow that saves programs in a database of islands"
- `artifact_to_evolve_name` (str): The name of the artifact/program to evolve. Default: "solve_function"
- `evaluate_function` (str): The function used to evaluate the program. No Default value. This MUST be passed as a parameter.
- `evaluate_file_full_content` (str): The full content of the file containing the evaluation function. No Default value. This MUST be passed as a parameter.
- `num_islands`: The number of islands to use. Default: 3
- `reset_period`: The period in seconds to reset the islands. Default: 3600
- `artifacts_per_prompt`: The number of previous artifacts/programs to include in a prompt. Default: 2
- `temperature`: The temperature of the island. Default: 0.1
- `temperature_period`: The period in seconds to change the temperature. Default: 30000
- `sample_with_replacement`: Whether to sample with replacement. Default: False
- `portion_of_islands_to_reset`: The portion of islands to reset. Default: 0.5
- `template` (dict): The template to use for a program. Default: {"preface": ""}
**Input Interface**:
- `operation` (str): The operation to perform. It can be one of the following: ["register_program","get_prompt","get_best_programs_per_island"]
**Output Interface**:
- `retrieved` (Any): The retrieved data. It can be one of the following:
- If the operation is "get_prompt", it can be a dictionary with the following keys
- `code` (str): The code of the prompt
- `version_generated` (int): The version of the prompt generated
- `island_id` (int): The id of the island that generated the prompt
- `header` (str): The header of the prompt
- If the operation is "register_program", it can be a string with the message "Program registered" or "Program failed to register"
- If the operation is "get_best_programs_per_island", it can be a dictionary with the following keys:
- `best_island_programs` (List[Dict[str,Any]]): A list of dictionaries with the following keys:
- `rank` (int): The rank of the program (1 is the best)
- `score` (float): The score of the program
- `program` (str): The program
- `island_id` (int): The id of the island that generated the program
**Citation**:
@Article{FunSearch2023,
author = {Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Balog, Matej and Kumar, M. Pawan and Dupont, Emilien and Ruiz, Francisco J. R. and Ellenberg, Jordan and Wang, Pengming and Fawzi, Omar and Kohli, Pushmeet and Fawzi, Alhussein},
journal = {Nature},
title = {Mathematical discoveries from program search with large language models},
year = {2023},
doi = {10.1038/s41586-023-06924-6}
}
"""
def __init__(self,
**kwargs
):
super().__init__(**kwargs)
#Unpack config (for clarity)
self.artifact_to_evolve_name: str = self.flow_config["artifact_to_evolve_name"]
self.artifacts_per_prompt: int = self.flow_config["artifacts_per_prompt"]
self.temperature: float = self.flow_config["temperature"]
self.temperature_period: int = self.flow_config["temperature_period"]
self.reduce_score_method: Callable = np.mean
self.sample_with_replacement: bool = self.flow_config["sample_with_replacement"]
self.num_islands: int = self.flow_config["num_islands"]
self.portion_of_islands_to_reset: float = self.flow_config["portion_of_islands_to_reset"]
self.reset_period: float = self.flow_config["reset_period"]
self.evaluate_function = self.flow_config["evaluate_function"]
self.evaluate_file_full_content = self.flow_config["evaluate_file_full_content"]
assert self.portion_of_islands_to_reset <= 1.0 and self.portion_of_islands_to_reset >= 0.0, "portion_of_islands_to_reset must be between 0 and 1"
self.islands_to_reset = int(round(self.portion_of_islands_to_reset * self.num_islands)) #round to nearest integer
def set_up_flow_state(self):
""" This method sets up the state of the flow and clears the previous messages."""
super().set_up_flow_state()
preface = \
self.flow_config["template"]["preface"] + "\n\n" "#function used to evaluate the program:\n" + self.flow_config["evaluate_function"] + "\n\n"
self.template: Program = Program(preface =preface,artifacts = [])
# ~~~instantiate isladns~~~
self.flow_state["islands"] = [
Island(
artifact_to_evolve_name =self.flow_config["artifact_to_evolve_name"],
artifacts_per_prompt = self.flow_config["artifacts_per_prompt"],
temperature = self.flow_config["temperature"],
temperature_period = self.flow_config["temperature_period"],
reduce_score_method = np.mean,
sample_with_replacement = self.flow_config["sample_with_replacement"],
template=self.template
)
for _ in range(self.flow_config["num_islands"])
]
self.flow_state["last_reset_time"] = time.time()
self.flow_state["best_score_per_island"] = [float("-inf") for _ in range(self.flow_config["num_islands"])]
self.flow_state["best_program_per_island"] = [None for _ in range(self.flow_config["num_islands"])]
self.flow_state["best_scores_per_test_per_island"] = [None for _ in range(self.flow_config["num_islands"])]
self.flow_state["first_program_registered"] = False
def get_prompt(self):
""" This method gets a prompt from an island. It returns the code, the version generated and the island id."""
island_id = np.random.choice(len(self.flow_state["islands"]))
code, version_generated = self.flow_state["islands"][island_id].get_prompt()
return code,version_generated,island_id
def reset_islands(self):
""" This method resets the islands. It resets the worst islands and copies the best programs to the worst islands."""
# gaussian noise to break ties
sorted_island_ids = np.argsort(
np.array(self.flow_state["best_score_per_island"]) +
(np.random.randn(len(self.flow_state["best_score_per_island"])) * 1e-6)
)
reset_island_ids = sorted_island_ids[:self.islands_to_reset]
keep_island_ids = sorted_island_ids[self.islands_to_reset:]
for island_id in reset_island_ids:
self.flow_state["islands"][island_id] = Island(
artifact_to_evolve_name =self.artifact_to_evolve_name,
artifacts_per_prompt = self.artifacts_per_prompt,
temperature = self.temperature,
temperature_period = self.temperature_period,
reduce_score_method = np.mean,
sample_with_replacement = self.sample_with_replacement,
template=self.template
)
self.flow_state["best_score_per_island"][island_id] = float("-inf")
founder_island_id = np.random.choice(keep_island_ids)
founder = self.flow_state["best_score_per_island"][founder_island_id]
founder_scores = self.flow_state["best_scores_per_test_per_island"][founder_island_id]
self._register_program_in_island(program=founder,island_id=island_id,scores_per_test=founder_scores)
def register_program(self, program: AbstractArtifact ,island_id: int,scores_per_test: ScoresPerTest):
""" This method registers a program in an island. It also updates the best program if needed.
:param program: The program to register
:type program: AbstractArtifact
:param island_id: The id of the island to register the program
:type island_id: int
:param scores_per_test: The scores per test of the program
:type scores_per_test: ScoresPerTest
"""
if not program.calls_ancestor(artifact_to_evolve=self.artifact_to_evolve_name):
#program added at the beggining, so add to all islands
if island_id is None:
for id in range(self.num_islands):
self._register_program_in_island(program=program,island_id=id,scores_per_test=scores_per_test)
else:
self._register_program_in_island(program=program,island_id=island_id,scores_per_test=scores_per_test)
#reset islands if needed
if time.time() - self.flow_state["last_reset_time"]> self.reset_period:
self.reset_islands()
self.flow_state["last_reset_time"] = time.time()
def _register_program_in_island(self,program: AbstractArtifact, scores_per_test: ScoresPerTest, island_id: Optional[int] = None):
""" This method registers a program in an island. It also updates the best program if needed.
:param program: The program to register
:type program: AbstractArtifact
:param scores_per_test: The scores per test of the program
:type scores_per_test: ScoresPerTest
:param island_id: The id of the island to register the program
:type island_id: Optional[int]
"""
self.flow_state["islands"][island_id].register_program(program,scores_per_test)
scores_per_test_values = np.array([ score_per_test["score"] for score_per_test in scores_per_test.values()])
score = self.reduce_score_method(scores_per_test_values)
if score > self.flow_state["best_score_per_island"][island_id]:
self.flow_state["best_score_per_island"][island_id] = score
self.flow_state["best_program_per_island"][island_id] = str(program)
self.flow_state["best_scores_per_test_per_island"][island_id] = scores_per_test
def get_best_programs(self) -> List[Dict[str,Any]]:
""" This method returns the best programs per island."""
sorted_island_ids = np.argsort(np.array(self.flow_state["best_score_per_island"]))
return {
"best_island_programs": [
{
"rank": self.num_islands - rank,
"score": self.flow_state["best_score_per_island"][island_id],
"program": self.flow_state["best_program_per_island"][island_id],
"island_id": int(island_id),
}
for rank,island_id in enumerate(sorted_island_ids)
]
}
def run(self,input_message: FlowMessage):
""" This method runs the flow. It performs the operation requested in the input message."""
input_data = input_message.data
operation = input_data["operation"]
content = input_data["content"]
possible_operations = [
"register_program",
"get_prompt",
"get_best_programs_per_island",
]
if operation not in possible_operations:
raise ValueError(f"operation must be one of the following: {possible_operations}")
response = {}
if operation == "get_prompt":
response["retrieved"] = False
if not self.flow_state["first_program_registered"]:
response["retrieved"] = False
else:
code,version_generated,island_id = self.get_prompt()
response["retrieved"] = {
"code":code,
"version_generated": version_generated,
"island_id":island_id,
"header": self.evaluate_file_full_content
}
elif operation == "register_program":
try:
artifact = text_to_artifact(content["artifact"])
island_id = content.get("island_id",None)
scores_per_test = content["scores_per_test"]
if scores_per_test is not None:
self.register_program(program=artifact,island_id=island_id,scores_per_test=scores_per_test)
response["retrieved"] = "Program registered"
self.flow_state["first_program_registered"] = True
except:
response["retrieved"] = "Program failed to register"
else:
response["retrieved"] = self.get_best_programs()
response["from"] = "ProgramDBFlow"
reply = self.package_output_message(
input_message,
response
)
self.send_message(reply)