File size: 19,225 Bytes
41d1bc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 |
from generators.model import ModelBase, message_to_str
from .generator_types import Generator
from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection, generate_with_accumulated_context
from typing import Optional, List, Union
import ast
import re
from .parse import parse_code_block, add_code_block
PY_SIMPLE_COMPLETION_INSTRUCTION = "# Write the body of this function only."
PY_REFLEXION_COMPLETION_INSTRUCTION = "You are a Python writing assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature).\n\n-----"
PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = "You are a Python writing assistant. You will be given a function implementation and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation.\n\n-----"
USE_PYTHON_CODEBLOCK_INSTRUCTION = "Use a Python code block to write your response. For example:\n```python\nprint('Hello world!')\n```"
PY_SIMPLE_CHAT_INSTRUCTION = "You are an AI that only responds with python code, NOT ENGLISH. You will be given a function signature and its docstring by the user. Write your full implementation (restate the function signature)."
PY_SIMPLE_CHAT_INSTRUCTION_V2 = "You are an AI that only responds with only python code. You will be given a function signature and its docstring by the user. Write your full implementation (restate the function signature)."
PY_REFLEXION_CHAT_INSTRUCTION = "You are an AI Python assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature)."
PY_REFLEXION_CHAT_INSTRUCTION_V2 = "You are an AI Python assistant. You will be given your previous implementation of a function, a series of unit tests results, and your self-reflection on your previous implementation. Write your full implementation (restate the function signature)."
PY_REFLEXION_FEW_SHOT_ADD = '''Example 1:
[previous impl]:
```python
def add(a: int, b: int) -> int:
"""
Given integers a and b, return the total value of a and b.
"""
return a - b
```
[unit test results from previous impl]:
Tested passed:
Tests failed:
assert add(1, 2) == 3 # output: -1
assert add(1, 2) == 4 # output: -1
[reflection on previous impl]:
The implementation failed the test cases where the input integers are 1 and 2. The issue arises because the code does not add the two integers together, but instead subtracts the second integer from the first. To fix this issue, we should change the operator from `-` to `+` in the return statement. This will ensure that the function returns the correct output for the given input.
[improved impl]:
```python
def add(a: int, b: int) -> int:
"""
Given integers a and b, return the total value of a and b.
"""
return a + b
```
'''
PY_REFLEXION_FEW_SHOT = '''Example 1:
[previous impl]:
```python
from typing import *
def fullJustify(words: List[str], maxWidth: int) -> List[str]:
"""
Given an array of words and a width maxWidth, format the text such that each line has exactly maxWidth characters and is fully (left and right) justified.
You should pack your words in a greedy approach; that is, pack as many words as you can in each line. Pad extra spaces `' '` when necessary so that each line has exactly maxWidth characters.
Extra spaces between words should be distributed as evenly as possible. If the number of spaces on a line do not divide evenly between words, the empty slots on the left will be assigned more spaces than the slots on the right.
For the last line of text, it should be left justified and no extra space is inserted between words.
Note:
A word is defined as a character sequence consisting of non-space characters only.
Each word's length is guaranteed to be greater than 0 and not exceed maxWidth.
The input array `words` contains at least one word.
"""
res = []
cur_line = []
cur_len = 0
for word in words:
if cur_len + len(word) + len(cur_line) > maxWidth:
if len(cur_line) == 1:
res.append(cur_line[0] + ' ' * (maxWidth - cur_len))
else:
spaces = maxWidth - cur_len
space_between = spaces // (len(cur_line) - 1)
extra_spaces = spaces % (len(cur_line) - 1)
line = ''
for i, w in enumerate(cur_line[:-1]):
line += w + ' ' * (space_between + (i < extra_spaces))
line += cur_line[-1]
res.append(line)
cur_line = []
cur_len = 0
cur_line.append(word)
cur_len += len(word)
last_line = ' '.join(cur_line)
last_line += ' ' * (maxWidth - len(last_line))
res.append(last_line)
return res
```
[unit test results from previous impl]:
Tested passed:
Tests failed:
assert fullJustify([], 10) == [] # output: [' ']
assert fullJustify([], 0) == [] # output: ['']
[reflection on previous impl]:
The implementation failed the test cases where the input list of words is empty. The issue arises because the code does not handle the case where there are no words to process. As a result, it still appends a line with spaces to the result list, even when there are no words. To fix this issue, we should add a condition at the beginning of the function to check if the input list is empty, and return an empty list if it is. This will ensure that the function returns the correct output for empty input lists.
[improved impl]:
```python
from typing import *
def fullJustify(words: List[str], maxWidth: int) -> List[str]:
"""
Given an array of words and a width maxWidth, format the text such that each line has exactly maxWidth characters and is fully (left and right) justified.
You should pack your words in a greedy approach; that is, pack as many words as you can in each line. Pad extra spaces `' '` when necessary so that each line has exactly maxWidth characters.
Extra spaces between words should be distributed as evenly as possible. If the number of spaces on a line do not divide evenly between words, the empty slots on the left will be assigned more spaces than the slots on the right.
For the last line of text, it should be left justified and no extra space is inserted between words.
Note:
A word is defined as a character sequence consisting of non-space characters only.
Each word's length is guaranteed to be greater than 0 and not exceed maxWidth.
The input array `words` contains at least one word.
"""
if not words:
return []
res = []
cur_line = []
cur_len = 0
for word in words:
if cur_len + len(word) + len(cur_line) > maxWidth:
if len(cur_line) == 1:
res.append(cur_line[0] + ' ' * (maxWidth - cur_len))
else:
spaces = maxWidth - cur_len
space_between = spaces // (len(cur_line) - 1)
extra_spaces = spaces % (len(cur_line) - 1)
line = ''
for i, w in enumerate(cur_line[:-1]):
line += w + ' ' * (space_between + (i < extra_spaces))
line += cur_line[-1]
res.append(line)
cur_line = []
cur_len = 0
cur_line.append(word)
cur_len += len(word)
last_line = ' '.join(cur_line)
last_line += ' ' * (maxWidth - len(last_line))
res.append(last_line)
return res
```
END EXAMPLES
'''
PY_SELF_REFLECTION_CHAT_INSTRUCTION = "You are a Python programming assistant. You will be given a function implementation and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation."
PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = "You are a Python programming assistant. You will be given a function implementation and a series of unit test results. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as guidance when you try again later. Only provide the few sentence description in your answer, not the implementation. You will be given a few examples by the user."
PY_SELF_REFLECTION_FEW_SHOT = """Example 1:
[function impl]:
```python
def longest_subarray_with_sum_limit(nums: List[int], target: int) -> List[int]:
n = len(nums)
left, right = 0, 0
max_length = 0
current_sum = 0
result = []
while right < n:
current_sum += nums[right]
while current_sum > target:
current_sum -= nums[left]
left += 1
if right - left + 1 >= max_length:
max_length = right - left + 1
result = nums[left:right+1]
right += 1
return result
```
[unit test results]:
Tests passing:
assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 8) == [1, 2, 3]
assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 15) == [1, 2, 3, 4, 5]
assert longest_subarray_with_sum_limit([1, -1, 2, -2, 3, -3], 2) == [1, -1, 2, -2, 3]
assert longest_subarray_with_sum_limit([], 10) == []
assert longest_subarray_with_sum_limit([], 0) == []
assert longest_subarray_with_sum_limit([], -5) == []
Tests failing:
assert longest_subarray_with_sum_limit([5, 6, 7, 8, 9], 4) == [] # output: [5]
[self-reflection]:
The implementation failed the where no subarray fulfills the condition. The issue in the implementation is due to the use of >= instead of > in the condition to update the result. Because of this, it returns a subarray even when the sum is greater than the target, as it still updates the result when the current subarray length is equal to the previous longest subarray length. To overcome this error, we should change the condition to only update the result when the current subarray length is strictly greater than the previous longest subarray length. This can be done by replacing >= with > in the condition.
Example 2:
[function impl]:
```python
def longest_subarray_with_sum_limit(nums: List[int], target: int) -> List[int]:
n = len(nums)
left, right = 0, 0
max_length = 0
current_sum = 0
result = []
while current_sum + nums[right] <= target:
current_sum += nums[right]
right += 1
while right < n:
current_sum += nums[right]
while current_sum > target:
current_sum -= nums[left]
left += 1
if right - left + 1 > max_length:
max_length = right - left + 1
result = nums[left:right+1]
right += 1
return result
```
[unit test results]:
Tests passing:
assert longest_subarray_with_sum_limit([], 10) == []
assert longest_subarray_with_sum_limit([], 0) == []
assert longest_subarray_with_sum_limit([], -5) == []
Tests failing:
assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 8) == [1, 2, 3] # output: list index out of range
assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 15) == [1, 2, 3, 4, 5] # output: list index out of range
assert longest_subarray_with_sum_limit([5, 6, 7, 8, 9], 4) == [] # output: list index out of range
assert longest_subarray_with_sum_limit([1, -1, 2, -2, 3, -3], 2) == [1, -1, 2, -2, 3] # output: list index out of range
[self-reflection]:
The implementation failed 4 out of the 7 test cases due to an IndexError. The issue stems from the while loop while current_sum + nums[right] <= target:, which directly accesses nums[right] without checking if right is within the bounds of the list. This results in a runtime error when right goes beyond the list length. To overcome this error, we need to add a bounds check for the right variable in the mentioned while loop. We can modify the loop condition to while right < len(nums) and current_sum + nums[right] <= target:. This change will ensure that we only access elements within the bounds of the list, thus avoiding the IndexError.
END OF EXAMPLES
"""
PY_TEST_GENERATION_FEW_SHOT = """Examples:
func signature:
def add3Numbers(x, y, z):
\"\"\" Add three numbers together.
This function takes three numbers as input and returns the sum of the three numbers.
\"\"\"
unit tests:
assert add3Numbers(1, 2, 3) == 6
assert add3Numbers(-1, 2, 3) == 4
assert add3Numbers(1, -2, 3) == 2
assert add3Numbers(1, 2, -3) == 0
assert add3Numbers(-3, -2, -1) == -6
assert add3Numbers(0, 0, 0) == 0
"""
PY_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring. Call your function answer().
{PY_TEST_GENERATION_FEW_SHOT}"""
PY_TEST_GENERATION_CHAT_INSTRUCTION = """You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring. Call your function answer()."""
class PyGenerator(Generator):
def self_reflection(self, func: str, feedback: str, model: ModelBase) -> str:
return generic_generate_self_reflection(
func=func,
feedback=feedback,
model=model,
self_reflection_chat_instruction=PY_SELF_REFLECTION_CHAT_INSTRUCTION,
self_reflection_completion_instruction=PY_SELF_REFLECTION_COMPLETION_INSTRUCTION,
add_code_block=lambda x: add_code_block(x, "python"),
self_reflection_few_shot=PY_SELF_REFLECTION_FEW_SHOT
)
def func_impl(
self,
func_sig: str,
model: ModelBase,
strategy: str,
prev_func_impl: Optional[str] = None,
feedback: Optional[str] = None,
self_reflection: Optional[str] = None,
num_comps: int = 1,
temperature: float = 0.8,
acc_feedback: Optional[str] = None,
acc_reflection: Optional[str] = None,
) -> Union[str, List[str]]:
if strategy == "mcts":
return generate_with_accumulated_context(
func_sig=func_sig,
model=model,
strategy="reflexion",
prev_func_impl=prev_func_impl,
accumulated_feedback=acc_feedback,
accumulated_reflection=acc_reflection,
num_comps=num_comps,
temperature=temperature,
reflexion_chat_instruction=PY_REFLEXION_CHAT_INSTRUCTION,
reflexion_few_shot=PY_REFLEXION_FEW_SHOT_ADD,
simple_chat_instruction=PY_SIMPLE_CHAT_INSTRUCTION,
reflexion_completion_instruction=PY_REFLEXION_COMPLETION_INSTRUCTION,
simple_completion_instruction=PY_SIMPLE_COMPLETION_INSTRUCTION,
code_block_instruction=USE_PYTHON_CODEBLOCK_INSTRUCTION,
parse_code_block=lambda x: parse_code_block(x, "python"),
add_code_block=lambda x: add_code_block(x, "python"),
)
else:
return generic_generate_func_impl(
func_sig=func_sig,
model=model,
strategy=strategy,
prev_func_impl=prev_func_impl,
feedback=feedback,
self_reflection=self_reflection,
num_comps=num_comps,
temperature=temperature,
reflexion_chat_instruction=PY_REFLEXION_CHAT_INSTRUCTION,
reflexion_few_shot=PY_REFLEXION_FEW_SHOT_ADD,
simple_chat_instruction=PY_SIMPLE_CHAT_INSTRUCTION,
reflexion_completion_instruction=PY_REFLEXION_COMPLETION_INSTRUCTION,
simple_completion_instruction=PY_SIMPLE_COMPLETION_INSTRUCTION,
code_block_instruction=USE_PYTHON_CODEBLOCK_INSTRUCTION,
parse_code_block=lambda x: parse_code_block(x, "python"),
add_code_block=lambda x: add_code_block(x, "python"),
)
def internal_tests(self, func_sig: str, model: ModelBase, max_num_tests: int = 4) -> List[str]:
def parse_tests(tests: str) -> List[str]:
return [test.strip() for test in tests.splitlines() if "assert" in test]
"""
Generates tests for a function.
"""
return generic_generate_internal_tests(
func_sig=func_sig,
model=model,
max_num_tests=max_num_tests,
test_generation_few_shot=PY_TEST_GENERATION_FEW_SHOT,
test_generation_chat_instruction=PY_TEST_GENERATION_CHAT_INSTRUCTION,
test_generation_completion_instruction=PY_TEST_GENERATION_COMPLETION_INSTRUCTION,
parse_tests=parse_tests,
is_syntax_valid=py_is_syntax_valid,
)
DUMMY_FUNC_SIG = "def func():"
DUMMY_FUNC_CALL = "func()"
def handle_first_line_indent(func_body: str) -> str:
if func_body.startswith(" "):
return func_body
split = func_body.splitlines()
return f" {split[0]}\n" + "\n".join(split[1:])
def handle_entire_body_indent(func_body: str) -> str:
split = func_body.splitlines()
res = "\n".join([" " + line for line in split])
return res
def fix_turbo_response(func_body: str) -> str:
return fix_markdown(remove_unindented_signatures(func_body))
def fix_markdown(func_body: str) -> str:
return re.sub("`{3}", "", func_body)
def remove_unindented_signatures(code: str) -> str:
regex = r"^def\s+\w+\s*\("
before_signature = []
after_signature = []
signature_found = False
for line in code.split("\n"):
if re.match(regex, line):
signature_found = True
continue
if signature_found:
after_signature.append(line)
else:
if not line.startswith(" ") and line.strip():
line = " " + line
before_signature.append(line)
return "\n".join(before_signature + after_signature)
def py_fix_indentation(func_body: str) -> str:
func_body = fix_turbo_response(func_body)
"""
3 cases:
1. good syntax
2. first line not good
3. entire body not good
"""
def parse_indent_rec(f_body: str, cur_state: int) -> str:
f_body = fix_markdown(f_body)
if cur_state > 1:
return f_body
code = f'{DUMMY_FUNC_SIG}\n{f_body}\n{DUMMY_FUNC_CALL}'
try:
exec(code)
return f_body
except (IndentationError, SyntaxError):
p_func = handle_first_line_indent if cur_state == 0 else handle_entire_body_indent
return parse_indent_rec(p_func(func_body), cur_state + 1)
except Exception:
return f_body
return parse_indent_rec(func_body, 0)
def py_is_syntax_valid(code: str) -> bool:
try:
ast.parse(code)
return True
except Exception:
return False
|