File size: 6,144 Bytes
26d32ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import streamlit as st
from langchain_community.llms import OpenAI
import argparse
from datasets import load_dataset
import yaml
from tqdm import tqdm
import re



def load_data(split="test"):
    data = load_dataset("bigcode/humanevalpack")
    print("=========== dataset statistics ===========")
    print(len(data[split]))
    print("==========================================")
    return data[split]



def split_function_header_and_docstring(s):
    # pattern = re.compile(r'\"\"\"(.*?)\"\"\"', re.DOTALL)
    pattern = re.compile(r"(\"\"\"(.*?)\"\"\"|\'\'\'(.*?)\'\'\')", re.DOTALL)
    match = pattern.findall(s)
    if match:
        # docstring = match.group(-1)
        docstring = match[-1][0]
        code_without_docstring = s.replace(docstring, "").replace('"' * 6, "").strip()
        docstring = docstring.replace('"', "")
    else:
        raise ValueError
    return code_without_docstring, docstring


def prepare_model_input(code_data):
    prompt = """Provide feedback on the errors in the given code and suggest the
    correct code to address the described problem.

Problem Description:
{description}

Incorrect Code:
{wrong_code}"""

    description = code_data["prompt"]
    function_header, docstring = split_function_header_and_docstring(description)
    problem = docstring.split(">>>")[0]

    wrong_code = function_header + code_data["buggy_solution"]
    template_dict = {"function_header": function_header, "description": problem, "wrong_code": wrong_code}
    model_input = prompt.format(**template_dict)
    return model_input, problem, function_header


def load_and_prepare_data():
    dataset = load_data()
    all_model_inputs = {} 
    print("### load and prepare data")
    for data in tqdm(dataset):
        problem_id = data['task_id'] 
        buggy_solution = data['buggy_solution']
        model_input, problem, function_header = prepare_model_input(data)
        new_model_input =f"Provide feedback on the errors in the given code and suggest the correct code to address the described problem.\nProblem Description:{problem}\nIncorrect Code:\n{buggy_solution}\nFeedback:"
        # data["header"] = function_header
        all_model_inputs[problem_id] = {
                "model_input": new_model_input,
                "header": function_header,
                "problem_description": problem,
                "data": data
            }
    return all_model_inputs

 


dataset = load_dataset("bigcode/humanevalpack", split='test',  trust_remote_code=True)  # Ensuring consistent split usage

problem_ids = [problem['task_id'] for problem in dataset]
all_model_inputs = load_and_prepare_data()


# Initialize with dummy ports for demonstration purposes here
parser = argparse.ArgumentParser()
parser.add_argument("--editor_port", type=str, default="6000")
parser.add_argument("--critic_port", type=str, default="6001")

# Assuming args are passed via command line interface
args = parser.parse_args()

# Initialize Langchain LLMs for our models (please replace 'your_api_key' with actual API keys)
editor_model = OpenAI(model="Anonymous-COFFEE/COFFEEPOTS-editor", api_key="EMPTY", openai_api_base=f"https://editor.jp.ngrok.io/v1")
# critic_model = OpenAI(model="Anonymous-COFFEE/COFFEEPOTS-critic", api_key="EMPTY", openai_api_base=f"http://localhost:{args.critic_port}/v1")

critic_model = OpenAI(model="Anonymous-COFFEE/COFFEEPOTS-critic", api_key="EMPTY", openai_api_base=f"https://critic.jp.ngrok.io/v1")

st.title("Demo for COFFEEPOTS")

selected_task_id = st.selectbox("Select a problem ID:", problem_ids)

# Retrieve selected problem details
problem_details = dataset[problem_ids.index(selected_task_id)]

st.write(f"**Selected Problem ID:** {problem_details['task_id']}")
st.write(f"**Problem Description:**\n{all_model_inputs[selected_task_id]['problem_description']}")
# Display buggy code with syntax highlighting
st.code(problem_details['buggy_solution'], language='python')

status_text = st.empty()
code_output = st.code("", language="python")

def generate_feedback():

    return critic_model.stream(input=f"{all_model_inputs[selected_task_id]['model_input']}", logit_bias=None)
    # feedback = output.generations[0][0].text
    # return feedback
# def generate_corrected_code():
#     return  "```python"+editor_model.stream(input=f"Buggy Code:\n{problem_details['buggy_solution']}\nFeedback: {feedback}", logit_bias=None)
def generate_corrected_code():
    # Stream output from the editor model
    yield "```python"
    for text_chunk in editor_model.stream(input=f"[INST]Buggy Code:\n{problem_details['buggy_solution']}\nFeedback: {feedback}[/INST]", logit_bias=None):
        yield  text_chunk  # Assuming each chunk is part of the final code
    
    yield "```"
        # time.sleep(0.02)  # Simulate processing delay; Adjust timing as necessary


if st.button("Generate Feedback and Corrected Code"):
    # Example of generating feedback and corrected code (replace these with actual model calls)
    with st.spinner("Generating feedback..."):
        # Simulate API call to critic_model
        print(f"model input for critic:")
        print(all_model_inputs[selected_task_id]['model_input'])
        # output = critic_model.generate(prompts=[f"{all_model_inputs[selected_task_id]['model_input']}"], logit_bias=None)
        # feedback = output.generations[0][0].text
        # print(feedback)
        # feedback = "dummy feedback"
        
        # status_text.markdown(f"{feedback}")
        feedback = status_text.write_stream(generate_feedback())
        # status_text.code(f"{feedback}", language='python')
    
    with st.spinner("Generating corrected code..."):
        # Simulate API call to editor_model
        # output = editor_model.generate(prompts=[f"Buggy Code:\n{problem_details['buggy_solution']}\nFeedback: {feedback}"], logit_bias=None)
        # corrected_code = output.generations[0][0].text
        # print(corrected_code)
        # corrected_code = "dummy code"
        # st.write("**Corrected Code:**")
        corrected_code = code_output.write_stream(generate_corrected_code())
        # code_output.code(corrected_code, language='python')