File size: 21,171 Bytes
97e363b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
from aiflows.base_flows import CompositeFlow
from aiflows.utils import logging
from aiflows.interfaces import KeyInterface
from aiflows.messages import FlowMessage
from typing import Dict, Any
log = logging.get_logger(f"aiflows.{__name__}")

class FunSearch(CompositeFlow):
    """ This class implements FunSearch. This code is an implementation of Funsearch (https://www.nature.com/articles/s41586-023-06924-6) and is heavily inspired by the original code (https://github.com/google-deepmind/funsearch) . It's a Flow in charge of starting, stopping and managing (passing around messages) the FunSearch process. It passes messages around to the following subflows:
    
    - ProgramDBFlow: which is in charge of storing and retrieving programs.
    - SamplerFlow: which is in charge of sampling programs.
    - EvaluatorFlow: which is in charge of evaluating programs.
    
    *Configuration Parameters*:
    
    - `name` (str): The name of the flow. Default: "FunSearchFlow".
    - `description` (str): The description of the flow. Default: "A flow implementing FunSearch"
    - `subflows_config` (Dict[str,Any]): A dictionary of subflows configurations. Default:
        - `ProgramDBFlow`: By default, it uses the `ProgramDBFlow` class and uses its default parameters.
        - `SamplerFlow`: By default, it uses the `SamplerFlow` class and uses its default parameters.
        - `EvaluatorFlow`: By default, it uses the `EvaluatorFlow` class and uses its default parameters.
        
    **Input Interface**:
    
    - `from` (str): The flow from which the message is coming from. It can be one of the following: "FunSearch", "SamplerFlow", "EvaluatorFlow", "ProgramDBFlow".
    - `operation` (str): The operation to perform. It can be one of the following: "start", "stop", "get_prompt", "get_best_programs_per_island", "register_program".
    - `content` (Dict[str,Any]): The content associated to an operation. Here is the expected content for each operation:
        - "start":
            - `num_samplers` (int): The number of samplers to start up. Note that it's still restricted by the number of workers available. Default: 1.
        - "stop":
            - No content. Pass either an empty dictionary or None. Works also with no content.
        - "get_prompt":
            - No content. Pass either an empty dictionary or None. Works also with no content.
        - "get_best_programs_per_island":
            - No content. Pass either an empty dictionary or None. Works also with no content.
    
    **Output Interface**:
    
    - `retrieved` (Dict[str,Any]): The retrieved data.
    
    **Citation**:
    
    @Article{FunSearch2023,
        author  = {Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Balog, Matej and Kumar, M. Pawan and Dupont, Emilien and Ruiz, Francisco J. R. and Ellenberg, Jordan and Wang, Pengming and Fawzi, Omar and Kohli, Pushmeet and Fawzi, Alhussein},
        journal = {Nature},
        title   = {Mathematical discoveries from program search with large language models},
        year    = {2023},
        doi     = {10.1038/s41586-023-06924-6}
    }
    """
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        #next state per action
        #this is a dictionary that maps the next state of the flow based on the action and the current state
        self.next_state_per_action = {
            "get_prompt": {
                "FunSearch": "ProgramDBFlow",
                "ProgramDBFlow": "SamplerFlow",
            },
            
            "get_best_programs_per_island": {
                "FunSearch": "ProgramDBFlow",
                "ProgramDBFlow": "GenerateReply",
            },
            "register_program": {
                "SamplerFlow": "EvaluatorFlow",
                "EvaluatorFlow": "ProgramDBFlow",
            },
            "start":
                {"FunSearch": "FunSearch"},
            "stop":
                {"FunSearch": "FunSearch"},
        }
        
        #key interface to make a request for a prompt
        self.make_request_for_prompt_data = KeyInterface(
            keys_to_set= {"operation": "get_prompt", "content": {}, "from": "FunSearch"},
            keys_to_select= ["operation", "content", "from"]
        )
    
    def make_request_for_prompt(self):
        """ This method makes a request for a prompt. It sends a message to itself with the operation "get_prompt" which will trigger the flow to call the `ProgramDBFlow` to get a prompt. """
        
        #Prepare data to make request for prompt
        data = self.make_request_for_prompt_data({})
        
        #Package message to make request for prompt
        msg = self.package_input_message(
            data=data,
            dst_flow="FunSearch"
        )
        #Send message to itself to start the process of getting a prompt
        self.send_message(
            msg
        )
        
    def request_samplers(self,input_message: FlowMessage):
        """ This method requests samplers. It sends a message to itself with the operation "get_prompt" which will trigger the flow to call the `ProgramDBFlow` to get a prompt.
        
        :param input_message: The input message that triggered the request for samplers.
        :type input_message: FlowMessage
        """
        
        #Get state associated with the message
        message_state = self.pop_message_from_state(input_message.input_message_id)
        #Get number of samplers to request
        num_samplers = message_state["content"].get("num_samplers",1)
        for i in range(num_samplers):
            self.make_request_for_prompt()
    
    def get_next_state(self, input_message: FlowMessage):
        """ This method determines the next state of the flow based on the input message. It will return the next state based on the current state and the message received.
        
        :param input_message: The input message that triggered the request for the next state.
        :type input_message: FlowMessage
        :return: The next state of the flow.
        :rtype: str
        """
        #Get state associated with the message
        message_state = self.get_message_from_state(input_message.input_message_id)
        message_from = message_state["from"]
        operation = message_state["operation"]
        #Get next state based on the action and the current state
        next_state = self.next_state_per_action[operation][message_from]
        return next_state 
    
    def set_up_flow_state(self):
        """ This method sets up the state of the flow. It's called at the beginning of the flow."""
        super().set_up_flow_state()
        #Dictonary containing state of message currently being handled by FunSearch
        #Each message has its own state in the flow state
        #Once a message is done being handled, it's removed from the state
        self.flow_state["msg_requests"] = {}
        #Flag to keep track if the first sample has been saved to the db
        self.flow_state["first_sample_saved_to_db"] = False
        #Flag to keep track if FunSearch is running
        self.flow_state["funsearch_running"] = False
        
    def save_message_to_state(self,msg_id: str, message: FlowMessage):
        """ This method saves a message to the state of the flow. It's used to keep track of state on a per message basis (i.e., state of the flow depending on the message received and id).
        
        :param msg_id: The id of the message to save.
        :type msg_id: str
        :param message: The message to save.
        :type message: FlowMessage
        """
        self.flow_state["msg_requests"][msg_id] = {"og_message":  message}
        
    def rename_key_message_in_state(self, old_key: str, new_key: str):
        """ This method renames a key in the state of the flow in the "msg_requests" dictonary. It's used to rename a key in the state of the flow (i.e., rename a message id).
        
        :param old_key: The old key to rename.
        :type old_key: str
        :param new_key: The new key to rename to.
        :type new_key: str
        """
        self.flow_state["msg_requests"][new_key]  = self.flow_state["msg_requests"].pop(old_key) 

    def message_in_state(self,msg_id: str) -> bool:
        """ This method checks if a message is in the state of the flow (in "msg_requests" dictionary). It returns True if the message is in the state, otherwise it returns False.
        
        :param msg_id: The id of the message to check if it's in the state.
        :type msg_id: str
        :return: True if the message is in the state, otherwise False.
        :rtype: bool
        """
        
        return msg_id in self.flow_state["msg_requests"].keys()

    def get_message_from_state(self, msg_id: str) -> Dict[str,Any]:
        """ This method returns the state associated with a message id in the state of the flow (in "msg_requests" dictionary).
        
        :param msg_id: The id of the message to get the state from.
        :type msg_id: str
        :return: The state associated with the message id.
        :rtype: Dict[str,Any]
        """
        return self.flow_state["msg_requests"][msg_id]

    def pop_message_from_state(self, msg_id: str) -> Dict[str,Any]:
        """ This method pops a message from the state of the flow (in "msg_requests" dictionary). It the state associate to a message and removes it from the state.
        
        :param msg_id: The id of the message to pop from the state.
        :type msg_id: str
        :return: The state associated with the message id.
        :rtype: Dict[str,Any]
        """
        return self.flow_state["msg_requests"].pop(msg_id)
    
    def merge_message_request_state(self,id: str, new_states: Dict[str,Any]):
        """ This method merges new states to a message in the state of the flow (in "msg_requests" dictionary). It merges new states to a message in the state.
        
        :param id: The id of the message to merge new states to.
        :type id: str
        :param new_states: The new states to merge to the message.
        :type new_states: Dict[str,Any]
        """
        self.flow_state["msg_requests"][id] = {**self.flow_state["msg_requests"][id], **new_states}
    
    def register_data_to_state(self, input_message: FlowMessage):
        """This method registers the input message data to the flow state. It's called everytime a new input message is received.
        
        :param input_message: The input message
        :type input_message: FlowMessage
        """

        #Determine Who the message is from (should be either FunSearch, SamplerFlow, EvaluatorFlow, or ProgramDBFlow)
        msg_from = input_message.data.get("from", "FunSearch")
        #Check if this a first request or part of a message that is being handled (it's part of message being handled if message is in the state)
        msg_id = input_message.input_message_id   
        msg_in_state = self.message_in_state(msg_id)
        
        #If message is not in state, save it to state
        if not msg_in_state:
            self.save_message_to_state(msg_id, input_message)
        
        #Get the state associated to the message
        message_state = self.get_message_from_state(msg_id)
        
        #Determine what to do based on who the message is from
        
        if msg_from == "FunSearch":
            #Calls From FunSearch expect operation and content
            operation = input_message.data["operation"]
            content = input_message.data.get("content",{})
            to_add_to_state = {
                "content": content,
                "operation": operation
            }
            #save operation and content to state   
            self.merge_message_request_state(msg_id, to_add_to_state)
            
        elif msg_from == "SamplerFlow":
            #Calls From SamplerFlow expect api_output, merge it to state
            to_add_to_state = {
                "content": {
                    **message_state.get("content",{}),
                    **{"artifact": input_message.data["api_output"]}
                },
                "operation": "register_program"
            }
            self.merge_message_request_state(msg_id, to_add_to_state)
                        
        elif msg_from == "EvaluatorFlow":
            #Calls From EvaluatorFlow expect scores_per_test, merge it to state
            message_state = self.get_message_from_state(msg_id)
            to_add_to_state = {
                "content": {
                    **message_state.get("content",{}),
                    **{"scores_per_test": input_message.data["scores_per_test"]}
                }
            }
            self.merge_message_request_state(msg_id, to_add_to_state)
                        
        elif msg_from == "ProgramDBFlow":
            #Calls From ProgramDBFlow expect retrieved, merge it to state
            to_add_to_state = {
                "retrieved": input_message.data["retrieved"],
            }
            
            #if message from ProgramDBFlow is associate to a "get_prompt" operation, 
            # save island_id to state
            if message_state["operation"] == "get_prompt":
                island_id = input_message.data["retrieved"]["island_id"]
                to_add_to_state["content"] = {
                    **message_state.get("content",{}),
                    **{"island_id": island_id}
                }
                
            self.merge_message_request_state(msg_id, to_add_to_state) 
        
        #save from to state
        self.merge_message_request_state(msg_id, {"from": msg_from})

    def call_program_db(self, input_message):
        """ This method calls the ProgramDBFlow. It sends a message to the ProgramDBFlow with the data of the input message.
        
        :param input_message: The input message to send to the ProgramDBFlow.
        :type input_message: FlowMessage
        """
        
        #Fetch state associated with the message
        msg_id = input_message.input_message_id
        message_state = self.get_message_from_state(input_message.input_message_id)

        #Get operation and content from state to send to ProgramDBFlow
        operation = message_state["operation"]
        content = message_state.get("content", {})
        
        data = {
            "operation": operation,
            "content": content
        }
        #package message to send to ProgramDBFlow
        msg = self.package_input_message(
            data = data,
            dst_flow = "ProgramDBFlow"
        )      
        
        #If operation is "register_program",
        # pop message from state (because inital message has been fully handled) and set first_sample_saved_to_db to True
        #Send a message to register program without expecting a reply (no need to wait for a reply, just save to db and move on)
        if data["operation"] == "register_program":
            self.pop_message_from_state(msg_id)
            
            self.flow_state["first_sample_saved_to_db"] = True
            
            self.subflows["ProgramDBFlow"].send_message(
                msg
            )
        
        # If operation is "get_prompt" or "get_best_programs_per_island"
        # rename key in state to new message id (in order to be able to track of the message in state when the reply arrives)
        elif data["operation"] in ["get_prompt","get_best_programs_per_island"]:
            self.rename_key_message_in_state(msg_id, msg.message_id)  
            #if no sample has been saved to db, Send input message back to itself (to try again, hopefully this time a sample will be saved to db)
            if not self.flow_state["first_sample_saved_to_db"]:
                #send back to itself message (to try again)
                self.send_message(
                    input_message
                )
            #If a sample has been saved to db, send message to ProgramDBFlow to fetch prompt or best programs per island
            else:
                self.subflows["ProgramDBFlow"].get_reply(
                    msg
                )
        #If operation is not "register_program", "get_prompt" or "get_best_programs_per_island"
        else:
            log.error("No operation found, input_message received: \n" + str(input_message))
        
    def call_evaluator(self, input_message):
        """ This method calls the EvaluatorFlow. It sends a message to the EvaluatorFlow with the data of the input message.
        
        :param input_message: The input message to send to the EvaluatorFlow.
        :type input_message: FlowMessage
        """
    
        #Fetch state associated with the message
        msg_id = input_message.input_message_id 
        message_state = self.get_message_from_state(msg_id)
        
        #Get data to send to EvaluatorFlow (artifact generated by Sampler to be evaluated)
        data = {
            "artifact": message_state["content"]["artifact"]
        }
            
        msg = self.package_input_message(
            data = data,
            dst_flow = "EvaluatorFlow"
        )
        # rename key in state to new message id (in order to be able to track of the message in state when the reply arrives)
        self.rename_key_message_in_state(msg_id, msg.message_id)  
        #Send message to EvaluatorFlow and expect a reply to be sent back to FunSearch's input message queue    
        self.subflows["EvaluatorFlow"].get_reply(
            msg
        )
        
    def call_sampler(self, input_message):
        """ This method calls the SamplerFlow. It sends a message to the SamplerFlow with the data of the input message.
        
        :param input_message: The input message to send to the SamplerFlow.
        :type input_message: FlowMessage
        """
        
        #Fetch state associated with the message
        msg_id = input_message.input_message_id
        message_state = self.get_message_from_state(msg_id)
        
        #Get data to send to SamplerFlow (prompt to generate a program)
        data = {
            **message_state["retrieved"],
        }
        msg = self.package_input_message(
            data = data,
            dst_flow = "SamplerFlow"
        )
        # rename key in state to new message id (in order to be able to track of the message in state when the reply arrives)
        self.rename_key_message_in_state(msg_id, msg.message_id)        
        
        #send message to SamplerFlow and expect a reply to be sent back to FunSearch's input message queue
        self.subflows["SamplerFlow"].get_reply(
            msg
        )
        #If FunSearch is running, make a new request for a prompt (to keep the process going)
        if self.flow_state["funsearch_running"]:
            self.make_request_for_prompt()
        
        
    def generate_reply(self, input_message: FlowMessage):
        """ This method generates a reply to a message sent to user. It packages the output message and sends it.
        
        :param input_message: The input message to generate a reply to.
        :type input_message: FlowMessage
        """
        
        #Fetch state associated with the message
        msg_id = input_message.input_message_id
        message_state = self.pop_message_from_state(msg_id)
        #Prepare response to send to user (due to a call to get_best_programs_per_island)
        response = {
            "retrieved": message_state["retrieved"]
        }
        reply = self.package_output_message(
            message_state["og_message"],
            response
        )
        
        self.send_message(
            reply
        )
        
    def run(self,input_message: FlowMessage):
        """ This method runs the flow. It's the main method of the flow. It's called when the flow is executed.
        
        :input_message: The input message of the flow
        :type input_message: Message
        """
        self.register_data_to_state(input_message)
        
        next_state = self.get_next_state(input_message)
        
        if next_state == "ProgramDBFlow":
            self.call_program_db(input_message)
            
        elif next_state == "EvaluatorFlow":
            self.call_evaluator(input_message)
            
        elif next_state == "SamplerFlow":
            self.call_sampler(input_message)
        
        elif next_state == "GenerateReply":
            self.generate_reply(input_message)
            
        elif next_state == "FunSearch":
            #If operation is "start", set funsearch_running to True and make a request for a prompt
            if input_message.data["operation"] == "start":
                self.flow_state["funsearch_running"] = True
                self.request_samplers(input_message)
            #If operation is "stop", set funsearch_running to False (will stop the process of generating new samples)
            elif input_message.data["operation"] == "stop":
                self.flow_state["funsearch_running"] = False
        else:
            log.error("No next state found, input_message received: \n" + str(input_message))