Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tempfile | |
import pytest | |
import io | |
import sys | |
import os | |
import requests | |
import ast | |
api_base = "https://api.endpoints.anyscale.com/v1" | |
token = os.environ["OPENAI_API_KEY"] | |
url = f"{api_base}/chat/completions" | |
def extract_functions_from_file(filename): | |
"""Given a file written to disk, extract all functions from it into a list.""" | |
with open(filename, "r") as file: | |
tree = ast.parse(file.read()) | |
functions = [] | |
for node in ast.walk(tree): | |
if isinstance(node, ast.FunctionDef): | |
start_line = node.lineno | |
end_line = node.end_lineno if hasattr(node, "end_lineno") else start_line | |
with open(filename, "r") as file: | |
function_code = "".join( | |
[ | |
line | |
for i, line in enumerate(file) | |
if start_line <= i + 1 <= end_line | |
] | |
) | |
functions.append(function_code) | |
return functions | |
def extract_tests_from_list(l): | |
"""Given a list of strings, extract all functions from it into a list.""" | |
return [t for t in l if t.startswith("def")] | |
def remove_leading_whitespace(func_str): | |
"""Given a string representing a function, remove the leading whitespace from each | |
line such that the function definition is left-aligned and all following lines | |
follow Python's whitespace formatting rules. | |
""" | |
lines = func_str.split("\n") | |
# Find the amount of whitespace before 'def' (the function signature) | |
leading_whitespace = len(lines[0]) - len(lines[0].lstrip()) | |
# Remove that amount of whitespace from each line | |
new_lines = [line[leading_whitespace:] for line in lines if line.strip()] | |
return "\n".join(new_lines) | |
def main(fxn: str, openai_api_key, examples: str = "", temperature: float = 0.7): | |
"""Requires Anyscale Endpoints Alpha API access. | |
If examples is not a empty string, it will be formatted into | |
a list of input/output pairs used to prompt the model. | |
""" | |
s = requests.Session() | |
api_base = os.environ["OPENAI_API_BASE"] | |
token = openai_api_key | |
url = f"{api_base}/chat/completions" | |
message = "Write me a test of this function\n{}".format(fxn) | |
if examples: | |
message += "\nExample input output pairs:\n" | |
system_prompt = """ | |
You are a helpful coding assistant. | |
Your job is to help people write unit tests for their python code. Please write all | |
unit tests in the format expected by pytest. If inputs and outputs are provided, | |
return a set of unit tests that will verify that the function will produce the | |
corect outputs. Also provide tests to handle base and edge cases. It is very | |
important that the code is formatted correctly for pytest. | |
""" | |
body = { | |
"model": "meta-llama/Llama-2-70b-chat-hf", | |
"messages": [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": message}, | |
], | |
"temperature": temperature, | |
} | |
with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp: | |
response = resp.json()["choices"][0] | |
if response["finish_reason"] != "stop": | |
raise ValueError("Print please try again -- response was not finished!") | |
# Parse the response to get the tests out. | |
split_response = response["message"]["content"].split("```") | |
if len(split_response) != 3: | |
raise ValueError("Please try again -- response generated too many code blocks!") | |
all_tests = split_response[1] | |
# Writes out all tests to a file. Then, extracts each individual test out into a | |
# list. | |
with tempfile.NamedTemporaryFile( | |
prefix="all_tests_", suffix=".py", mode="w" | |
) as temp: | |
temp.writelines(all_tests) | |
temp.flush() | |
parsed_tests = extract_functions_from_file(temp.name) | |
# Loop through test, run pytest, and return two lists of tests. | |
passed_tests, failed_tests = [], [] | |
for test in parsed_tests: | |
test_formatted = remove_leading_whitespace(test) | |
print("testing: \n {}".format(test_formatted)) | |
with tempfile.NamedTemporaryFile( | |
prefix="test_", suffix=".py", mode="w" | |
) as temp: | |
# Writes out each test to a file. Then, runs pytest on that file. | |
full_test_file = "#!/usr/bin/env python\n\nimport pytest\n{}\n{}".format( | |
fxn, test_formatted | |
) | |
temp.writelines(full_test_file) | |
temp.flush() | |
retcode = pytest.main(["-x", temp.name]) | |
print(retcode.name) | |
if retcode.name == "TESTS_FAILED": | |
failed_tests.append(test) | |
print("test failed") | |
elif retcode.name == "OK": | |
passed_tests.append(test) | |
print("test passed") | |
passed_tests = "\n".join(passed_tests) | |
failed_tests = "\n".join(failed_tests) | |
return passed_tests, failed_tests | |
def generate_test(code): | |
s = requests.Session() | |
message = "Write me a test of this function\n{}".format(code) | |
system_prompt = """ | |
You are a helpful coding assistant. | |
Your job is to help people write unit tests for the python code. | |
If inputs and outputs are provided, please return a set of unit tests that will | |
verify that the function will produce the corect outputs. Also provide tests to | |
handle base and edge cases. | |
""" | |
body = { | |
"model": "meta-llama/Llama-2-70b-chat-hf", | |
"messages": [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": message}, | |
], | |
"temperature": 0.7, | |
} | |
with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp: | |
response = resp.json()["choices"][0] | |
if response["finish_reason"] != "stop": | |
raise ValueError("Print please try again -- response was not finished!") | |
split_response = response["message"]["content"].split("```") | |
if len(split_response) != 3: | |
raise ValueError("Please try again -- response generated too many code blocks!") | |
def execute_code(code, test): | |
# Capture the standard output in a StringIO object | |
old_stdout = sys.stdout | |
new_stdout = io.StringIO() | |
sys.stdout = new_stdout | |
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.py') as f: | |
f.writelines(code) | |
f.writelines(test) | |
f.flush() | |
temp_path = f.name | |
pytest.main(["-x", temp_path]) | |
# Restore the standard output | |
sys.stdout = old_stdout | |
# Get the captured output from the StringIO object | |
output = new_stdout.getvalue() | |
return output | |
examples = [""" | |
def prime_factors(n): | |
i = 2 | |
factors = [] | |
while i * i <= n: | |
if n % i: | |
i += 1 | |
else: | |
n //= i | |
factors.append(i) | |
if n > 1: | |
factors.append(n) | |
return factors | |
""", | |
""" | |
import numpy | |
def matrix_multiplication(A, B): | |
return np.dot(A, B) | |
""", | |
""" | |
import numpy as np | |
def efficient_is_semipositive_definite(matrix): | |
try: | |
# Attempt Cholesky decomposition | |
np.linalg.cholesky(matrix) | |
return True | |
except np.linalg.LinAlgError: | |
return False | |
""", | |
""" | |
import numpy as np | |
def is_semipositive_definite(matrix): | |
# Compute the eigenvalues of the matrix | |
eigenvalues = np.linalg.eigvals(matrix) | |
# Check if all eigenvalues are non-negative | |
return all(val >= 0 for val in eigenvalues) | |
""" | |
] | |
example = examples[0] | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1><center>Llama_test: generate unit test for your Python code</center></h1>") | |
openai_api_key = gr.Textbox( | |
show_label=False, | |
placeholder="Set your Anyscale API key here.", | |
lines=1, | |
type="password" | |
) | |
code_input = gr.Code(example, language="python", label="Provide the code of the function you want to test") | |
gr.Examples( | |
examples=examples, | |
inputs=code_input,) | |
generate_btn = gr.Button("Generate test") | |
with gr.Row(): | |
code_output = gr.Code(language="python", label="Passed tests") | |
code_output2 = gr.Code(language="python", label="Failed tests") | |
generate_btn.click(main, inputs=[code_input, openai_api_key], outputs=[code_output, code_output2]) | |
if __name__ == "__main__": | |
demo.launch() |