Ron
commited on
Commit
·
41d1bc5
1
Parent(s):
40190c3
initial commit
Browse files- .gitignore +1 -0
- README.md +4 -5
- app.py +107 -2
- executors/__init__.py +2 -0
- executors/cargo_harness/Cargo.toml +8 -0
- executors/cargo_harness/src/.gitkeep +0 -0
- executors/executor_types.py +20 -0
- executors/executor_utils.py +46 -0
- executors/factory.py +8 -0
- executors/py_executor.py +88 -0
- generators/__init__.py +3 -0
- generators/factory.py +20 -0
- generators/generator_types.py +33 -0
- generators/generator_utils.py +286 -0
- generators/model.py +120 -0
- generators/parse.py +49 -0
- generators/py_generate.py +404 -0
- lats/.DS_Store +0 -0
- lats/lats.py +233 -0
- lats/lats_main.py +78 -0
- lats/requirements.txt +9 -0
- lats/utils.py +73 -0
- requirements.txt +9 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
**/__pycache__/
|
README.md
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
---
|
2 |
title: CodeLATS
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.27.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: mit
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: CodeLATS
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: yellow
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.27.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,4 +1,109 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
x = st.slider('Select a value')
|
4 |
-
st.write(x, 'squared is', x * x)
|
|
|
1 |
import streamlit as st
|
2 |
+
import openai
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import argparse
|
6 |
+
sys.path.append('./LATS')
|
7 |
+
from lats_main import lats_main
|
8 |
+
|
9 |
+
st.set_page_config(layout="wide")
|
10 |
+
|
11 |
+
# Initialize session state variables if they don't exist.
|
12 |
+
if 'response_content' not in st.session_state:
|
13 |
+
st.session_state.response_content = None
|
14 |
+
|
15 |
+
# Creating main columns for the chat and runtime notifications
|
16 |
+
chat_col = st.container()
|
17 |
+
|
18 |
+
chat_col.title("CodeLATS")
|
19 |
+
description = """This tech demo is an implementation of Language Agent Tree Search (LATS) (https://arxiv.org/abs/2310.04406) built specifically for generating code in the form of python functions. It achieves **state-of-the-art** results on HumanEval with a **94.4% pass@1 rate** on GPT-4.
|
20 |
+
|
21 |
+
Listed below is an example programming problem (https://leetcode.com/problems/longest-valid-parentheses/description/) to get started with.
|
22 |
+
|
23 |
+
```python
|
24 |
+
Given a string containing just the characters '(' and ')', return the length of the longest valid (well-formed) parentheses substring
|
25 |
+
```
|
26 |
+
NOTE: On average a call for a HumanEval or Leetcode question will cost around 5-30 cents on GPT-4, using the default parameters. This value may change depending on problem difficulty and parameters.
|
27 |
+
"""
|
28 |
+
|
29 |
+
chat_col.markdown(description)
|
30 |
+
sidebar = st.sidebar
|
31 |
+
# Runtime Section
|
32 |
+
runtime_container = st.container()
|
33 |
+
|
34 |
+
# Parameters Section
|
35 |
+
sidebar.title("**An AI@UIUC Project** (https://uiuc.ai/)")
|
36 |
+
parameters_section = sidebar.expander("Parameters", expanded=False)
|
37 |
+
tree_width = parameters_section.number_input("Tree Width", min_value=1, max_value=5, value=1)
|
38 |
+
tree_depth = parameters_section.number_input("Tree Depth", min_value=1, max_value=8, value=3)
|
39 |
+
iterations = parameters_section.number_input("Iterations", min_value=1, max_value=4, value=2)
|
40 |
+
key = st.sidebar.text_input("Enter your OpenAI Api Key:", type="password")
|
41 |
+
sidebar.markdown('<hr style="margin-top: 0.5rem; margin-bottom: 0.5rem;">', unsafe_allow_html=True)
|
42 |
+
|
43 |
+
with sidebar:
|
44 |
+
runtime_container = st.container()
|
45 |
+
runtime_container.empty()
|
46 |
+
|
47 |
+
runtime_messages = []
|
48 |
+
|
49 |
+
def make_args(instruction, tree_depth, tree_width, iterations):
|
50 |
+
parser = argparse.ArgumentParser()
|
51 |
+
|
52 |
+
parser.add_argument("--strategy", default="mcts", help="Strategy to use")
|
53 |
+
parser.add_argument("--language", default="py", help="Programming language")
|
54 |
+
parser.add_argument("--model", default="gpt-4", help="Model type")
|
55 |
+
parser.add_argument("--max_iters", default=iterations, help="Maximum iterations")
|
56 |
+
parser.add_argument("--instruction", default=instruction, help="Instruction text")
|
57 |
+
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
58 |
+
parser.add_argument("--is_leetcode", action='store_true',
|
59 |
+
help="To run the leetcode benchmark") # Temporary
|
60 |
+
parser.add_argument("--n_samples", type=int,
|
61 |
+
help="The number of nodes added during expansion", default=tree_width)
|
62 |
+
parser.add_argument("--depth", type=int,
|
63 |
+
help="Tree depth", default=tree_depth)
|
64 |
+
args = parser.parse_args()
|
65 |
+
return args
|
66 |
+
|
67 |
+
def run_querry():
|
68 |
+
if user_input:
|
69 |
+
|
70 |
+
# Create a new container for each subsequent message
|
71 |
+
runtime_container.write("Initiating process...")
|
72 |
+
|
73 |
+
# Make it so that prints go to runtime_container writes instead
|
74 |
+
old_stdout = sys.stdout
|
75 |
+
sys.stdout = runtime_container
|
76 |
+
|
77 |
+
with chat_col:
|
78 |
+
|
79 |
+
with st.spinner('Running...'):
|
80 |
+
args = make_args(user_input, tree_depth, tree_width, iterations)
|
81 |
+
# main call
|
82 |
+
response = lats_main(args)
|
83 |
+
|
84 |
+
sys.stdout = old_stdout
|
85 |
+
runtime_container.write("Response fetched.")
|
86 |
+
chat_col.markdown('<hr style="margin-top: 0.5rem; margin-bottom: 0.5rem;">', unsafe_allow_html=True)
|
87 |
+
chat_col.write(f"```python\n{response} \n")
|
88 |
+
|
89 |
+
return response
|
90 |
+
|
91 |
+
# User input section at the bottom of the page
|
92 |
+
with chat_col:
|
93 |
+
user_input = st.text_area("Enter your message here:", placeholder="Type your message here...", label_visibility="collapsed")
|
94 |
+
button = st.button("Send")
|
95 |
+
|
96 |
+
if button:
|
97 |
+
fail = False
|
98 |
+
if key == "":
|
99 |
+
st.warning("Missing OpenAI API Key")
|
100 |
+
fail = True
|
101 |
+
|
102 |
+
if user_input == "":
|
103 |
+
st.warning("Missing a coding problem")
|
104 |
+
fail = True
|
105 |
+
|
106 |
+
if (not fail):
|
107 |
+
openai.api_key = key
|
108 |
+
run_querry()
|
109 |
|
|
|
|
executors/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .py_executor import PyExecutor
|
2 |
+
from .factory import executor_factory
|
executors/cargo_harness/Cargo.toml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[package]
|
2 |
+
name = "cargo_harness"
|
3 |
+
version = "0.1.0"
|
4 |
+
edition = "2021"
|
5 |
+
|
6 |
+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
7 |
+
|
8 |
+
[dependencies]
|
executors/cargo_harness/src/.gitkeep
ADDED
File without changes
|
executors/executor_types.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import NamedTuple, List, Tuple
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
class ExecuteResult(NamedTuple):
|
5 |
+
is_passing: bool
|
6 |
+
feedback: str
|
7 |
+
state: Tuple[bool]
|
8 |
+
|
9 |
+
class Executor(ABC):
|
10 |
+
@abstractmethod
|
11 |
+
def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult:
|
12 |
+
...
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
def evaluate(self, name: str, func: str, test: str, timeout: int = 5) -> bool:
|
16 |
+
...
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
executors/executor_utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def timeout_handler(_, __):
|
3 |
+
raise TimeoutError()
|
4 |
+
|
5 |
+
import os, json
|
6 |
+
def to_jsonl(dict_data, file_path):
|
7 |
+
with open(file_path, 'a') as file:
|
8 |
+
json_line = json.dumps(dict_data)
|
9 |
+
file.write(json_line + os.linesep)
|
10 |
+
|
11 |
+
from threading import Thread
|
12 |
+
class PropagatingThread(Thread):
|
13 |
+
def run(self):
|
14 |
+
self.exc = None
|
15 |
+
try:
|
16 |
+
if hasattr(self, '_Thread__target'):
|
17 |
+
# Thread uses name mangling prior to Python 3.
|
18 |
+
self.ret = self._Thread__target(*self._Thread__args, **self._Thread__kwargs)
|
19 |
+
else:
|
20 |
+
self.ret = self._target(*self._args, **self._kwargs)
|
21 |
+
except BaseException as e:
|
22 |
+
self.exc = e
|
23 |
+
|
24 |
+
def join(self, timeout=None):
|
25 |
+
super(PropagatingThread, self).join(timeout)
|
26 |
+
if self.exc:
|
27 |
+
raise self.exc
|
28 |
+
return self.ret
|
29 |
+
|
30 |
+
|
31 |
+
def function_with_timeout(func, args, timeout):
|
32 |
+
result_container = []
|
33 |
+
|
34 |
+
def wrapper():
|
35 |
+
result_container.append(func(*args))
|
36 |
+
|
37 |
+
thread = PropagatingThread(target=wrapper)
|
38 |
+
thread.start()
|
39 |
+
thread.join(timeout)
|
40 |
+
|
41 |
+
if thread.is_alive():
|
42 |
+
raise TimeoutError()
|
43 |
+
else:
|
44 |
+
return result_container[0]
|
45 |
+
|
46 |
+
|
executors/factory.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .py_executor import PyExecutor
|
2 |
+
from .executor_types import Executor
|
3 |
+
|
4 |
+
def executor_factory(lang: str) -> Executor:
|
5 |
+
if lang == "py" or lang == "python":
|
6 |
+
return PyExecutor()
|
7 |
+
else:
|
8 |
+
raise ValueError(f"Invalid language for executor: {lang}")
|
executors/py_executor.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import signal
|
3 |
+
import astunparse
|
4 |
+
|
5 |
+
from .executor_utils import function_with_timeout
|
6 |
+
|
7 |
+
from typing import List
|
8 |
+
from .executor_types import ExecuteResult, Executor
|
9 |
+
|
10 |
+
class PyExecutor(Executor):
|
11 |
+
def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult:
|
12 |
+
# Combine function code and assert statement
|
13 |
+
imports = 'from typing import *'
|
14 |
+
func_test_list = [f'{imports}\n{func}\n{test}' for test in tests]
|
15 |
+
|
16 |
+
# Run the tests and collect the results
|
17 |
+
success_tests = []
|
18 |
+
failed_tests = []
|
19 |
+
is_passing = True
|
20 |
+
num_tests = len(func_test_list)
|
21 |
+
for i in range(num_tests):
|
22 |
+
try:
|
23 |
+
|
24 |
+
function_with_timeout(exec, (func_test_list[i], globals()), timeout)
|
25 |
+
|
26 |
+
success_tests += [tests[i]]
|
27 |
+
except Exception:
|
28 |
+
output = get_output(func, tests[i], timeout=timeout)
|
29 |
+
failed_tests += [f"{tests[i]} # output: {output}"]
|
30 |
+
is_passing = False
|
31 |
+
|
32 |
+
state = []
|
33 |
+
for test in tests:
|
34 |
+
if test in success_tests:
|
35 |
+
state += [True]
|
36 |
+
else:
|
37 |
+
state += [False]
|
38 |
+
|
39 |
+
state = tuple(state)
|
40 |
+
|
41 |
+
feedback = "Tested passed:"
|
42 |
+
for test in success_tests:
|
43 |
+
feedback += f"\n{test}"
|
44 |
+
feedback += "\n\nTests failed:"
|
45 |
+
for test in failed_tests:
|
46 |
+
feedback += f"\n{test}"
|
47 |
+
|
48 |
+
return ExecuteResult(is_passing, feedback, state)
|
49 |
+
|
50 |
+
def evaluate(self, name: str, func: str, test: str, timeout: int = 5) -> bool:
|
51 |
+
"""
|
52 |
+
Evaluates the implementation on Human-Eval Python.
|
53 |
+
|
54 |
+
probably should be written in a dataset-agnostic way but not now
|
55 |
+
"""
|
56 |
+
code = f"""{func}
|
57 |
+
|
58 |
+
{test}
|
59 |
+
|
60 |
+
check({name})
|
61 |
+
"""
|
62 |
+
try:
|
63 |
+
|
64 |
+
function_with_timeout(exec, (code, globals()), timeout)
|
65 |
+
|
66 |
+
return True
|
67 |
+
except Exception:
|
68 |
+
return False
|
69 |
+
|
70 |
+
def get_call_str(assert_statement: str) -> str:
|
71 |
+
ast_parsed = ast.parse(assert_statement)
|
72 |
+
try:
|
73 |
+
call_str = ast_parsed.body[0].test.left # type: ignore
|
74 |
+
except:
|
75 |
+
call_str = ast_parsed.body[0].test # type: ignore
|
76 |
+
|
77 |
+
return astunparse.unparse(call_str).strip()
|
78 |
+
|
79 |
+
def get_output(func: str, assert_statement: str, timeout: int = 5) -> str:
|
80 |
+
try:
|
81 |
+
exec(f"from typing import *\n{func}", globals())
|
82 |
+
func_call = get_call_str(assert_statement)
|
83 |
+
output = function_with_timeout(eval, (func_call, globals()), timeout)
|
84 |
+
return output
|
85 |
+
except TimeoutError:
|
86 |
+
return "TIMEOUT"
|
87 |
+
except Exception as e:
|
88 |
+
return str(e)
|
generators/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .py_generate import PyGenerator
|
2 |
+
from .factory import generator_factory, model_factory
|
3 |
+
from .model import ModelBase, GPT4, GPT35
|
generators/factory.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .py_generate import PyGenerator
|
2 |
+
from .generator_types import Generator
|
3 |
+
from .model import ModelBase, GPT4, GPT35, GPTDavinci
|
4 |
+
|
5 |
+
def generator_factory(lang: str) -> Generator:
|
6 |
+
if lang == "py" or lang == "python":
|
7 |
+
return PyGenerator()
|
8 |
+
else:
|
9 |
+
raise ValueError(f"Invalid language for generator: {lang}")
|
10 |
+
|
11 |
+
|
12 |
+
def model_factory(model_name: str) -> ModelBase:
|
13 |
+
if model_name == "gpt-4":
|
14 |
+
return GPT4()
|
15 |
+
elif model_name == "gpt-3.5-turbo-0613":
|
16 |
+
return GPT35()
|
17 |
+
elif model_name.startswith("text-davinci"):
|
18 |
+
return GPTDavinci(model_name)
|
19 |
+
else:
|
20 |
+
raise ValueError(f"Invalid model name: {model_name}")
|
generators/generator_types.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union
|
2 |
+
from abc import abstractmethod, ABC
|
3 |
+
|
4 |
+
from generators.model import ModelBase
|
5 |
+
|
6 |
+
|
7 |
+
class Generator:
|
8 |
+
@abstractmethod
|
9 |
+
def self_reflection(self, func: str, feedback: str, model: ModelBase) -> str:
|
10 |
+
...
|
11 |
+
|
12 |
+
@abstractmethod
|
13 |
+
def func_impl(
|
14 |
+
self,
|
15 |
+
func_sig: str,
|
16 |
+
model: ModelBase,
|
17 |
+
strategy: str,
|
18 |
+
prev_func_impl: Optional[str] = None,
|
19 |
+
feedback: Optional[str] = None,
|
20 |
+
self_reflection: Optional[str] = None,
|
21 |
+
num_comps: int = 1,
|
22 |
+
temperature: float = 0.0,
|
23 |
+
) -> Union[str, List[str]]:
|
24 |
+
...
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
def internal_tests(
|
28 |
+
self,
|
29 |
+
func_sig: str,
|
30 |
+
model: ModelBase,
|
31 |
+
max_num_tests: int = 5
|
32 |
+
) -> List[str]:
|
33 |
+
...
|
generators/generator_utils.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from generators.model import ModelBase, Message
|
2 |
+
import random
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
from typing import Union, List, Optional, Callable
|
6 |
+
|
7 |
+
|
8 |
+
def generic_generate_func_impl(
|
9 |
+
func_sig: str,
|
10 |
+
model: ModelBase,
|
11 |
+
strategy: str,
|
12 |
+
prev_func_impl,
|
13 |
+
feedback,
|
14 |
+
self_reflection,
|
15 |
+
num_comps,
|
16 |
+
temperature,
|
17 |
+
reflexion_chat_instruction: str,
|
18 |
+
reflexion_few_shot: str,
|
19 |
+
simple_chat_instruction: str,
|
20 |
+
reflexion_completion_instruction: str,
|
21 |
+
simple_completion_instruction: str,
|
22 |
+
code_block_instruction: str,
|
23 |
+
parse_code_block: Callable[[str], str],
|
24 |
+
add_code_block: Callable[[str], str]
|
25 |
+
) -> Union[str, List[str]]:
|
26 |
+
if strategy != "reflexion" and strategy != "simple":
|
27 |
+
raise ValueError(
|
28 |
+
f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`")
|
29 |
+
if strategy == "reflexion" and (prev_func_impl is None or feedback is None or self_reflection is None):
|
30 |
+
raise ValueError(
|
31 |
+
f"Invalid arguments: given `strategy=reflexion` but `prev_func_impl`, `feedback`, or `self_reflection` is None")
|
32 |
+
|
33 |
+
if model.is_chat:
|
34 |
+
if strategy == "reflexion":
|
35 |
+
message = f"{reflexion_few_shot}\n[previous impl]:\n{add_code_block(prev_func_impl)}\n\n[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:\n{self_reflection}\n\n[improved impl]:\n{func_sig}"
|
36 |
+
prompt = f"{reflexion_chat_instruction}\n{code_block_instruction}"
|
37 |
+
# func_bodies is a really bad name, as it can also be just 1 string
|
38 |
+
print_messages(prompt, message)
|
39 |
+
messages = [
|
40 |
+
Message(
|
41 |
+
role="system",
|
42 |
+
content=prompt,
|
43 |
+
),
|
44 |
+
Message(
|
45 |
+
role="user", # TODO: check this
|
46 |
+
content=reflexion_few_shot,
|
47 |
+
),
|
48 |
+
Message(
|
49 |
+
role="assistant",
|
50 |
+
content=add_code_block(prev_func_impl),
|
51 |
+
),
|
52 |
+
Message(
|
53 |
+
role="user",
|
54 |
+
content=f"[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:",
|
55 |
+
),
|
56 |
+
Message(
|
57 |
+
role="assistant",
|
58 |
+
content=self_reflection,
|
59 |
+
),
|
60 |
+
Message(
|
61 |
+
role="user",
|
62 |
+
content=f"[improved impl]:\n{func_sig}",
|
63 |
+
),
|
64 |
+
]
|
65 |
+
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
|
66 |
+
else:
|
67 |
+
system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}"
|
68 |
+
print_messages(system_prompt, func_sig)
|
69 |
+
messages = [
|
70 |
+
Message(
|
71 |
+
role="system",
|
72 |
+
content=f"{simple_chat_instruction}\n{code_block_instruction}",
|
73 |
+
),
|
74 |
+
Message(
|
75 |
+
role="user",
|
76 |
+
content=func_sig,
|
77 |
+
),
|
78 |
+
]
|
79 |
+
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
|
80 |
+
else:
|
81 |
+
if strategy == "reflexion":
|
82 |
+
prompt = f"{reflexion_completion_instruction}\n{add_code_block(prev_func_impl)}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}"
|
83 |
+
func_bodies = model.generate(
|
84 |
+
prompt, num_comps=num_comps, temperature=temperature)
|
85 |
+
else:
|
86 |
+
prompt = f"{simple_completion_instruction}\n{func_sig}\n{code_block_instruction}"
|
87 |
+
func_bodies = model.generate(
|
88 |
+
prompt, num_comps=num_comps, temperature=temperature)
|
89 |
+
|
90 |
+
if num_comps == 1:
|
91 |
+
assert isinstance(func_bodies, str)
|
92 |
+
func_body_str = parse_code_block(func_bodies)
|
93 |
+
print_generated_func_body(func_body_str)
|
94 |
+
return func_body_str
|
95 |
+
|
96 |
+
else:
|
97 |
+
func_bodies = [parse_code_block(func_body) for func_body in func_bodies]
|
98 |
+
print_generated_func_body("\n\n".join(func_bodies))
|
99 |
+
return func_bodies
|
100 |
+
|
101 |
+
|
102 |
+
def generate_with_accumulated_context(
|
103 |
+
func_sig: str,
|
104 |
+
model: ModelBase,
|
105 |
+
strategy: str,
|
106 |
+
prev_func_impl,
|
107 |
+
accumulated_feedback,
|
108 |
+
accumulated_reflection,
|
109 |
+
num_comps,
|
110 |
+
temperature,
|
111 |
+
reflexion_chat_instruction: str,
|
112 |
+
reflexion_few_shot: str,
|
113 |
+
simple_chat_instruction: str,
|
114 |
+
reflexion_completion_instruction: str,
|
115 |
+
simple_completion_instruction: str,
|
116 |
+
code_block_instruction: str,
|
117 |
+
parse_code_block: Callable[[str], str],
|
118 |
+
add_code_block: Callable[[str], str]
|
119 |
+
) -> Union[str, List[str]]:
|
120 |
+
# Ensure that the strategy is valid
|
121 |
+
if strategy != "reflexion" and strategy != "simple":
|
122 |
+
raise ValueError(
|
123 |
+
f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`")
|
124 |
+
if strategy == "reflexion" and (prev_func_impl is None or accumulated_feedback is None or accumulated_reflection is None):
|
125 |
+
raise ValueError(
|
126 |
+
f"Invalid arguments: given `strategy=reflexion` but `prev_func_impl`, `feedback`, or `self_reflection` is None")
|
127 |
+
|
128 |
+
# Build the accumulated context from the provided feedback and reflections
|
129 |
+
accumulated_context = "\n\n".join(
|
130 |
+
[f"[previous impl {i+1}]:\n{add_code_block(impl)}\n[unit test results from previous impl {i+1}]:\n{feedback}\n[reflection on previous impl {i+1}]:\n{reflection}"
|
131 |
+
for i, (impl, feedback, reflection) in enumerate(zip(prev_func_impl, accumulated_feedback, accumulated_reflection))]
|
132 |
+
)
|
133 |
+
|
134 |
+
if model.is_chat:
|
135 |
+
if strategy == "reflexion":
|
136 |
+
# Constructing the message using a loop for accumulated context
|
137 |
+
messages = [
|
138 |
+
Message(role="system", content=f"{reflexion_chat_instruction}\n{code_block_instruction}"),
|
139 |
+
Message(role="user", content=reflexion_few_shot)
|
140 |
+
]
|
141 |
+
|
142 |
+
for impl, feedback, reflection in zip(prev_func_impl, accumulated_feedback, accumulated_reflection):
|
143 |
+
messages.append(Message(role="assistant", content=add_code_block(impl)))
|
144 |
+
messages.append(Message(role="user", content=f"[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:\n{reflection}"))
|
145 |
+
|
146 |
+
messages.append(Message(role="user", content=f"[improved impl]:\n{func_sig}"))
|
147 |
+
prompt = "\n".join([message.content for message in messages])
|
148 |
+
message = (f"{reflexion_few_shot}\n{accumulated_context}\n\n[improved impl]:\n{func_sig}")
|
149 |
+
print_messages(prompt, message)
|
150 |
+
|
151 |
+
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
|
152 |
+
else:
|
153 |
+
system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}"
|
154 |
+
print_messages(system_prompt, func_sig)
|
155 |
+
messages = [
|
156 |
+
Message(role="system", content=f"{simple_chat_instruction}\n{code_block_instruction}"),
|
157 |
+
Message(role="user", content=func_sig)
|
158 |
+
]
|
159 |
+
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
|
160 |
+
else:
|
161 |
+
if strategy == "reflexion":
|
162 |
+
prompt = f"{reflexion_completion_instruction}\n{accumulated_context}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}"
|
163 |
+
func_bodies = model.generate(prompt, num_comps=num_comps, temperature=temperature)
|
164 |
+
print_messages(prompt, "")
|
165 |
+
else:
|
166 |
+
prompt = f"{simple_completion_instruction}\n{func_sig}\n{code_block_instruction}"
|
167 |
+
func_bodies = model.generate(prompt, num_comps=num_comps, temperature=temperature)
|
168 |
+
print_messages(prompt, "")
|
169 |
+
|
170 |
+
if num_comps == 1:
|
171 |
+
assert isinstance(func_bodies, str)
|
172 |
+
func_body_str = parse_code_block(func_bodies)
|
173 |
+
print_generated_func_body(func_body_str)
|
174 |
+
return func_body_str
|
175 |
+
|
176 |
+
else:
|
177 |
+
func_bodies = [parse_code_block(func_body) for func_body in func_bodies]
|
178 |
+
print_generated_func_body("\n\n".join(func_bodies))
|
179 |
+
return func_bodies
|
180 |
+
|
181 |
+
|
182 |
+
def generic_generate_internal_tests(
|
183 |
+
func_sig: str,
|
184 |
+
model: ModelBase,
|
185 |
+
max_num_tests: int,
|
186 |
+
test_generation_few_shot: str,
|
187 |
+
test_generation_chat_instruction: str,
|
188 |
+
test_generation_completion_instruction: str,
|
189 |
+
parse_tests: Callable[[str], List[str]],
|
190 |
+
is_syntax_valid: Callable[[str], bool],
|
191 |
+
is_react: bool = False
|
192 |
+
) -> List[str]:
|
193 |
+
"""Generates tests for a function."""
|
194 |
+
if model.is_chat:
|
195 |
+
if is_react:
|
196 |
+
messages = [
|
197 |
+
Message(
|
198 |
+
role="system",
|
199 |
+
content=test_generation_chat_instruction,
|
200 |
+
),
|
201 |
+
Message(
|
202 |
+
role="user",
|
203 |
+
content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[think]:"
|
204 |
+
)
|
205 |
+
]
|
206 |
+
output = model.generate_chat(messages=messages, max_tokens=1024)
|
207 |
+
print(f'React test generation output: {output}')
|
208 |
+
else:
|
209 |
+
messages = [
|
210 |
+
Message(
|
211 |
+
role="system",
|
212 |
+
content=test_generation_chat_instruction,
|
213 |
+
),
|
214 |
+
Message(
|
215 |
+
role="user",
|
216 |
+
content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[unit tests]:",
|
217 |
+
)
|
218 |
+
]
|
219 |
+
output = model.generate_chat(messages=messages, max_tokens=1024)
|
220 |
+
else:
|
221 |
+
prompt = f'{test_generation_completion_instruction}\n\nfunc signature:\n{func_sig}\nunit tests:'
|
222 |
+
output = model.generate(prompt, max_tokens=1024)
|
223 |
+
all_tests = parse_tests(output) # type: ignore
|
224 |
+
valid_tests = [test for test in all_tests if is_syntax_valid(test)]
|
225 |
+
|
226 |
+
# print(valid_tests)
|
227 |
+
|
228 |
+
return (valid_tests)
|
229 |
+
|
230 |
+
|
231 |
+
def generic_generate_self_reflection(
|
232 |
+
func: str,
|
233 |
+
feedback: str,
|
234 |
+
model: ModelBase,
|
235 |
+
self_reflection_chat_instruction: str,
|
236 |
+
self_reflection_completion_instruction: str,
|
237 |
+
add_code_block: Callable[[str], str],
|
238 |
+
self_reflection_few_shot: Optional[str] = None,
|
239 |
+
) -> str:
|
240 |
+
if model.is_chat:
|
241 |
+
if self_reflection_few_shot is not None:
|
242 |
+
messages = [
|
243 |
+
Message(
|
244 |
+
role="system",
|
245 |
+
content=self_reflection_chat_instruction,
|
246 |
+
),
|
247 |
+
Message(
|
248 |
+
role="user",
|
249 |
+
content=f'{self_reflection_few_shot}\n\n[function impl]:\n{add_code_block(func)}\n\n[unit test results]:\n{feedback}\n\n[self-reflection]:',
|
250 |
+
)
|
251 |
+
]
|
252 |
+
reflection = model.generate_chat(messages=messages)
|
253 |
+
print(f'|Self reflection output|: {reflection}')
|
254 |
+
else:
|
255 |
+
messages = [
|
256 |
+
Message(
|
257 |
+
role="system",
|
258 |
+
content=self_reflection_chat_instruction,
|
259 |
+
),
|
260 |
+
Message(
|
261 |
+
role="user",
|
262 |
+
content=f'[function impl]:\n{add_code_block(func)}\n\n[unit test results]:\n{feedback}\n\n[self-reflection]:',
|
263 |
+
)
|
264 |
+
]
|
265 |
+
reflection = model.generate_chat(messages=messages)
|
266 |
+
else:
|
267 |
+
reflection = model.generate(
|
268 |
+
f'{self_reflection_completion_instruction}\n{add_code_block(func)}\n\n{feedback}\n\nExplanation:')
|
269 |
+
return reflection # type: ignore
|
270 |
+
|
271 |
+
|
272 |
+
def sample_n_random(items: List[str], n: int) -> List[str]:
|
273 |
+
"""Sample min(n, len(items)) random items from a list"""
|
274 |
+
assert n >= 0
|
275 |
+
if n >= len(items):
|
276 |
+
return items
|
277 |
+
return random.sample(items, n)
|
278 |
+
|
279 |
+
def print_messages(system_message_text: str, user_message_text: str) -> None:
|
280 |
+
print(f"""{system_message_text}""")
|
281 |
+
print(f"""{user_message_text} \n""")
|
282 |
+
|
283 |
+
def print_generated_func_body(func_body_str: str) -> None:
|
284 |
+
print(f"""|GENERATED FUNCTION BODY| \n
|
285 |
+
```python\n{func_body_str} \n
|
286 |
+
""")
|
generators/model.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union, Optional, Literal
|
2 |
+
import dataclasses
|
3 |
+
|
4 |
+
from tenacity import (
|
5 |
+
retry,
|
6 |
+
stop_after_attempt, # type: ignore
|
7 |
+
wait_random_exponential, # type: ignore
|
8 |
+
)
|
9 |
+
import openai
|
10 |
+
|
11 |
+
MessageRole = Literal["system", "user", "assistant"]
|
12 |
+
|
13 |
+
|
14 |
+
@dataclasses.dataclass()
|
15 |
+
class Message():
|
16 |
+
role: MessageRole
|
17 |
+
content: str
|
18 |
+
|
19 |
+
|
20 |
+
def message_to_str(message: Message) -> str:
|
21 |
+
return f"{message.role}: {message.content}"
|
22 |
+
|
23 |
+
|
24 |
+
def messages_to_str(messages: List[Message]) -> str:
|
25 |
+
return "\n".join([message_to_str(message) for message in messages])
|
26 |
+
|
27 |
+
|
28 |
+
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
29 |
+
def gpt_completion(
|
30 |
+
model: str,
|
31 |
+
prompt: str,
|
32 |
+
max_tokens: int = 1024,
|
33 |
+
stop_strs: Optional[List[str]] = None,
|
34 |
+
temperature: float = 0.0,
|
35 |
+
num_comps=1,
|
36 |
+
) -> Union[List[str], str]:
|
37 |
+
response = openai.Completion.create(
|
38 |
+
model=model,
|
39 |
+
prompt=prompt,
|
40 |
+
temperature=temperature,
|
41 |
+
max_tokens=max_tokens,
|
42 |
+
top_p=1,
|
43 |
+
frequency_penalty=0.0,
|
44 |
+
presence_penalty=0.0,
|
45 |
+
stop=stop_strs,
|
46 |
+
n=num_comps,
|
47 |
+
)
|
48 |
+
if num_comps == 1:
|
49 |
+
return response.choices[0].text # type: ignore
|
50 |
+
|
51 |
+
return [choice.text for choice in response.choices] # type: ignore
|
52 |
+
|
53 |
+
|
54 |
+
@retry(wait=wait_random_exponential(min=1, max=180), stop=stop_after_attempt(6))
|
55 |
+
def gpt_chat(
|
56 |
+
model: str,
|
57 |
+
messages: List,
|
58 |
+
max_tokens: int = 1024,
|
59 |
+
temperature: float = 0.0,
|
60 |
+
num_comps=1,
|
61 |
+
) -> Union[List[str], str]:
|
62 |
+
try:
|
63 |
+
response = openai.ChatCompletion.create(
|
64 |
+
model=model,
|
65 |
+
messages=[dataclasses.asdict(message) for message in messages],
|
66 |
+
max_tokens=max_tokens,
|
67 |
+
temperature=temperature,
|
68 |
+
top_p=1,
|
69 |
+
frequency_penalty=0.0,
|
70 |
+
presence_penalty=0.0,
|
71 |
+
n=num_comps,
|
72 |
+
)
|
73 |
+
if num_comps == 1:
|
74 |
+
return response.choices[0].message.content # type: ignore
|
75 |
+
return [choice.message.content for choice in response.choices] # type: ignore
|
76 |
+
|
77 |
+
except Exception as e:
|
78 |
+
print(f"An error occurred while calling OpenAI: {e}")
|
79 |
+
raise
|
80 |
+
|
81 |
+
class ModelBase():
|
82 |
+
def __init__(self, name: str):
|
83 |
+
self.name = name
|
84 |
+
self.is_chat = False
|
85 |
+
|
86 |
+
def __repr__(self) -> str:
|
87 |
+
return f'{self.name}'
|
88 |
+
|
89 |
+
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
|
90 |
+
raise NotImplementedError
|
91 |
+
|
92 |
+
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1) -> Union[List[str], str]:
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
|
96 |
+
class GPTChat(ModelBase):
|
97 |
+
def __init__(self, model_name: str):
|
98 |
+
self.name = model_name
|
99 |
+
self.is_chat = True
|
100 |
+
|
101 |
+
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
|
102 |
+
return gpt_chat(self.name, messages, max_tokens, temperature, num_comps)
|
103 |
+
|
104 |
+
|
105 |
+
class GPT4(GPTChat):
|
106 |
+
def __init__(self):
|
107 |
+
super().__init__("gpt-4")
|
108 |
+
|
109 |
+
|
110 |
+
class GPT35(GPTChat):
|
111 |
+
def __init__(self):
|
112 |
+
super().__init__("gpt-3.5-turbo")
|
113 |
+
|
114 |
+
|
115 |
+
class GPTDavinci(ModelBase):
|
116 |
+
def __init__(self, model_name: str):
|
117 |
+
self.name = model_name
|
118 |
+
|
119 |
+
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0, num_comps=1) -> Union[List[str], str]:
|
120 |
+
return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps)
|
generators/parse.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
|
5 |
+
def parse_code_block(string: str, lang: str) -> Optional[str]:
|
6 |
+
code_pattern = fr"```{lang}\n(.*?)\n```"
|
7 |
+
match = re.search(code_pattern, string, re.DOTALL)
|
8 |
+
|
9 |
+
if match:
|
10 |
+
return match.group(1)
|
11 |
+
|
12 |
+
generic_code_pattern = r"```\n(.*?)\n```"
|
13 |
+
match = re.search(generic_code_pattern, string, re.DOTALL)
|
14 |
+
|
15 |
+
if match:
|
16 |
+
return match.group(1)
|
17 |
+
|
18 |
+
return parse_first_func(string, lang)
|
19 |
+
|
20 |
+
|
21 |
+
def parse_first_func(code: str, lang: str) -> Optional[str]:
|
22 |
+
assert lang == "python", "Only python is supported for now. TODO: Rust"
|
23 |
+
code_lines = code.split("\n")
|
24 |
+
def_i = -1
|
25 |
+
last_i = 0
|
26 |
+
got_return = False
|
27 |
+
for i, line in enumerate(code_lines):
|
28 |
+
if line.startswith("def "):
|
29 |
+
if def_i == -1:
|
30 |
+
def_i = i
|
31 |
+
else:
|
32 |
+
break
|
33 |
+
elif "return" in line and def_i != -1:
|
34 |
+
got_return = True
|
35 |
+
if line == "" and def_i != -1 and got_return:
|
36 |
+
last_i = i
|
37 |
+
break
|
38 |
+
|
39 |
+
if last_i == 0:
|
40 |
+
last_i = len(code_lines) - 1
|
41 |
+
|
42 |
+
if def_i == -1:
|
43 |
+
return None
|
44 |
+
|
45 |
+
return "\n".join(code_lines[def_i:last_i+1]).rstrip("[/PYTHON]")
|
46 |
+
|
47 |
+
|
48 |
+
def add_code_block(string: str, lang: str) -> str:
|
49 |
+
return f"```{lang}\n{string}\n```"
|
generators/py_generate.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from generators.model import ModelBase, message_to_str
|
2 |
+
from .generator_types import Generator
|
3 |
+
from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection, generate_with_accumulated_context
|
4 |
+
|
5 |
+
from typing import Optional, List, Union
|
6 |
+
import ast
|
7 |
+
import re
|
8 |
+
from .parse import parse_code_block, add_code_block
|
9 |
+
|
10 |
+
PY_SIMPLE_COMPLETION_INSTRUCTION = "# Write the body of this function only."
|
11 |
+
PY_REFLEXION_COMPLETION_INSTRUCTION = "You are a Python writing assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature).\n\n-----"
|
12 |
+
PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = "You are a Python writing assistant. You will be given a function implementation and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation.\n\n-----"
|
13 |
+
USE_PYTHON_CODEBLOCK_INSTRUCTION = "Use a Python code block to write your response. For example:\n```python\nprint('Hello world!')\n```"
|
14 |
+
|
15 |
+
PY_SIMPLE_CHAT_INSTRUCTION = "You are an AI that only responds with python code, NOT ENGLISH. You will be given a function signature and its docstring by the user. Write your full implementation (restate the function signature)."
|
16 |
+
PY_SIMPLE_CHAT_INSTRUCTION_V2 = "You are an AI that only responds with only python code. You will be given a function signature and its docstring by the user. Write your full implementation (restate the function signature)."
|
17 |
+
PY_REFLEXION_CHAT_INSTRUCTION = "You are an AI Python assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature)."
|
18 |
+
PY_REFLEXION_CHAT_INSTRUCTION_V2 = "You are an AI Python assistant. You will be given your previous implementation of a function, a series of unit tests results, and your self-reflection on your previous implementation. Write your full implementation (restate the function signature)."
|
19 |
+
PY_REFLEXION_FEW_SHOT_ADD = '''Example 1:
|
20 |
+
[previous impl]:
|
21 |
+
```python
|
22 |
+
def add(a: int, b: int) -> int:
|
23 |
+
"""
|
24 |
+
Given integers a and b, return the total value of a and b.
|
25 |
+
"""
|
26 |
+
return a - b
|
27 |
+
```
|
28 |
+
|
29 |
+
[unit test results from previous impl]:
|
30 |
+
Tested passed:
|
31 |
+
|
32 |
+
Tests failed:
|
33 |
+
assert add(1, 2) == 3 # output: -1
|
34 |
+
assert add(1, 2) == 4 # output: -1
|
35 |
+
|
36 |
+
[reflection on previous impl]:
|
37 |
+
The implementation failed the test cases where the input integers are 1 and 2. The issue arises because the code does not add the two integers together, but instead subtracts the second integer from the first. To fix this issue, we should change the operator from `-` to `+` in the return statement. This will ensure that the function returns the correct output for the given input.
|
38 |
+
|
39 |
+
[improved impl]:
|
40 |
+
```python
|
41 |
+
def add(a: int, b: int) -> int:
|
42 |
+
"""
|
43 |
+
Given integers a and b, return the total value of a and b.
|
44 |
+
"""
|
45 |
+
return a + b
|
46 |
+
```
|
47 |
+
'''
|
48 |
+
|
49 |
+
PY_REFLEXION_FEW_SHOT = '''Example 1:
|
50 |
+
[previous impl]:
|
51 |
+
```python
|
52 |
+
from typing import *
|
53 |
+
def fullJustify(words: List[str], maxWidth: int) -> List[str]:
|
54 |
+
"""
|
55 |
+
Given an array of words and a width maxWidth, format the text such that each line has exactly maxWidth characters and is fully (left and right) justified.
|
56 |
+
You should pack your words in a greedy approach; that is, pack as many words as you can in each line. Pad extra spaces `' '` when necessary so that each line has exactly maxWidth characters.
|
57 |
+
Extra spaces between words should be distributed as evenly as possible. If the number of spaces on a line do not divide evenly between words, the empty slots on the left will be assigned more spaces than the slots on the right.
|
58 |
+
For the last line of text, it should be left justified and no extra space is inserted between words.
|
59 |
+
Note:
|
60 |
+
A word is defined as a character sequence consisting of non-space characters only.
|
61 |
+
Each word's length is guaranteed to be greater than 0 and not exceed maxWidth.
|
62 |
+
The input array `words` contains at least one word.
|
63 |
+
"""
|
64 |
+
res = []
|
65 |
+
cur_line = []
|
66 |
+
cur_len = 0
|
67 |
+
|
68 |
+
for word in words:
|
69 |
+
if cur_len + len(word) + len(cur_line) > maxWidth:
|
70 |
+
if len(cur_line) == 1:
|
71 |
+
res.append(cur_line[0] + ' ' * (maxWidth - cur_len))
|
72 |
+
else:
|
73 |
+
spaces = maxWidth - cur_len
|
74 |
+
space_between = spaces // (len(cur_line) - 1)
|
75 |
+
extra_spaces = spaces % (len(cur_line) - 1)
|
76 |
+
line = ''
|
77 |
+
for i, w in enumerate(cur_line[:-1]):
|
78 |
+
line += w + ' ' * (space_between + (i < extra_spaces))
|
79 |
+
line += cur_line[-1]
|
80 |
+
res.append(line)
|
81 |
+
cur_line = []
|
82 |
+
cur_len = 0
|
83 |
+
cur_line.append(word)
|
84 |
+
cur_len += len(word)
|
85 |
+
|
86 |
+
last_line = ' '.join(cur_line)
|
87 |
+
last_line += ' ' * (maxWidth - len(last_line))
|
88 |
+
res.append(last_line)
|
89 |
+
|
90 |
+
return res
|
91 |
+
```
|
92 |
+
|
93 |
+
[unit test results from previous impl]:
|
94 |
+
Tested passed:
|
95 |
+
|
96 |
+
Tests failed:
|
97 |
+
assert fullJustify([], 10) == [] # output: [' ']
|
98 |
+
assert fullJustify([], 0) == [] # output: ['']
|
99 |
+
|
100 |
+
[reflection on previous impl]:
|
101 |
+
The implementation failed the test cases where the input list of words is empty. The issue arises because the code does not handle the case where there are no words to process. As a result, it still appends a line with spaces to the result list, even when there are no words. To fix this issue, we should add a condition at the beginning of the function to check if the input list is empty, and return an empty list if it is. This will ensure that the function returns the correct output for empty input lists.
|
102 |
+
|
103 |
+
[improved impl]:
|
104 |
+
```python
|
105 |
+
from typing import *
|
106 |
+
def fullJustify(words: List[str], maxWidth: int) -> List[str]:
|
107 |
+
"""
|
108 |
+
Given an array of words and a width maxWidth, format the text such that each line has exactly maxWidth characters and is fully (left and right) justified.
|
109 |
+
You should pack your words in a greedy approach; that is, pack as many words as you can in each line. Pad extra spaces `' '` when necessary so that each line has exactly maxWidth characters.
|
110 |
+
Extra spaces between words should be distributed as evenly as possible. If the number of spaces on a line do not divide evenly between words, the empty slots on the left will be assigned more spaces than the slots on the right.
|
111 |
+
For the last line of text, it should be left justified and no extra space is inserted between words.
|
112 |
+
Note:
|
113 |
+
A word is defined as a character sequence consisting of non-space characters only.
|
114 |
+
Each word's length is guaranteed to be greater than 0 and not exceed maxWidth.
|
115 |
+
The input array `words` contains at least one word.
|
116 |
+
"""
|
117 |
+
if not words:
|
118 |
+
return []
|
119 |
+
|
120 |
+
res = []
|
121 |
+
cur_line = []
|
122 |
+
cur_len = 0
|
123 |
+
|
124 |
+
for word in words:
|
125 |
+
if cur_len + len(word) + len(cur_line) > maxWidth:
|
126 |
+
if len(cur_line) == 1:
|
127 |
+
res.append(cur_line[0] + ' ' * (maxWidth - cur_len))
|
128 |
+
else:
|
129 |
+
spaces = maxWidth - cur_len
|
130 |
+
space_between = spaces // (len(cur_line) - 1)
|
131 |
+
extra_spaces = spaces % (len(cur_line) - 1)
|
132 |
+
line = ''
|
133 |
+
for i, w in enumerate(cur_line[:-1]):
|
134 |
+
line += w + ' ' * (space_between + (i < extra_spaces))
|
135 |
+
line += cur_line[-1]
|
136 |
+
res.append(line)
|
137 |
+
cur_line = []
|
138 |
+
cur_len = 0
|
139 |
+
cur_line.append(word)
|
140 |
+
cur_len += len(word)
|
141 |
+
|
142 |
+
last_line = ' '.join(cur_line)
|
143 |
+
last_line += ' ' * (maxWidth - len(last_line))
|
144 |
+
res.append(last_line)
|
145 |
+
|
146 |
+
return res
|
147 |
+
```
|
148 |
+
END EXAMPLES
|
149 |
+
|
150 |
+
'''
|
151 |
+
PY_SELF_REFLECTION_CHAT_INSTRUCTION = "You are a Python programming assistant. You will be given a function implementation and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation."
|
152 |
+
PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = "You are a Python programming assistant. You will be given a function implementation and a series of unit test results. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as guidance when you try again later. Only provide the few sentence description in your answer, not the implementation. You will be given a few examples by the user."
|
153 |
+
PY_SELF_REFLECTION_FEW_SHOT = """Example 1:
|
154 |
+
[function impl]:
|
155 |
+
```python
|
156 |
+
def longest_subarray_with_sum_limit(nums: List[int], target: int) -> List[int]:
|
157 |
+
n = len(nums)
|
158 |
+
left, right = 0, 0
|
159 |
+
max_length = 0
|
160 |
+
current_sum = 0
|
161 |
+
result = []
|
162 |
+
while right < n:
|
163 |
+
current_sum += nums[right]
|
164 |
+
while current_sum > target:
|
165 |
+
current_sum -= nums[left]
|
166 |
+
left += 1
|
167 |
+
if right - left + 1 >= max_length:
|
168 |
+
max_length = right - left + 1
|
169 |
+
result = nums[left:right+1]
|
170 |
+
right += 1
|
171 |
+
return result
|
172 |
+
```
|
173 |
+
[unit test results]:
|
174 |
+
Tests passing:
|
175 |
+
assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 8) == [1, 2, 3]
|
176 |
+
assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 15) == [1, 2, 3, 4, 5]
|
177 |
+
assert longest_subarray_with_sum_limit([1, -1, 2, -2, 3, -3], 2) == [1, -1, 2, -2, 3]
|
178 |
+
assert longest_subarray_with_sum_limit([], 10) == []
|
179 |
+
assert longest_subarray_with_sum_limit([], 0) == []
|
180 |
+
assert longest_subarray_with_sum_limit([], -5) == []
|
181 |
+
Tests failing:
|
182 |
+
assert longest_subarray_with_sum_limit([5, 6, 7, 8, 9], 4) == [] # output: [5]
|
183 |
+
[self-reflection]:
|
184 |
+
The implementation failed the where no subarray fulfills the condition. The issue in the implementation is due to the use of >= instead of > in the condition to update the result. Because of this, it returns a subarray even when the sum is greater than the target, as it still updates the result when the current subarray length is equal to the previous longest subarray length. To overcome this error, we should change the condition to only update the result when the current subarray length is strictly greater than the previous longest subarray length. This can be done by replacing >= with > in the condition.
|
185 |
+
|
186 |
+
Example 2:
|
187 |
+
[function impl]:
|
188 |
+
```python
|
189 |
+
def longest_subarray_with_sum_limit(nums: List[int], target: int) -> List[int]:
|
190 |
+
n = len(nums)
|
191 |
+
left, right = 0, 0
|
192 |
+
max_length = 0
|
193 |
+
current_sum = 0
|
194 |
+
result = []
|
195 |
+
while current_sum + nums[right] <= target:
|
196 |
+
current_sum += nums[right]
|
197 |
+
right += 1
|
198 |
+
while right < n:
|
199 |
+
current_sum += nums[right]
|
200 |
+
while current_sum > target:
|
201 |
+
current_sum -= nums[left]
|
202 |
+
left += 1
|
203 |
+
if right - left + 1 > max_length:
|
204 |
+
max_length = right - left + 1
|
205 |
+
result = nums[left:right+1]
|
206 |
+
right += 1
|
207 |
+
return result
|
208 |
+
```
|
209 |
+
[unit test results]:
|
210 |
+
Tests passing:
|
211 |
+
assert longest_subarray_with_sum_limit([], 10) == []
|
212 |
+
assert longest_subarray_with_sum_limit([], 0) == []
|
213 |
+
assert longest_subarray_with_sum_limit([], -5) == []
|
214 |
+
Tests failing:
|
215 |
+
assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 8) == [1, 2, 3] # output: list index out of range
|
216 |
+
assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 15) == [1, 2, 3, 4, 5] # output: list index out of range
|
217 |
+
assert longest_subarray_with_sum_limit([5, 6, 7, 8, 9], 4) == [] # output: list index out of range
|
218 |
+
assert longest_subarray_with_sum_limit([1, -1, 2, -2, 3, -3], 2) == [1, -1, 2, -2, 3] # output: list index out of range
|
219 |
+
[self-reflection]:
|
220 |
+
The implementation failed 4 out of the 7 test cases due to an IndexError. The issue stems from the while loop while current_sum + nums[right] <= target:, which directly accesses nums[right] without checking if right is within the bounds of the list. This results in a runtime error when right goes beyond the list length. To overcome this error, we need to add a bounds check for the right variable in the mentioned while loop. We can modify the loop condition to while right < len(nums) and current_sum + nums[right] <= target:. This change will ensure that we only access elements within the bounds of the list, thus avoiding the IndexError.
|
221 |
+
END OF EXAMPLES
|
222 |
+
"""
|
223 |
+
|
224 |
+
PY_TEST_GENERATION_FEW_SHOT = """Examples:
|
225 |
+
func signature:
|
226 |
+
def add3Numbers(x, y, z):
|
227 |
+
\"\"\" Add three numbers together.
|
228 |
+
This function takes three numbers as input and returns the sum of the three numbers.
|
229 |
+
\"\"\"
|
230 |
+
unit tests:
|
231 |
+
assert add3Numbers(1, 2, 3) == 6
|
232 |
+
assert add3Numbers(-1, 2, 3) == 4
|
233 |
+
assert add3Numbers(1, -2, 3) == 2
|
234 |
+
assert add3Numbers(1, 2, -3) == 0
|
235 |
+
assert add3Numbers(-3, -2, -1) == -6
|
236 |
+
assert add3Numbers(0, 0, 0) == 0
|
237 |
+
"""
|
238 |
+
|
239 |
+
PY_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring. Call your function answer().
|
240 |
+
|
241 |
+
{PY_TEST_GENERATION_FEW_SHOT}"""
|
242 |
+
|
243 |
+
PY_TEST_GENERATION_CHAT_INSTRUCTION = """You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring. Call your function answer()."""
|
244 |
+
|
245 |
+
|
246 |
+
class PyGenerator(Generator):
|
247 |
+
def self_reflection(self, func: str, feedback: str, model: ModelBase) -> str:
|
248 |
+
return generic_generate_self_reflection(
|
249 |
+
func=func,
|
250 |
+
feedback=feedback,
|
251 |
+
model=model,
|
252 |
+
self_reflection_chat_instruction=PY_SELF_REFLECTION_CHAT_INSTRUCTION,
|
253 |
+
self_reflection_completion_instruction=PY_SELF_REFLECTION_COMPLETION_INSTRUCTION,
|
254 |
+
add_code_block=lambda x: add_code_block(x, "python"),
|
255 |
+
self_reflection_few_shot=PY_SELF_REFLECTION_FEW_SHOT
|
256 |
+
)
|
257 |
+
|
258 |
+
def func_impl(
|
259 |
+
self,
|
260 |
+
func_sig: str,
|
261 |
+
model: ModelBase,
|
262 |
+
strategy: str,
|
263 |
+
prev_func_impl: Optional[str] = None,
|
264 |
+
feedback: Optional[str] = None,
|
265 |
+
self_reflection: Optional[str] = None,
|
266 |
+
num_comps: int = 1,
|
267 |
+
temperature: float = 0.8,
|
268 |
+
acc_feedback: Optional[str] = None,
|
269 |
+
acc_reflection: Optional[str] = None,
|
270 |
+
) -> Union[str, List[str]]:
|
271 |
+
if strategy == "mcts":
|
272 |
+
return generate_with_accumulated_context(
|
273 |
+
func_sig=func_sig,
|
274 |
+
model=model,
|
275 |
+
strategy="reflexion",
|
276 |
+
prev_func_impl=prev_func_impl,
|
277 |
+
accumulated_feedback=acc_feedback,
|
278 |
+
accumulated_reflection=acc_reflection,
|
279 |
+
num_comps=num_comps,
|
280 |
+
temperature=temperature,
|
281 |
+
reflexion_chat_instruction=PY_REFLEXION_CHAT_INSTRUCTION,
|
282 |
+
reflexion_few_shot=PY_REFLEXION_FEW_SHOT_ADD,
|
283 |
+
simple_chat_instruction=PY_SIMPLE_CHAT_INSTRUCTION,
|
284 |
+
reflexion_completion_instruction=PY_REFLEXION_COMPLETION_INSTRUCTION,
|
285 |
+
simple_completion_instruction=PY_SIMPLE_COMPLETION_INSTRUCTION,
|
286 |
+
code_block_instruction=USE_PYTHON_CODEBLOCK_INSTRUCTION,
|
287 |
+
parse_code_block=lambda x: parse_code_block(x, "python"),
|
288 |
+
add_code_block=lambda x: add_code_block(x, "python"),
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
return generic_generate_func_impl(
|
292 |
+
func_sig=func_sig,
|
293 |
+
model=model,
|
294 |
+
strategy=strategy,
|
295 |
+
prev_func_impl=prev_func_impl,
|
296 |
+
feedback=feedback,
|
297 |
+
self_reflection=self_reflection,
|
298 |
+
num_comps=num_comps,
|
299 |
+
temperature=temperature,
|
300 |
+
reflexion_chat_instruction=PY_REFLEXION_CHAT_INSTRUCTION,
|
301 |
+
reflexion_few_shot=PY_REFLEXION_FEW_SHOT_ADD,
|
302 |
+
simple_chat_instruction=PY_SIMPLE_CHAT_INSTRUCTION,
|
303 |
+
reflexion_completion_instruction=PY_REFLEXION_COMPLETION_INSTRUCTION,
|
304 |
+
simple_completion_instruction=PY_SIMPLE_COMPLETION_INSTRUCTION,
|
305 |
+
code_block_instruction=USE_PYTHON_CODEBLOCK_INSTRUCTION,
|
306 |
+
parse_code_block=lambda x: parse_code_block(x, "python"),
|
307 |
+
add_code_block=lambda x: add_code_block(x, "python"),
|
308 |
+
)
|
309 |
+
|
310 |
+
def internal_tests(self, func_sig: str, model: ModelBase, max_num_tests: int = 4) -> List[str]:
|
311 |
+
def parse_tests(tests: str) -> List[str]:
|
312 |
+
return [test.strip() for test in tests.splitlines() if "assert" in test]
|
313 |
+
"""
|
314 |
+
Generates tests for a function.
|
315 |
+
"""
|
316 |
+
return generic_generate_internal_tests(
|
317 |
+
func_sig=func_sig,
|
318 |
+
model=model,
|
319 |
+
max_num_tests=max_num_tests,
|
320 |
+
test_generation_few_shot=PY_TEST_GENERATION_FEW_SHOT,
|
321 |
+
test_generation_chat_instruction=PY_TEST_GENERATION_CHAT_INSTRUCTION,
|
322 |
+
test_generation_completion_instruction=PY_TEST_GENERATION_COMPLETION_INSTRUCTION,
|
323 |
+
parse_tests=parse_tests,
|
324 |
+
is_syntax_valid=py_is_syntax_valid,
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
DUMMY_FUNC_SIG = "def func():"
|
329 |
+
DUMMY_FUNC_CALL = "func()"
|
330 |
+
|
331 |
+
|
332 |
+
def handle_first_line_indent(func_body: str) -> str:
|
333 |
+
if func_body.startswith(" "):
|
334 |
+
return func_body
|
335 |
+
split = func_body.splitlines()
|
336 |
+
return f" {split[0]}\n" + "\n".join(split[1:])
|
337 |
+
|
338 |
+
|
339 |
+
def handle_entire_body_indent(func_body: str) -> str:
|
340 |
+
split = func_body.splitlines()
|
341 |
+
res = "\n".join([" " + line for line in split])
|
342 |
+
return res
|
343 |
+
|
344 |
+
|
345 |
+
def fix_turbo_response(func_body: str) -> str:
|
346 |
+
return fix_markdown(remove_unindented_signatures(func_body))
|
347 |
+
|
348 |
+
|
349 |
+
def fix_markdown(func_body: str) -> str:
|
350 |
+
return re.sub("`{3}", "", func_body)
|
351 |
+
|
352 |
+
|
353 |
+
def remove_unindented_signatures(code: str) -> str:
|
354 |
+
regex = r"^def\s+\w+\s*\("
|
355 |
+
|
356 |
+
before_signature = []
|
357 |
+
after_signature = []
|
358 |
+
signature_found = False
|
359 |
+
|
360 |
+
for line in code.split("\n"):
|
361 |
+
if re.match(regex, line):
|
362 |
+
signature_found = True
|
363 |
+
continue
|
364 |
+
|
365 |
+
if signature_found:
|
366 |
+
after_signature.append(line)
|
367 |
+
else:
|
368 |
+
if not line.startswith(" ") and line.strip():
|
369 |
+
line = " " + line
|
370 |
+
before_signature.append(line)
|
371 |
+
|
372 |
+
return "\n".join(before_signature + after_signature)
|
373 |
+
|
374 |
+
|
375 |
+
def py_fix_indentation(func_body: str) -> str:
|
376 |
+
func_body = fix_turbo_response(func_body)
|
377 |
+
"""
|
378 |
+
3 cases:
|
379 |
+
1. good syntax
|
380 |
+
2. first line not good
|
381 |
+
3. entire body not good
|
382 |
+
"""
|
383 |
+
def parse_indent_rec(f_body: str, cur_state: int) -> str:
|
384 |
+
f_body = fix_markdown(f_body)
|
385 |
+
if cur_state > 1:
|
386 |
+
return f_body
|
387 |
+
code = f'{DUMMY_FUNC_SIG}\n{f_body}\n{DUMMY_FUNC_CALL}'
|
388 |
+
try:
|
389 |
+
exec(code)
|
390 |
+
return f_body
|
391 |
+
except (IndentationError, SyntaxError):
|
392 |
+
p_func = handle_first_line_indent if cur_state == 0 else handle_entire_body_indent
|
393 |
+
return parse_indent_rec(p_func(func_body), cur_state + 1)
|
394 |
+
except Exception:
|
395 |
+
return f_body
|
396 |
+
return parse_indent_rec(func_body, 0)
|
397 |
+
|
398 |
+
|
399 |
+
def py_is_syntax_valid(code: str) -> bool:
|
400 |
+
try:
|
401 |
+
ast.parse(code)
|
402 |
+
return True
|
403 |
+
except Exception:
|
404 |
+
return False
|
lats/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
lats/lats.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import enumerate_resume, make_printv, write_jsonl, resume_success_count
|
2 |
+
from executors import executor_factory
|
3 |
+
from generators import generator_factory, model_factory
|
4 |
+
from typing import List, Dict, Any
|
5 |
+
import math
|
6 |
+
from typing import Tuple
|
7 |
+
import sys
|
8 |
+
import random
|
9 |
+
|
10 |
+
sys.set_int_max_str_digits(100000) # Increase the limit to 10000 digits
|
11 |
+
|
12 |
+
react_prompt_header = "Here are some previous solutions and the corresponding test results.\n"
|
13 |
+
react_prompt_starter = "\n\nYour solution:\n"
|
14 |
+
extra_header = "\n\nName the function answer()"
|
15 |
+
|
16 |
+
class Node:
|
17 |
+
def __init__(self, solution: str, parent=None, context="", depth=0):
|
18 |
+
self.solution = solution
|
19 |
+
self.parent = parent
|
20 |
+
self.children = []
|
21 |
+
self.value = 0
|
22 |
+
self.visits = 0
|
23 |
+
self.context = ""
|
24 |
+
self.depth = depth
|
25 |
+
self.reflection = ""
|
26 |
+
self.test_feedback = ""
|
27 |
+
|
28 |
+
def uct(self, exploration_weight=1.0):
|
29 |
+
if self.visits == 0:
|
30 |
+
#return float('inf')
|
31 |
+
return self.value
|
32 |
+
return (self.value / self.visits) + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
|
33 |
+
|
34 |
+
def best_child(self):
|
35 |
+
if not self.children: # Check if children list is empty
|
36 |
+
return None
|
37 |
+
return max(self.children, key=lambda child: child.uct())
|
38 |
+
|
39 |
+
def best_child_value(self):
|
40 |
+
if not self.children: # Check if children list is empty
|
41 |
+
return None
|
42 |
+
return max(self.children, key=lambda child: child.value)
|
43 |
+
|
44 |
+
def update(self, reward: float):
|
45 |
+
self.visits += 1
|
46 |
+
self.value += reward
|
47 |
+
|
48 |
+
|
49 |
+
def prune_context_blocks(context: str, max_length: int) -> str:
|
50 |
+
"""Prune the context to fit within the specified max_length by removing entire blocks of content using 'trial' as a delimiter."""
|
51 |
+
if len(context) <= max_length:
|
52 |
+
return context
|
53 |
+
|
54 |
+
# Split by the block delimiter "trial".
|
55 |
+
blocks = context.split('Previous Trial')
|
56 |
+
|
57 |
+
# Remove the earliest blocks until the context fits within max_length.
|
58 |
+
while len('trial'.join(blocks)) > max_length and blocks:
|
59 |
+
blocks.pop(0)
|
60 |
+
|
61 |
+
return 'trial'.join(blocks)
|
62 |
+
|
63 |
+
def gather_context_from_tree(node: Node) -> Tuple[List[str], List[str]]:
|
64 |
+
"""
|
65 |
+
Given a node, walk up its tree and gather the feedback and reflections
|
66 |
+
from each parent node until the root is reached.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
node (Node): The node to start gathering context from.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
Tuple[List[str], List[str]]: Two lists containing the accumulated feedback and reflections.
|
73 |
+
"""
|
74 |
+
accumulated_feedback = []
|
75 |
+
accumulated_reflection = []
|
76 |
+
num_nodes = 0
|
77 |
+
|
78 |
+
while node and num_nodes < 2:
|
79 |
+
num_nodes += 1
|
80 |
+
if node.test_feedback:
|
81 |
+
accumulated_feedback.append(node.test_feedback)
|
82 |
+
if node.reflection:
|
83 |
+
accumulated_reflection.append(node.reflection)
|
84 |
+
node = node.parent
|
85 |
+
|
86 |
+
# Reverse the lists so that the context from the earliest nodes is first
|
87 |
+
return accumulated_feedback[::-1], accumulated_reflection[::-1]
|
88 |
+
|
89 |
+
def sample_n_random(items: List[str], n: int) -> List[str]:
|
90 |
+
"""Sample min(n, len(items)) random items from a list"""
|
91 |
+
assert n >= 0
|
92 |
+
if n >= len(items):
|
93 |
+
return items
|
94 |
+
return random.sample(items, n)
|
95 |
+
|
96 |
+
def run_lats(
|
97 |
+
model_name: str,
|
98 |
+
language: str,
|
99 |
+
max_iters: int,
|
100 |
+
verbose: bool,
|
101 |
+
instruction: str = "Write some code to print Hello World in Python",
|
102 |
+
n_samples: int = 3,
|
103 |
+
depth: int = 5,
|
104 |
+
) -> None:
|
105 |
+
exe = executor_factory(language)
|
106 |
+
gen = generator_factory(language)
|
107 |
+
model = model_factory(model_name)
|
108 |
+
|
109 |
+
|
110 |
+
num_success = 0 # Counter for successful solutions
|
111 |
+
cur_func_impl = None
|
112 |
+
|
113 |
+
item = {}
|
114 |
+
|
115 |
+
#for idx, item in enumerate(dataset):
|
116 |
+
|
117 |
+
tests = gen.internal_tests(instruction + extra_header, model, 1)
|
118 |
+
tests_i = sample_n_random(tests, 1)
|
119 |
+
|
120 |
+
while cur_func_impl is None:
|
121 |
+
cur_func_impl = gen.func_impl(instruction + extra_header, model, "simple")
|
122 |
+
root = Node(cur_func_impl) # initial solution (for pass@1 metric)
|
123 |
+
|
124 |
+
# Lists for logging
|
125 |
+
reflections = []
|
126 |
+
implementations = []
|
127 |
+
test_feedback = []
|
128 |
+
is_solved = False
|
129 |
+
|
130 |
+
# first attempt
|
131 |
+
|
132 |
+
implementations.append(cur_func_impl)
|
133 |
+
assert isinstance(cur_func_impl, str)
|
134 |
+
is_passing, feedback, _ = exe.execute(cur_func_impl, tests_i)
|
135 |
+
test_feedback.append(feedback)
|
136 |
+
|
137 |
+
# if solved, exit early
|
138 |
+
if is_passing:
|
139 |
+
num_success += 1
|
140 |
+
return cur_func_impl # GET SOLUTION
|
141 |
+
|
142 |
+
reflection = gen.self_reflection(cur_func_impl, feedback, model)
|
143 |
+
reflections += [reflection]
|
144 |
+
root.test_feedback = feedback
|
145 |
+
root.reflection = reflection
|
146 |
+
max_iters = int(max_iters)
|
147 |
+
for cur_iter in range(max_iters):
|
148 |
+
# Selection
|
149 |
+
tests_i = sample_n_random(tests, 1)
|
150 |
+
|
151 |
+
node = root
|
152 |
+
trajectory = {
|
153 |
+
'solutions': [],
|
154 |
+
'feedbacks': []
|
155 |
+
}
|
156 |
+
|
157 |
+
while node.children:
|
158 |
+
node = node.best_child()
|
159 |
+
trajectory['solutions'].append(node.solution)
|
160 |
+
|
161 |
+
# Expansion
|
162 |
+
for _ in range(n_samples):
|
163 |
+
new_solution = None
|
164 |
+
strategy = "mcts"
|
165 |
+
prev_func_impl = node.solution
|
166 |
+
feedback = node.test_feedback
|
167 |
+
reflection = node.reflection
|
168 |
+
acc_feedback, acc_reflection = gather_context_from_tree(node)
|
169 |
+
|
170 |
+
while new_solution is None:
|
171 |
+
new_solution = gen.func_impl(
|
172 |
+
func_sig=instruction+extra_header,
|
173 |
+
model=model,
|
174 |
+
strategy=strategy,
|
175 |
+
prev_func_impl=prev_func_impl,
|
176 |
+
feedback=feedback,
|
177 |
+
self_reflection=reflection,
|
178 |
+
acc_feedback = acc_feedback,
|
179 |
+
acc_reflection = acc_reflection
|
180 |
+
)
|
181 |
+
|
182 |
+
combined_context = "\nPrevious Trial\n\n" + new_solution
|
183 |
+
|
184 |
+
child = Node(new_solution, parent=node, context=combined_context, depth=node.depth + 1)
|
185 |
+
node.children.append(child)
|
186 |
+
|
187 |
+
# Simulation
|
188 |
+
reward_real = 0
|
189 |
+
for child in node.children:
|
190 |
+
is_passing_internal, feedback_internal, _ = exe.execute(child.solution, tests_i)
|
191 |
+
if not is_passing_internal:
|
192 |
+
reflection = gen.self_reflection(child.solution, feedback_internal, model)
|
193 |
+
reflections.append(reflection)
|
194 |
+
child.reflection = reflection
|
195 |
+
child.test_feedback = feedback_internal
|
196 |
+
child.context += "\n\nPrevious Trial\n\n" + child.solution + "\n\nTest results: \n" + feedback_internal + "\n\nSelf-reflection: " + reflection
|
197 |
+
else:
|
198 |
+
child.context += "\n\nPrevious Trial\n\n" + child.solution + "\n\nTest results: \n" + feedback_internal
|
199 |
+
child.reflection = ""
|
200 |
+
child.test_feedback = feedback_internal
|
201 |
+
|
202 |
+
if "Tested passed:" in feedback_internal:
|
203 |
+
# Split at "Tests failed:" and get the part before it (which contains the passed tests)
|
204 |
+
passed_section = feedback_internal.split("Tests failed:")[0]
|
205 |
+
# Split at "Tested passed:" and get the part after it, then count the non-empty lines
|
206 |
+
reward_internal = len([line for line in passed_section.split("Tested passed:")[1].splitlines() if line.strip() != ''])
|
207 |
+
reward_internal = reward_internal / len(tests_i)
|
208 |
+
else:
|
209 |
+
reward_internal = 0
|
210 |
+
if is_passing_internal or cur_iter == max_iters - 1:
|
211 |
+
item["solution"] = child.solution
|
212 |
+
break
|
213 |
+
|
214 |
+
if is_solved:
|
215 |
+
break
|
216 |
+
|
217 |
+
reward = reward_internal + reward_real
|
218 |
+
child.update(reward)
|
219 |
+
|
220 |
+
# Backpropagation
|
221 |
+
temp = child
|
222 |
+
while temp.parent:
|
223 |
+
temp = temp.parent
|
224 |
+
temp.update(reward)
|
225 |
+
|
226 |
+
# Choose the best solution after all iterations
|
227 |
+
if is_solved:
|
228 |
+
best_solution = item["solution"]
|
229 |
+
else:
|
230 |
+
best_solution = root.best_child_value().solution
|
231 |
+
item["solution"] = best_solution
|
232 |
+
|
233 |
+
return best_solution
|
lats/lats_main.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
from lats import run_lats
|
5 |
+
|
6 |
+
|
7 |
+
def get_args():
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--run_name", type=str, help="The name of the run")
|
10 |
+
parser.add_argument("--root_dir", type=str,
|
11 |
+
help="The root logging directory", default="root")
|
12 |
+
parser.add_argument("--dataset_path", type=str,
|
13 |
+
help="The path to the benchmark dataset", default="root")
|
14 |
+
parser.add_argument("--strategy", type=str,
|
15 |
+
help="Strategy: `simple`, `reflexion`")
|
16 |
+
parser.add_argument("--language", type=str, help="Strategy: `py` or `rs`")
|
17 |
+
parser.add_argument(
|
18 |
+
"--model", type=str, help="OpenAI models only for now. For best results, use GPT-4")
|
19 |
+
parser.add_argument("--pass_at_k", type=int,
|
20 |
+
help="Pass@k metric", default=1)
|
21 |
+
parser.add_argument("--max_iters", type=int,
|
22 |
+
help="The maximum number of self-improvement iterations", default=10)
|
23 |
+
parser.add_argument("--expansion_factor", type=int,
|
24 |
+
help="The expansion factor for the reflexion UCS and A* strategy", default=3)
|
25 |
+
parser.add_argument("--verbose", action='store_true',
|
26 |
+
help="To print live logs")
|
27 |
+
parser.add_argument("--instruction", type=str,
|
28 |
+
help="text string", default="")
|
29 |
+
parser.add_argument("--n_samples", type=int,
|
30 |
+
help="The number of nodes added during expansion", default=3)
|
31 |
+
parser.add_argument("--depth", type=int,
|
32 |
+
help="Tree depth", default=5)
|
33 |
+
|
34 |
+
# TODO: implement this
|
35 |
+
# parser.add_argument("--is_resume", action='store_true', help="To resume run")
|
36 |
+
# parser.add_argument("--resume_dir", type=str, help="If resume, the logging directory", default="")
|
37 |
+
args = parser.parse_args()
|
38 |
+
return args
|
39 |
+
|
40 |
+
|
41 |
+
def strategy_factory(strategy: str):
|
42 |
+
def kwargs_wrapper_gen(func, delete_keys=[]):
|
43 |
+
def kwargs_wrapper(**kwargs):
|
44 |
+
for key in delete_keys:
|
45 |
+
del kwargs[key]
|
46 |
+
return func(**kwargs)
|
47 |
+
return kwargs_wrapper
|
48 |
+
|
49 |
+
return kwargs_wrapper_gen(run_lats, delete_keys=[])
|
50 |
+
|
51 |
+
|
52 |
+
def lats_main(args):
|
53 |
+
|
54 |
+
# check if the strategy is valid
|
55 |
+
run_strategy = strategy_factory(args.strategy)
|
56 |
+
|
57 |
+
# start the run
|
58 |
+
# evaluate with pass@k
|
59 |
+
x = run_strategy(
|
60 |
+
model_name=args.model,
|
61 |
+
language=args.language,
|
62 |
+
max_iters=args.max_iters,
|
63 |
+
verbose=args.verbose,
|
64 |
+
instruction=args.instruction,
|
65 |
+
n_samples=args.n_samples,
|
66 |
+
depth=args.depth
|
67 |
+
)
|
68 |
+
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def main(args):
|
74 |
+
lats_main(args)
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
args = get_args()
|
78 |
+
main(args)
|
lats/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
jsonlines==3.1.0
|
2 |
+
openai==0.27.0
|
3 |
+
datasets==2.7.0
|
4 |
+
tenacity==8.1.0
|
5 |
+
astunparse==1.6.3
|
6 |
+
torch
|
7 |
+
xformers
|
8 |
+
transformers
|
9 |
+
accelerate
|
lats/utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gzip
|
3 |
+
import json
|
4 |
+
import openai
|
5 |
+
import jsonlines
|
6 |
+
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
10 |
+
|
11 |
+
def make_printv(verbose: bool):
|
12 |
+
def print_v(*args, **kwargs):
|
13 |
+
if verbose:
|
14 |
+
kwargs["flush"] = True
|
15 |
+
print(*args, **kwargs)
|
16 |
+
else:
|
17 |
+
pass
|
18 |
+
return print_v
|
19 |
+
|
20 |
+
|
21 |
+
def read_jsonl(path: str) -> List[dict]:
|
22 |
+
if not os.path.exists(path):
|
23 |
+
raise FileNotFoundError(f"File `{path}` does not exist.")
|
24 |
+
elif not path.endswith(".jsonl"):
|
25 |
+
raise ValueError(f"File `{path}` is not a jsonl file.")
|
26 |
+
items = []
|
27 |
+
with jsonlines.open(path) as reader:
|
28 |
+
for item in reader:
|
29 |
+
items += [item]
|
30 |
+
return items
|
31 |
+
|
32 |
+
|
33 |
+
def write_jsonl(path: str, data: List[dict], append: bool = False):
|
34 |
+
with jsonlines.open(path, mode='a' if append else 'w') as writer:
|
35 |
+
for item in data:
|
36 |
+
writer.write(item)
|
37 |
+
|
38 |
+
|
39 |
+
def read_jsonl_gz(path: str) -> List[dict]:
|
40 |
+
if not path.endswith(".jsonl.gz"):
|
41 |
+
raise ValueError(f"File `{path}` is not a jsonl.gz file.")
|
42 |
+
with gzip.open(path, "rt") as f:
|
43 |
+
data = [json.loads(line) for line in f]
|
44 |
+
return data
|
45 |
+
|
46 |
+
|
47 |
+
# generator that returns the item and the index in the dataset.
|
48 |
+
# if the results_path exists, it will skip all items that have been processed
|
49 |
+
# before.
|
50 |
+
def enumerate_resume(dataset, results_path):
|
51 |
+
if not os.path.exists(results_path):
|
52 |
+
for i, item in enumerate(dataset):
|
53 |
+
yield i, item
|
54 |
+
else:
|
55 |
+
count = 0
|
56 |
+
with jsonlines.open(results_path) as reader:
|
57 |
+
for item in reader:
|
58 |
+
count += 1
|
59 |
+
|
60 |
+
for i, item in enumerate(dataset):
|
61 |
+
# skip items that have been processed before
|
62 |
+
if i < count:
|
63 |
+
continue
|
64 |
+
yield i, item
|
65 |
+
|
66 |
+
|
67 |
+
def resume_success_count(dataset) -> int:
|
68 |
+
count = 0
|
69 |
+
for item in dataset:
|
70 |
+
if "is_solved" in item and item["is_solved"]:
|
71 |
+
count += 1
|
72 |
+
return count
|
73 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
jsonlines==3.1.0
|
2 |
+
openai==0.27.0
|
3 |
+
datasets==2.7.0
|
4 |
+
tenacity==8.1.0
|
5 |
+
astunparse==1.6.3
|
6 |
+
torch
|
7 |
+
transformers
|
8 |
+
accelerate
|
9 |
+
xformers
|