File size: 4,906 Bytes
97e363b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
"""
This code is an implementation of Funsearch (https://www.nature.com/articles/s41586-023-06924-6) and is heavily inspired by the original code (https://github.com/google-deepmind/funsearch)
**Citation**:
@Article{FunSearch2023,
author = {Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Balog, Matej and Kumar, M. Pawan and Dupont, Emilien and Ruiz, Francisco J. R. and Ellenberg, Jordan and Wang, Pengming and Fawzi, Omar and Kohli, Pushmeet and Fawzi, Alhussein},
journal = {Nature},
title = {Mathematical discoveries from program search with large language models},
year = {2023},
doi = {10.1038/s41586-023-06924-6}
}
"""
from . import AbstractArtifact
import dataclasses
import tokenize
import io
from collections.abc import Iterator, MutableSet, Sequence
@dataclasses.dataclass
class FunctionArtifact(AbstractArtifact):
def __str__(self) -> str:
return_type = f' -> {self.return_type}' if self.return_type else ''
function = f'def {self.name}({self.args}){return_type}:\n'
if self.docstring:
# self.docstring is already indented on every line except the first one.
# Here, we assume the indentation is always two spaces.
new_line = '\n' if self.body else ''
function += f' """{self.docstring}"""{new_line}'
# self.body is already indented.
function += self.body + '\n\n'
return function
@staticmethod
def _tokenize(code: str) -> Iterator[tokenize.TokenInfo]:
"""Transforms `code` into Python tokens."""
code_bytes = code.encode()
code_io = io.BytesIO(code_bytes)
return tokenize.tokenize(code_io.readline)
@staticmethod
def _untokenize(tokens: Sequence[tokenize.TokenInfo]) -> str:
"""Transforms a list of Python tokens into code."""
code_bytes = tokenize.untokenize(tokens)
return code_bytes.decode()
def _get_artifacts_called(self) -> MutableSet[str]:
"""Returns the set of all functions called in function."""
code = str(self.body)
return set(token.string for token, is_call in
self._yield_token_and_is_call(code) if is_call)
def calls_ancestor(self,artifact_to_evolve: str) -> bool:
"""Returns whether the generated function is calling an earlier version."""
for name in self._get_artifacts_called():
# In `program` passed into this function the most recently generated
# function has already been renamed to `function_to_evolve` (wihout the
# suffix). Therefore any function call starting with `function_to_evolve_v`
# is a call to an ancestor function.
if name.startswith(f'{artifact_to_evolve}_v') and not name.startswith(self.name):
return True
return False
def _yield_token_and_is_call(cls,code: str) -> Iterator[tuple[tokenize.TokenInfo, bool]]:
"""Yields each token with a bool indicating whether it is a function call."""
tokens = cls._tokenize(code)
prev_token = None
is_attribute_access = False
for token in tokens:
if (prev_token and # If the previous token exists and
prev_token.type == tokenize.NAME and # it is a Python identifier
token.type == tokenize.OP and # and the current token is a delimiter
token.string == "("
): # and in particular it is '('.
yield prev_token, not is_attribute_access
is_attribute_access = False
else:
if prev_token:
is_attribute_access = (
prev_token.type == tokenize.OP and prev_token.string == '.'
)
yield prev_token, False
prev_token = token
if prev_token:
yield prev_token, False
def rename_artifact_calls(self, source_name, target_name) -> str:
implementation = str(self)
if source_name not in implementation:
return implementation
modified_tokens = []
for token, is_call in self._yield_token_and_is_call(implementation):
if is_call and token.string == source_name:
# Replace the function name token
modified_token = tokenize.TokenInfo(
type=token.type,
string=target_name,
start=token.start,
end=token.end,
line=token.line,
)
modified_tokens.append(modified_token)
else:
modified_tokens.append(token)
return self._untokenize(modified_tokens)
|