File size: 14,296 Bytes
97e363b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
""" NOTE: THIS IS A BETA VERSION OF FUNSEARCH. NEW VERSION DOCUMENTATION WILL BE RELEASED SOON."""


from aiflows.base_flows import AtomicFlow
from .Island import Island,ScoresPerTest
import numpy as np
from typing import Callable,Dict,Union, Any,Optional, List
import time
from aiflows.utils import logging
from .artifacts import AbstractArtifact
from .Program import Program,text_to_artifact
import ast
import os
from aiflows.messages import FlowMessage
log = logging.get_logger(f"aiflows.{__name__}")

class ProgramDBFlow(AtomicFlow):
    """ This class implements a ProgramDBFlow. It's a flow that stores programs and their scores in a database. It can also query the database for the best programs or generate a prompt containing stored programs in order to evolve them with a SamplerFlow. This code is an implementation of Funsearch (https://www.nature.com/articles/s41586-023-06924-6) and is heavily inspired by the original code (https://github.com/google-deepmind/funsearch)
    
    **Configuration Parameters**:
    
    - `name` (str): The name of the flow. Default: "ProgramDBFlow"
    - `description` (str): A description of the flow. This description is used to generate the help message of the flow. Default: " A flow that saves programs in a database of islands"
    - `artifact_to_evolve_name` (str): The name of the artifact/program to evolve. Default: "solve_function"
    - `evaluate_function` (str): The function used to evaluate the program. No Default value. This MUST be passed as a parameter.
    - `evaluate_file_full_content` (str): The full content of the file containing the evaluation function. No Default value. This MUST be passed as a parameter.
    - `num_islands`: The number of islands to use. Default: 3
    - `reset_period`: The period in seconds to reset the islands. Default: 3600
    - `artifacts_per_prompt`: The number of previous artifacts/programs to include in a prompt. Default: 2
    - `temperature`: The temperature of the island. Default: 0.1
    - `temperature_period`: The period in seconds to change the temperature. Default: 30000
    - `sample_with_replacement`: Whether to sample with replacement. Default: False
    - `portion_of_islands_to_reset`: The portion of islands to reset. Default: 0.5
    - `template` (dict): The template to use for a program. Default: {"preface": ""}

    **Input Interface**:
    
    - `operation` (str): The operation to perform. It can be one of the following: ["register_program","get_prompt","get_best_programs_per_island"]
    
    **Output Interface**:
    
    - `retrieved` (Any): The retrieved data. It can be one of the following:
        - If the operation is "get_prompt", it can be a dictionary with the following keys
            - `code` (str): The code of the prompt
            - `version_generated` (int): The version of the prompt generated
            - `island_id` (int): The id of the island that generated the prompt
            - `header` (str): The header of the prompt
        - If the operation is "register_program", it can be a string with the message "Program registered" or "Program failed to register"
        - If the operation is "get_best_programs_per_island", it can be a dictionary with the following keys:
            - `best_island_programs` (List[Dict[str,Any]]): A list of dictionaries with the following keys:
                - `rank` (int): The rank of the program (1 is the best)
                - `score` (float): The score of the program
                - `program` (str): The program
                - `island_id` (int): The id of the island that generated the program
                
    **Citation**:
    
    @Article{FunSearch2023,
        author  = {Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Balog, Matej and Kumar, M. Pawan and Dupont, Emilien and Ruiz, Francisco J. R. and Ellenberg, Jordan and Wang, Pengming and Fawzi, Omar and Kohli, Pushmeet and Fawzi, Alhussein},
        journal = {Nature},
        title   = {Mathematical discoveries from program search with large language models},
        year    = {2023},
        doi     = {10.1038/s41586-023-06924-6}
    }
    """
    def __init__(self,
                 **kwargs
                ):
        super().__init__(**kwargs)
        
        #Unpack config (for clarity)
        self.artifact_to_evolve_name: str = self.flow_config["artifact_to_evolve_name"]
        
        self.artifacts_per_prompt: int = self.flow_config["artifacts_per_prompt"]
        self.temperature: float = self.flow_config["temperature"]
        self.temperature_period: int = self.flow_config["temperature_period"]
        self.reduce_score_method: Callable = np.mean 
        self.sample_with_replacement: bool = self.flow_config["sample_with_replacement"]
        self.num_islands: int = self.flow_config["num_islands"]
        self.portion_of_islands_to_reset: float = self.flow_config["portion_of_islands_to_reset"]
        self.reset_period: float = self.flow_config["reset_period"]
        
        self.evaluate_function = self.flow_config["evaluate_function"]
        self.evaluate_file_full_content = self.flow_config["evaluate_file_full_content"] 
        
        
        
        assert self.portion_of_islands_to_reset <= 1.0 and self.portion_of_islands_to_reset >= 0.0, "portion_of_islands_to_reset must be between 0 and 1"
        
        self.islands_to_reset = int(round(self.portion_of_islands_to_reset * self.num_islands)) #round to nearest integer
        
          
    def set_up_flow_state(self):
        """ This method sets up the state of the flow and clears the previous messages."""
        super().set_up_flow_state()
        
        preface = \
            self.flow_config["template"]["preface"] + "\n\n" "#function used to evaluate the program:\n" + self.flow_config["evaluate_function"] + "\n\n"
            
        self.template: Program = Program(preface =preface,artifacts = [])
        
        # ~~~instantiate isladns~~~
        self.flow_state["islands"] = [
            Island(
                artifact_to_evolve_name =self.flow_config["artifact_to_evolve_name"],
                artifacts_per_prompt = self.flow_config["artifacts_per_prompt"],
                temperature = self.flow_config["temperature"],
                temperature_period = self.flow_config["temperature_period"],
                reduce_score_method = np.mean,
                sample_with_replacement = self.flow_config["sample_with_replacement"],
                template=self.template
                )
        for _ in range(self.flow_config["num_islands"])
        ]
        
        self.flow_state["last_reset_time"] = time.time()
        self.flow_state["best_score_per_island"] = [float("-inf") for _ in range(self.flow_config["num_islands"])]
        self.flow_state["best_program_per_island"] = [None for _ in range(self.flow_config["num_islands"])]
        self.flow_state["best_scores_per_test_per_island"] = [None for _ in range(self.flow_config["num_islands"])]
        self.flow_state["first_program_registered"] = False
          
    def get_prompt(self):
        """ This method gets a prompt from an island. It returns the code, the version generated and the island id."""
        island_id = np.random.choice(len(self.flow_state["islands"]))
        code, version_generated = self.flow_state["islands"][island_id].get_prompt()
        return code,version_generated,island_id
    
    def reset_islands(self):
        """ This method resets the islands. It resets the worst islands and copies the best programs to the worst islands."""
        # gaussian noise to break ties
        sorted_island_ids = np.argsort(
            np.array(self.flow_state["best_score_per_island"]) +
            (np.random.randn(len(self.flow_state["best_score_per_island"])) * 1e-6)
        )
        
        reset_island_ids = sorted_island_ids[:self.islands_to_reset]
        keep_island_ids = sorted_island_ids[self.islands_to_reset:]
       
        for island_id in reset_island_ids:
            self.flow_state["islands"][island_id] = Island(
                               artifact_to_evolve_name =self.artifact_to_evolve_name,
                               artifacts_per_prompt = self.artifacts_per_prompt,
                               temperature = self.temperature,
                               temperature_period = self.temperature_period,
                               reduce_score_method = np.mean,
                               sample_with_replacement = self.sample_with_replacement,
                               template=self.template
                               )
            
            self.flow_state["best_score_per_island"][island_id] = float("-inf")
            founder_island_id = np.random.choice(keep_island_ids)
            founder = self.flow_state["best_score_per_island"][founder_island_id]
            founder_scores = self.flow_state["best_scores_per_test_per_island"][founder_island_id]
            self._register_program_in_island(program=founder,island_id=island_id,scores_per_test=founder_scores)
        
    
    
    def register_program(self, program: AbstractArtifact ,island_id: int,scores_per_test: ScoresPerTest):
        """ This method registers a program in an island. It also updates the best program if needed.
        
        :param program: The program to register
        :type program: AbstractArtifact
        :param island_id: The id of the island to register the program
        :type island_id: int
        :param scores_per_test: The scores per test of the program
        :type scores_per_test: ScoresPerTest
        """
        if not program.calls_ancestor(artifact_to_evolve=self.artifact_to_evolve_name):
            #program added at the beggining, so add to all islands
            if island_id is None:
                for id in range(self.num_islands):
                    self._register_program_in_island(program=program,island_id=id,scores_per_test=scores_per_test)
                    
            else:
                self._register_program_in_island(program=program,island_id=island_id,scores_per_test=scores_per_test)
            
        #reset islands if needed
        if time.time() - self.flow_state["last_reset_time"]> self.reset_period:
            self.reset_islands()
            self.flow_state["last_reset_time"] = time.time()
    
    def _register_program_in_island(self,program: AbstractArtifact, scores_per_test: ScoresPerTest, island_id: Optional[int] = None):
        """ This method registers a program in an island. It also updates the best program if needed.
        
        :param program: The program to register
        :type program: AbstractArtifact
        :param scores_per_test: The scores per test of the program
        :type scores_per_test: ScoresPerTest
        :param island_id: The id of the island to register the program
        :type island_id: Optional[int]
        """
        self.flow_state["islands"][island_id].register_program(program,scores_per_test)
        
        scores_per_test_values = np.array([ score_per_test["score"] for score_per_test in scores_per_test.values()])
        score = self.reduce_score_method(scores_per_test_values)
        
        if score > self.flow_state["best_score_per_island"][island_id]:
            self.flow_state["best_score_per_island"][island_id] = score
            self.flow_state["best_program_per_island"][island_id] = str(program)
            self.flow_state["best_scores_per_test_per_island"][island_id] = scores_per_test
    
    def get_best_programs(self) -> List[Dict[str,Any]]:
        """ This method returns the best programs per island."""
        sorted_island_ids = np.argsort(np.array(self.flow_state["best_score_per_island"]))
        return {
        "best_island_programs": [
            {
                "rank": self.num_islands - rank,
                "score": self.flow_state["best_score_per_island"][island_id],
                "program": self.flow_state["best_program_per_island"][island_id],
                "island_id": int(island_id),
                }
            for rank,island_id in enumerate(sorted_island_ids)
            ]
        }
    

    def run(self,input_message: FlowMessage):
        """ This method runs the flow. It performs the operation requested in the input message."""
        input_data = input_message.data
        operation = input_data["operation"]
        content = input_data["content"]
        
        possible_operations = [
            "register_program",
            "get_prompt",
            "get_best_programs_per_island",
        ]
        
        if operation not in possible_operations:
            raise ValueError(f"operation must be one of the following: {possible_operations}")
        
        response = {}
        if operation == "get_prompt":
            response["retrieved"] = False
            
            if not self.flow_state["first_program_registered"]:
                response["retrieved"] = False
            
            else:
                code,version_generated,island_id = self.get_prompt()
                
                response["retrieved"] = {
                    "code":code,
                    "version_generated": version_generated,
                    "island_id":island_id,
                    "header": self.evaluate_file_full_content
                }
            
        elif operation == "register_program":
            try:
                
                artifact = text_to_artifact(content["artifact"])
                island_id = content.get("island_id",None)
                scores_per_test = content["scores_per_test"]
                if scores_per_test is not None:
                    self.register_program(program=artifact,island_id=island_id,scores_per_test=scores_per_test)
                response["retrieved"] = "Program registered"
                self.flow_state["first_program_registered"] = True
            except:
                response["retrieved"] = "Program failed to register"
        else:
            response["retrieved"] = self.get_best_programs()
        
        response["from"] = "ProgramDBFlow"
        reply = self.package_output_message(
            input_message,
            response
        )
        
        self.send_message(reply)