llama_test / app.py
dhuynh95's picture
Update app.py
4feb357
import gradio as gr
import tempfile
import pytest
import io
import sys
import os
import requests
import ast
api_base = "https://api.endpoints.anyscale.com/v1"
token = os.environ["OPENAI_API_KEY"]
url = f"{api_base}/chat/completions"
def extract_functions_from_file(filename):
"""Given a file written to disk, extract all functions from it into a list."""
with open(filename, "r") as file:
tree = ast.parse(file.read())
functions = []
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
start_line = node.lineno
end_line = node.end_lineno if hasattr(node, "end_lineno") else start_line
with open(filename, "r") as file:
function_code = "".join(
[
line
for i, line in enumerate(file)
if start_line <= i + 1 <= end_line
]
)
functions.append(function_code)
return functions
def extract_tests_from_list(l):
"""Given a list of strings, extract all functions from it into a list."""
return [t for t in l if t.startswith("def")]
def remove_leading_whitespace(func_str):
"""Given a string representing a function, remove the leading whitespace from each
line such that the function definition is left-aligned and all following lines
follow Python's whitespace formatting rules.
"""
lines = func_str.split("\n")
# Find the amount of whitespace before 'def' (the function signature)
leading_whitespace = len(lines[0]) - len(lines[0].lstrip())
# Remove that amount of whitespace from each line
new_lines = [line[leading_whitespace:] for line in lines if line.strip()]
return "\n".join(new_lines)
def main(fxn: str, openai_api_key, examples: str = "", temperature: float = 0.7):
"""Requires Anyscale Endpoints Alpha API access.
If examples is not a empty string, it will be formatted into
a list of input/output pairs used to prompt the model.
"""
s = requests.Session()
api_base = os.environ["OPENAI_API_BASE"]
token = openai_api_key
url = f"{api_base}/chat/completions"
message = "Write me a test of this function\n{}".format(fxn)
if examples:
message += "\nExample input output pairs:\n"
system_prompt = """
You are a helpful coding assistant.
Your job is to help people write unit tests for their python code. Please write all
unit tests in the format expected by pytest. If inputs and outputs are provided,
return a set of unit tests that will verify that the function will produce the
corect outputs. Also provide tests to handle base and edge cases. It is very
important that the code is formatted correctly for pytest.
"""
body = {
"model": "meta-llama/Llama-2-70b-chat-hf",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
],
"temperature": temperature,
}
with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp:
response = resp.json()["choices"][0]
if response["finish_reason"] != "stop":
raise ValueError("Print please try again -- response was not finished!")
# Parse the response to get the tests out.
split_response = response["message"]["content"].split("```")
if len(split_response) != 3:
raise ValueError("Please try again -- response generated too many code blocks!")
all_tests = split_response[1]
# Writes out all tests to a file. Then, extracts each individual test out into a
# list.
with tempfile.NamedTemporaryFile(
prefix="all_tests_", suffix=".py", mode="w"
) as temp:
temp.writelines(all_tests)
temp.flush()
parsed_tests = extract_functions_from_file(temp.name)
# Loop through test, run pytest, and return two lists of tests.
passed_tests, failed_tests = [], []
for test in parsed_tests:
test_formatted = remove_leading_whitespace(test)
print("testing: \n {}".format(test_formatted))
with tempfile.NamedTemporaryFile(
prefix="test_", suffix=".py", mode="w"
) as temp:
# Writes out each test to a file. Then, runs pytest on that file.
full_test_file = "#!/usr/bin/env python\n\nimport pytest\n{}\n{}".format(
fxn, test_formatted
)
temp.writelines(full_test_file)
temp.flush()
retcode = pytest.main(["-x", temp.name])
print(retcode.name)
if retcode.name == "TESTS_FAILED":
failed_tests.append(test)
print("test failed")
elif retcode.name == "OK":
passed_tests.append(test)
print("test passed")
passed_tests = "\n".join(passed_tests)
failed_tests = "\n".join(failed_tests)
return passed_tests, failed_tests
def generate_test(code):
s = requests.Session()
message = "Write me a test of this function\n{}".format(code)
system_prompt = """
You are a helpful coding assistant.
Your job is to help people write unit tests for the python code.
If inputs and outputs are provided, please return a set of unit tests that will
verify that the function will produce the corect outputs. Also provide tests to
handle base and edge cases.
"""
body = {
"model": "meta-llama/Llama-2-70b-chat-hf",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
],
"temperature": 0.7,
}
with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp:
response = resp.json()["choices"][0]
if response["finish_reason"] != "stop":
raise ValueError("Print please try again -- response was not finished!")
split_response = response["message"]["content"].split("```")
if len(split_response) != 3:
raise ValueError("Please try again -- response generated too many code blocks!")
def execute_code(code, test):
# Capture the standard output in a StringIO object
old_stdout = sys.stdout
new_stdout = io.StringIO()
sys.stdout = new_stdout
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.py') as f:
f.writelines(code)
f.writelines(test)
f.flush()
temp_path = f.name
pytest.main(["-x", temp_path])
# Restore the standard output
sys.stdout = old_stdout
# Get the captured output from the StringIO object
output = new_stdout.getvalue()
return output
examples = ["""
def prime_factors(n):
i = 2
factors = []
while i * i <= n:
if n % i:
i += 1
else:
n //= i
factors.append(i)
if n > 1:
factors.append(n)
return factors
""",
"""
import numpy
def matrix_multiplication(A, B):
return np.dot(A, B)
""",
"""
import numpy as np
def efficient_is_semipositive_definite(matrix):
try:
# Attempt Cholesky decomposition
np.linalg.cholesky(matrix)
return True
except np.linalg.LinAlgError:
return False
""",
"""
import numpy as np
def is_semipositive_definite(matrix):
# Compute the eigenvalues of the matrix
eigenvalues = np.linalg.eigvals(matrix)
# Check if all eigenvalues are non-negative
return all(val >= 0 for val in eigenvalues)
"""
]
example = examples[0]
with gr.Blocks() as demo:
gr.Markdown("<h1><center>Llama_test: generate unit test for your Python code</center></h1>")
openai_api_key = gr.Textbox(
show_label=False,
placeholder="Set your Anyscale API key here.",
lines=1,
type="password"
)
code_input = gr.Code(example, language="python", label="Provide the code of the function you want to test")
gr.Examples(
examples=examples,
inputs=code_input,)
generate_btn = gr.Button("Generate test")
with gr.Row():
code_output = gr.Code(language="python", label="Passed tests")
code_output2 = gr.Code(language="python", label="Failed tests")
generate_btn.click(main, inputs=[code_input, openai_api_key], outputs=[code_output, code_output2])
if __name__ == "__main__":
demo.launch()