File size: 4,906 Bytes
97e363b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
This code is an implementation of Funsearch (https://www.nature.com/articles/s41586-023-06924-6) and is heavily inspired by the original code (https://github.com/google-deepmind/funsearch)

    **Citation**:
    
    @Article{FunSearch2023,
        author  = {Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Balog, Matej and Kumar, M. Pawan and Dupont, Emilien and Ruiz, Francisco J. R. and Ellenberg, Jordan and Wang, Pengming and Fawzi, Omar and Kohli, Pushmeet and Fawzi, Alhussein},
        journal = {Nature},
        title   = {Mathematical discoveries from program search with large language models},
        year    = {2023},
        doi     = {10.1038/s41586-023-06924-6}
    }
"""

from . import AbstractArtifact
import dataclasses
import tokenize
import io
from collections.abc import Iterator, MutableSet, Sequence

@dataclasses.dataclass
class FunctionArtifact(AbstractArtifact):
    def __str__(self) -> str:
        
        return_type = f' -> {self.return_type}' if self.return_type else ''
        function = f'def {self.name}({self.args}){return_type}:\n'
        
        if self.docstring:
        # self.docstring is already indented on every line except the first one.
        # Here, we assume the indentation is always two spaces.
            new_line = '\n' if self.body else ''
            function += f'    """{self.docstring}"""{new_line}'
            
        # self.body is already indented.
        function += self.body + '\n\n'
        return function
    
    @staticmethod    
    def _tokenize(code: str) -> Iterator[tokenize.TokenInfo]:
        """Transforms `code` into Python tokens."""
        code_bytes = code.encode()
        code_io = io.BytesIO(code_bytes)
        return tokenize.tokenize(code_io.readline)

    @staticmethod
    def _untokenize(tokens: Sequence[tokenize.TokenInfo]) -> str:
        """Transforms a list of Python tokens into code."""
        code_bytes = tokenize.untokenize(tokens)
        return code_bytes.decode()
    
    def _get_artifacts_called(self) -> MutableSet[str]:
        """Returns the set of all functions called in function."""
        code = str(self.body)
        return set(token.string for token, is_call in
                    self._yield_token_and_is_call(code) if is_call)
    
    def calls_ancestor(self,artifact_to_evolve: str) -> bool:
        """Returns whether the generated function is calling an earlier version."""
        
        for name in self._get_artifacts_called():
            # In `program` passed into this function the most recently generated
            # function has already been renamed to `function_to_evolve` (wihout the
            # suffix). Therefore any function call starting with `function_to_evolve_v`
            # is a call to an ancestor function.
            if name.startswith(f'{artifact_to_evolve}_v') and not name.startswith(self.name):
                return True
        return False

    
    def _yield_token_and_is_call(cls,code: str) -> Iterator[tuple[tokenize.TokenInfo, bool]]:
        """Yields each token with a bool indicating whether it is a function call."""
        
        tokens = cls._tokenize(code)
        prev_token = None
        is_attribute_access = False
        for token in tokens:
            if (prev_token and  # If the previous token exists and
                prev_token.type == tokenize.NAME and  # it is a Python identifier
                token.type == tokenize.OP and  # and the current token is a delimiter
                token.string == "("
                ):  # and in particular it is '('.
                yield prev_token, not is_attribute_access
                is_attribute_access = False
            else:
                if prev_token:
                    is_attribute_access = (
                        prev_token.type == tokenize.OP and prev_token.string == '.'
                    )
                    yield prev_token, False
                    
            prev_token = token
        if prev_token:
            yield prev_token, False

    
    def rename_artifact_calls(self, source_name, target_name) -> str:
        implementation = str(self)
        
        if source_name not in implementation:
            return implementation
        
        modified_tokens = []
        for token, is_call in self._yield_token_and_is_call(implementation):
            if is_call and token.string == source_name:
                # Replace the function name token
                modified_token = tokenize.TokenInfo(
                    type=token.type,
                    string=target_name,
                    start=token.start,
                    end=token.end,
                    line=token.line,
                )
                modified_tokens.append(modified_token)
            else:
                modified_tokens.append(token)
        return self._untokenize(modified_tokens)