dai
Merge remote-tracking branch 'my-hf/main'
0f7319f
from time import sleep
import ast
import astunparse
import openai
from openai.error import RateLimitError, APIConnectionError
from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import TerminalFormatter
class LMP:
def __init__(self, name, cfg, lmp_fgen, fixed_vars, variable_vars, md_logger):
self._name = name
self._cfg = cfg
self._md_logger = md_logger
with open(self._cfg['prompt_path'], 'r') as f:
self._base_prompt = f.read()
self._stop_tokens = list(self._cfg['stop'])
self._lmp_fgen = lmp_fgen
self._fixed_vars = fixed_vars
self._variable_vars = variable_vars
self.exec_hist = ''
def clear_exec_hist(self):
self.exec_hist = ''
def build_prompt(self, query, context=''):
if len(self._variable_vars) > 0:
variable_vars_imports_str = f"from utils import {', '.join(self._variable_vars.keys())}"
else:
variable_vars_imports_str = ''
prompt = self._base_prompt.replace('{variable_vars_imports}', variable_vars_imports_str)
if self._cfg['maintain_session']:
prompt += f'\n{self.exec_hist}'
if context != '':
prompt += f'\n{context}'
use_query = f'{self._cfg["query_prefix"]}{query}{self._cfg["query_suffix"]}'
prompt += f'\n{use_query}'
return prompt, use_query
def __call__(self, query, context='', **kwargs):
prompt, use_query = self.build_prompt(query, context=context)
while True:
try:
code_str = openai.Completion.create(
prompt=prompt,
stop=self._stop_tokens,
temperature=self._cfg['temperature'],
engine=self._cfg['engine'],
max_tokens=self._cfg['max_tokens']
)['choices'][0]['text'].strip()
break
except (RateLimitError, APIConnectionError) as e:
print(f'OpenAI API got err {e}')
print('Retrying after 2s.')
sleep(2)
if self._cfg['include_context'] and context != '':
to_exec = f'{context}\n{code_str}'
to_log = f'{context}\n{use_query}\n{code_str}'
else:
to_exec = code_str
to_log = f'{use_query}\n{to_exec}'
to_log_pretty = highlight(to_log, PythonLexer(), TerminalFormatter())
print(f'LMP {self._name} generated code:\n{to_log_pretty}')
self._md_logger.log_text(f'LMP {self._name} Generated Code:')
self._md_logger.log_code(to_log)
new_fs = self._lmp_fgen.create_new_fs_from_code(code_str)
self._variable_vars.update(new_fs)
gvars = merge_dicts([self._fixed_vars, self._variable_vars])
lvars = kwargs
if not self._cfg['debug_mode']:
exec_safe(to_exec, gvars, lvars)
self.exec_hist += f'\n{to_exec}'
if self._cfg['maintain_session']:
self._variable_vars.update(lvars)
if self._cfg['has_return']:
return lvars[self._cfg['return_val_name']]
class LMPFGen:
def __init__(self, cfg, fixed_vars, variable_vars, md_logger):
self._cfg = cfg
self._stop_tokens = list(self._cfg['stop'])
self._fixed_vars = fixed_vars
self._variable_vars = variable_vars
self._md_logger = md_logger
with open(self._cfg['prompt_path'], 'r') as f:
self._base_prompt = f.read()
def create_f_from_sig(self, f_name, f_sig, other_vars=None, fix_bugs=False, return_src=False):
print(f'Creating function: {f_sig}')
use_query = f'{self._cfg["query_prefix"]}{f_sig}{self._cfg["query_suffix"]}'
prompt = f'{self._base_prompt}\n{use_query}'
while True:
try:
f_src = openai.Completion.create(
prompt=prompt,
stop=self._stop_tokens,
temperature=self._cfg['temperature'],
engine=self._cfg['engine'],
max_tokens=self._cfg['max_tokens']
)['choices'][0]['text'].strip()
break
except (RateLimitError, APIConnectionError) as e:
print(f'OpenAI API got err {e}')
print('Retrying after 2s.')
sleep(2)
if fix_bugs:
f_src = openai.Edit.create(
model='code-davinci-edit-001',
input='# ' + f_src,
temperature=0,
instruction='Fix the bug if there is one. Improve readability. Keep same inputs and outputs. Only small changes. No comments.',
)['choices'][0]['text'].strip()
if other_vars is None:
other_vars = {}
gvars = merge_dicts([self._fixed_vars, self._variable_vars, other_vars])
lvars = {}
exec_safe(f_src, gvars, lvars)
f = lvars[f_name]
to_print = f'{use_query}\n{f_src}'
to_print_pretty = highlight(to_print, PythonLexer(), TerminalFormatter())
print(f'LMPFGen generated code:\n{to_print_pretty}')
self._md_logger.log_text('Generated Function:')
self._md_logger.log_code(to_print)
if return_src:
return f, f_src
return f
def create_new_fs_from_code(self, code_str, other_vars=None, fix_bugs=False, return_src=False):
fs, f_assigns = {}, {}
f_parser = FunctionParser(fs, f_assigns)
f_parser.visit(ast.parse(code_str))
for f_name, f_assign in f_assigns.items():
if f_name in fs:
fs[f_name] = f_assign
if other_vars is None:
other_vars = {}
new_fs = {}
srcs = {}
for f_name, f_sig in fs.items():
all_vars = merge_dicts([self._fixed_vars, self._variable_vars, new_fs, other_vars])
if not var_exists(f_name, all_vars):
f, f_src = self.create_f_from_sig(f_name, f_sig, new_fs, fix_bugs=fix_bugs, return_src=True)
# recursively define child_fs in the function body if needed
f_def_body = astunparse.unparse(ast.parse(f_src).body[0].body)
child_fs, child_f_srcs = self.create_new_fs_from_code(
f_def_body, other_vars=all_vars, fix_bugs=fix_bugs, return_src=True
)
if len(child_fs) > 0:
new_fs.update(child_fs)
srcs.update(child_f_srcs)
# redefine parent f so newly created child_fs are in scope
gvars = merge_dicts([self._fixed_vars, self._variable_vars, new_fs, other_vars])
lvars = {}
exec_safe(f_src, gvars, lvars)
f = lvars[f_name]
new_fs[f_name], srcs[f_name] = f, f_src
if return_src:
return new_fs, srcs
return new_fs
class FunctionParser(ast.NodeTransformer):
def __init__(self, fs, f_assigns):
super().__init__()
self._fs = fs
self._f_assigns = f_assigns
def visit_Call(self, node):
self.generic_visit(node)
if isinstance(node.func, ast.Name):
f_sig = astunparse.unparse(node).strip()
f_name = astunparse.unparse(node.func).strip()
self._fs[f_name] = f_sig
return node
def visit_Assign(self, node):
self.generic_visit(node)
if isinstance(node.value, ast.Call):
assign_str = astunparse.unparse(node).strip()
f_name = astunparse.unparse(node.value.func).strip()
self._f_assigns[f_name] = assign_str
return node
def var_exists(name, all_vars):
try:
eval(name, all_vars)
except:
exists = False
else:
exists = True
return exists
def merge_dicts(dicts):
return {
k : v
for d in dicts
for k, v in d.items()
}
def exec_safe(code_str, gvars=None, lvars=None):
banned_phrases = ['import', '__']
for phrase in banned_phrases:
assert phrase not in code_str
if gvars is None:
gvars = {}
if lvars is None:
lvars = {}
empty_fn = lambda *args, **kwargs: None
custom_gvars = merge_dicts([
gvars,
{'exec': empty_fn, 'eval': empty_fn}
])
exec(code_str, custom_gvars, lvars)