|
from aiflows.base_flows import CompositeFlow |
|
from aiflows.utils import logging |
|
from aiflows.interfaces import KeyInterface |
|
from aiflows.messages import FlowMessage |
|
from typing import Dict, Any |
|
log = logging.get_logger(f"aiflows.{__name__}") |
|
|
|
class FunSearch(CompositeFlow): |
|
""" This class implements FunSearch. This code is an implementation of Funsearch (https://www.nature.com/articles/s41586-023-06924-6) and is heavily inspired by the original code (https://github.com/google-deepmind/funsearch) . It's a Flow in charge of starting, stopping and managing (passing around messages) the FunSearch process. It passes messages around to the following subflows: |
|
|
|
- ProgramDBFlow: which is in charge of storing and retrieving programs. |
|
- SamplerFlow: which is in charge of sampling programs. |
|
- EvaluatorFlow: which is in charge of evaluating programs. |
|
|
|
*Configuration Parameters*: |
|
|
|
- `name` (str): The name of the flow. Default: "FunSearchFlow". |
|
- `description` (str): The description of the flow. Default: "A flow implementing FunSearch" |
|
- `subflows_config` (Dict[str,Any]): A dictionary of subflows configurations. Default: |
|
- `ProgramDBFlow`: By default, it uses the `ProgramDBFlow` class and uses its default parameters. |
|
- `SamplerFlow`: By default, it uses the `SamplerFlow` class and uses its default parameters. |
|
- `EvaluatorFlow`: By default, it uses the `EvaluatorFlow` class and uses its default parameters. |
|
|
|
**Input Interface**: |
|
|
|
- `from` (str): The flow from which the message is coming from. It can be one of the following: "FunSearch", "SamplerFlow", "EvaluatorFlow", "ProgramDBFlow". |
|
- `operation` (str): The operation to perform. It can be one of the following: "start", "stop", "get_prompt", "get_best_programs_per_island", "register_program". |
|
- `content` (Dict[str,Any]): The content associated to an operation. Here is the expected content for each operation: |
|
- "start": |
|
- `num_samplers` (int): The number of samplers to start up. Note that it's still restricted by the number of workers available. Default: 1. |
|
- "stop": |
|
- No content. Pass either an empty dictionary or None. Works also with no content. |
|
- "get_prompt": |
|
- No content. Pass either an empty dictionary or None. Works also with no content. |
|
- "get_best_programs_per_island": |
|
- No content. Pass either an empty dictionary or None. Works also with no content. |
|
|
|
**Output Interface**: |
|
|
|
- `retrieved` (Dict[str,Any]): The retrieved data. |
|
|
|
**Citation**: |
|
|
|
@Article{FunSearch2023, |
|
author = {Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Balog, Matej and Kumar, M. Pawan and Dupont, Emilien and Ruiz, Francisco J. R. and Ellenberg, Jordan and Wang, Pengming and Fawzi, Omar and Kohli, Pushmeet and Fawzi, Alhussein}, |
|
journal = {Nature}, |
|
title = {Mathematical discoveries from program search with large language models}, |
|
year = {2023}, |
|
doi = {10.1038/s41586-023-06924-6} |
|
} |
|
""" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
self.next_state_per_action = { |
|
"get_prompt": { |
|
"FunSearch": "ProgramDBFlow", |
|
"ProgramDBFlow": "SamplerFlow", |
|
}, |
|
|
|
"get_best_programs_per_island": { |
|
"FunSearch": "ProgramDBFlow", |
|
"ProgramDBFlow": "GenerateReply", |
|
}, |
|
"register_program": { |
|
"SamplerFlow": "EvaluatorFlow", |
|
"EvaluatorFlow": "ProgramDBFlow", |
|
}, |
|
"start": |
|
{"FunSearch": "FunSearch"}, |
|
"stop": |
|
{"FunSearch": "FunSearch"}, |
|
} |
|
|
|
|
|
self.make_request_for_prompt_data = KeyInterface( |
|
keys_to_set= {"operation": "get_prompt", "content": {}, "from": "FunSearch"}, |
|
keys_to_select= ["operation", "content", "from"] |
|
) |
|
|
|
def make_request_for_prompt(self): |
|
""" This method makes a request for a prompt. It sends a message to itself with the operation "get_prompt" which will trigger the flow to call the `ProgramDBFlow` to get a prompt. """ |
|
|
|
|
|
data = self.make_request_for_prompt_data({}) |
|
|
|
|
|
msg = self.package_input_message( |
|
data=data, |
|
dst_flow="FunSearch" |
|
) |
|
|
|
self.send_message( |
|
msg |
|
) |
|
|
|
def request_samplers(self,input_message: FlowMessage): |
|
""" This method requests samplers. It sends a message to itself with the operation "get_prompt" which will trigger the flow to call the `ProgramDBFlow` to get a prompt. |
|
|
|
:param input_message: The input message that triggered the request for samplers. |
|
:type input_message: FlowMessage |
|
""" |
|
|
|
|
|
message_state = self.pop_message_from_state(input_message.input_message_id) |
|
|
|
num_samplers = message_state["content"].get("num_samplers",1) |
|
for i in range(num_samplers): |
|
self.make_request_for_prompt() |
|
|
|
def get_next_state(self, input_message: FlowMessage): |
|
""" This method determines the next state of the flow based on the input message. It will return the next state based on the current state and the message received. |
|
|
|
:param input_message: The input message that triggered the request for the next state. |
|
:type input_message: FlowMessage |
|
:return: The next state of the flow. |
|
:rtype: str |
|
""" |
|
|
|
message_state = self.get_message_from_state(input_message.input_message_id) |
|
message_from = message_state["from"] |
|
operation = message_state["operation"] |
|
|
|
next_state = self.next_state_per_action[operation][message_from] |
|
return next_state |
|
|
|
def set_up_flow_state(self): |
|
""" This method sets up the state of the flow. It's called at the beginning of the flow.""" |
|
super().set_up_flow_state() |
|
|
|
|
|
|
|
self.flow_state["msg_requests"] = {} |
|
|
|
self.flow_state["first_sample_saved_to_db"] = False |
|
|
|
self.flow_state["funsearch_running"] = False |
|
|
|
def save_message_to_state(self,msg_id: str, message: FlowMessage): |
|
""" This method saves a message to the state of the flow. It's used to keep track of state on a per message basis (i.e., state of the flow depending on the message received and id). |
|
|
|
:param msg_id: The id of the message to save. |
|
:type msg_id: str |
|
:param message: The message to save. |
|
:type message: FlowMessage |
|
""" |
|
self.flow_state["msg_requests"][msg_id] = {"og_message": message} |
|
|
|
def rename_key_message_in_state(self, old_key: str, new_key: str): |
|
""" This method renames a key in the state of the flow in the "msg_requests" dictonary. It's used to rename a key in the state of the flow (i.e., rename a message id). |
|
|
|
:param old_key: The old key to rename. |
|
:type old_key: str |
|
:param new_key: The new key to rename to. |
|
:type new_key: str |
|
""" |
|
self.flow_state["msg_requests"][new_key] = self.flow_state["msg_requests"].pop(old_key) |
|
|
|
def message_in_state(self,msg_id: str) -> bool: |
|
""" This method checks if a message is in the state of the flow (in "msg_requests" dictionary). It returns True if the message is in the state, otherwise it returns False. |
|
|
|
:param msg_id: The id of the message to check if it's in the state. |
|
:type msg_id: str |
|
:return: True if the message is in the state, otherwise False. |
|
:rtype: bool |
|
""" |
|
|
|
return msg_id in self.flow_state["msg_requests"].keys() |
|
|
|
def get_message_from_state(self, msg_id: str) -> Dict[str,Any]: |
|
""" This method returns the state associated with a message id in the state of the flow (in "msg_requests" dictionary). |
|
|
|
:param msg_id: The id of the message to get the state from. |
|
:type msg_id: str |
|
:return: The state associated with the message id. |
|
:rtype: Dict[str,Any] |
|
""" |
|
return self.flow_state["msg_requests"][msg_id] |
|
|
|
def pop_message_from_state(self, msg_id: str) -> Dict[str,Any]: |
|
""" This method pops a message from the state of the flow (in "msg_requests" dictionary). It the state associate to a message and removes it from the state. |
|
|
|
:param msg_id: The id of the message to pop from the state. |
|
:type msg_id: str |
|
:return: The state associated with the message id. |
|
:rtype: Dict[str,Any] |
|
""" |
|
return self.flow_state["msg_requests"].pop(msg_id) |
|
|
|
def merge_message_request_state(self,id: str, new_states: Dict[str,Any]): |
|
""" This method merges new states to a message in the state of the flow (in "msg_requests" dictionary). It merges new states to a message in the state. |
|
|
|
:param id: The id of the message to merge new states to. |
|
:type id: str |
|
:param new_states: The new states to merge to the message. |
|
:type new_states: Dict[str,Any] |
|
""" |
|
self.flow_state["msg_requests"][id] = {**self.flow_state["msg_requests"][id], **new_states} |
|
|
|
def register_data_to_state(self, input_message: FlowMessage): |
|
"""This method registers the input message data to the flow state. It's called everytime a new input message is received. |
|
|
|
:param input_message: The input message |
|
:type input_message: FlowMessage |
|
""" |
|
|
|
|
|
msg_from = input_message.data.get("from", "FunSearch") |
|
|
|
msg_id = input_message.input_message_id |
|
msg_in_state = self.message_in_state(msg_id) |
|
|
|
|
|
if not msg_in_state: |
|
self.save_message_to_state(msg_id, input_message) |
|
|
|
|
|
message_state = self.get_message_from_state(msg_id) |
|
|
|
|
|
|
|
if msg_from == "FunSearch": |
|
|
|
operation = input_message.data["operation"] |
|
content = input_message.data.get("content",{}) |
|
to_add_to_state = { |
|
"content": content, |
|
"operation": operation |
|
} |
|
|
|
self.merge_message_request_state(msg_id, to_add_to_state) |
|
|
|
elif msg_from == "SamplerFlow": |
|
|
|
to_add_to_state = { |
|
"content": { |
|
**message_state.get("content",{}), |
|
**{"artifact": input_message.data["api_output"]} |
|
}, |
|
"operation": "register_program" |
|
} |
|
self.merge_message_request_state(msg_id, to_add_to_state) |
|
|
|
elif msg_from == "EvaluatorFlow": |
|
|
|
message_state = self.get_message_from_state(msg_id) |
|
to_add_to_state = { |
|
"content": { |
|
**message_state.get("content",{}), |
|
**{"scores_per_test": input_message.data["scores_per_test"]} |
|
} |
|
} |
|
self.merge_message_request_state(msg_id, to_add_to_state) |
|
|
|
elif msg_from == "ProgramDBFlow": |
|
|
|
to_add_to_state = { |
|
"retrieved": input_message.data["retrieved"], |
|
} |
|
|
|
|
|
|
|
if message_state["operation"] == "get_prompt": |
|
island_id = input_message.data["retrieved"]["island_id"] |
|
to_add_to_state["content"] = { |
|
**message_state.get("content",{}), |
|
**{"island_id": island_id} |
|
} |
|
|
|
self.merge_message_request_state(msg_id, to_add_to_state) |
|
|
|
|
|
self.merge_message_request_state(msg_id, {"from": msg_from}) |
|
|
|
def call_program_db(self, input_message): |
|
""" This method calls the ProgramDBFlow. It sends a message to the ProgramDBFlow with the data of the input message. |
|
|
|
:param input_message: The input message to send to the ProgramDBFlow. |
|
:type input_message: FlowMessage |
|
""" |
|
|
|
|
|
msg_id = input_message.input_message_id |
|
message_state = self.get_message_from_state(input_message.input_message_id) |
|
|
|
|
|
operation = message_state["operation"] |
|
content = message_state.get("content", {}) |
|
|
|
data = { |
|
"operation": operation, |
|
"content": content |
|
} |
|
|
|
msg = self.package_input_message( |
|
data = data, |
|
dst_flow = "ProgramDBFlow" |
|
) |
|
|
|
|
|
|
|
|
|
if data["operation"] == "register_program": |
|
self.pop_message_from_state(msg_id) |
|
|
|
self.flow_state["first_sample_saved_to_db"] = True |
|
|
|
self.subflows["ProgramDBFlow"].send_message( |
|
msg |
|
) |
|
|
|
|
|
|
|
elif data["operation"] in ["get_prompt","get_best_programs_per_island"]: |
|
self.rename_key_message_in_state(msg_id, msg.message_id) |
|
|
|
if not self.flow_state["first_sample_saved_to_db"]: |
|
|
|
self.send_message( |
|
input_message |
|
) |
|
|
|
else: |
|
self.subflows["ProgramDBFlow"].get_reply( |
|
msg |
|
) |
|
|
|
else: |
|
log.error("No operation found, input_message received: \n" + str(input_message)) |
|
|
|
def call_evaluator(self, input_message): |
|
""" This method calls the EvaluatorFlow. It sends a message to the EvaluatorFlow with the data of the input message. |
|
|
|
:param input_message: The input message to send to the EvaluatorFlow. |
|
:type input_message: FlowMessage |
|
""" |
|
|
|
|
|
msg_id = input_message.input_message_id |
|
message_state = self.get_message_from_state(msg_id) |
|
|
|
|
|
data = { |
|
"artifact": message_state["content"]["artifact"] |
|
} |
|
|
|
msg = self.package_input_message( |
|
data = data, |
|
dst_flow = "EvaluatorFlow" |
|
) |
|
|
|
self.rename_key_message_in_state(msg_id, msg.message_id) |
|
|
|
self.subflows["EvaluatorFlow"].get_reply( |
|
msg |
|
) |
|
|
|
def call_sampler(self, input_message): |
|
""" This method calls the SamplerFlow. It sends a message to the SamplerFlow with the data of the input message. |
|
|
|
:param input_message: The input message to send to the SamplerFlow. |
|
:type input_message: FlowMessage |
|
""" |
|
|
|
|
|
msg_id = input_message.input_message_id |
|
message_state = self.get_message_from_state(msg_id) |
|
|
|
|
|
data = { |
|
**message_state["retrieved"], |
|
} |
|
msg = self.package_input_message( |
|
data = data, |
|
dst_flow = "SamplerFlow" |
|
) |
|
|
|
self.rename_key_message_in_state(msg_id, msg.message_id) |
|
|
|
|
|
self.subflows["SamplerFlow"].get_reply( |
|
msg |
|
) |
|
|
|
if self.flow_state["funsearch_running"]: |
|
self.make_request_for_prompt() |
|
|
|
|
|
def generate_reply(self, input_message: FlowMessage): |
|
""" This method generates a reply to a message sent to user. It packages the output message and sends it. |
|
|
|
:param input_message: The input message to generate a reply to. |
|
:type input_message: FlowMessage |
|
""" |
|
|
|
|
|
msg_id = input_message.input_message_id |
|
message_state = self.pop_message_from_state(msg_id) |
|
|
|
response = { |
|
"retrieved": message_state["retrieved"] |
|
} |
|
reply = self.package_output_message( |
|
message_state["og_message"], |
|
response |
|
) |
|
|
|
self.send_message( |
|
reply |
|
) |
|
|
|
def run(self,input_message: FlowMessage): |
|
""" This method runs the flow. It's the main method of the flow. It's called when the flow is executed. |
|
|
|
:input_message: The input message of the flow |
|
:type input_message: Message |
|
""" |
|
self.register_data_to_state(input_message) |
|
|
|
next_state = self.get_next_state(input_message) |
|
|
|
if next_state == "ProgramDBFlow": |
|
self.call_program_db(input_message) |
|
|
|
elif next_state == "EvaluatorFlow": |
|
self.call_evaluator(input_message) |
|
|
|
elif next_state == "SamplerFlow": |
|
self.call_sampler(input_message) |
|
|
|
elif next_state == "GenerateReply": |
|
self.generate_reply(input_message) |
|
|
|
elif next_state == "FunSearch": |
|
|
|
if input_message.data["operation"] == "start": |
|
self.flow_state["funsearch_running"] = True |
|
self.request_samplers(input_message) |
|
|
|
elif input_message.data["operation"] == "stop": |
|
self.flow_state["funsearch_running"] = False |
|
else: |
|
log.error("No next state found, input_message received: \n" + str(input_message)) |