solcoder / code_assistant_runnable.py
AbdulmalikAdeyemo's picture
Update code_assistant_runnable.py
6aff73d verified
raw
history blame
19.4 kB
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from langchain_groq import ChatGroq
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
# from langchain_community.vectorstores import Chroma
from langchain_chroma import Chroma
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import pickle
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from typing import List
from typing_extensions import TypedDict
from langgraph.graph import END, StateGraph, START
import subprocess
import time
import re
import json
import os
from dotenv import load_dotenv
load_dotenv()
# Add after your imports
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# llm = ChatOllama(model="codestral")
expt_llm = "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
## Create retrieval from existing store
# Load the existing vectorstore
# Load an existing (saved) embedding model from a pickle file
# model_path = "/Model/embedding_model.pkl"
model_path = "embedding_model.pkl"
with open(model_path, 'rb') as f:
embedding_model = pickle.load(f)
print("Loaded embedding model successfully")
vectorstore = Chroma(
collection_name="solcoder-chroma",
embedding_function=embedding_model,
persist_directory="solcoder-db"
)
retriever = vectorstore.as_retriever()
# Grader prompt
code_gen_prompt = ChatPromptTemplate(
[
(
"system",
"""<instructions> You are a coding assistant with expertise in Solana Blockchain ecosystem. \n
Here is a set of Solana development documentation based on a user question: \n ------- \n {context} \n ------- \n
Answer the user question based on the above provided documentation. Ensure any code you provide can be executed with all required imports and variables \n
defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n
Invoke the code tool to structure the output correctly. </instructions> \n Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
# Data model
class code(BaseModel):
"""Schema for code solutions to questions about Solana development."""
prefix: str = Field(description="Description of the problem and approach")
imports: str = Field(description="Code block import statements")
code: str = Field(description="Code block not including import statements")
language: str = Field(description="programming language the code is implemented")
class Config:
json_schema_extra = {
"example": {
"prefix": "To read the balance of an account from the Solana network, you can use the `@solana/web3.js` library.",
"imports": 'import { clusterApiUrl, Connection, PublicKey, LAMPORTS_PER_SOL,} from "@solana/web3.js";',
"code":"""const connection = new Connection(clusterApiUrl("devnet"), "confirmed");
const wallet = new PublicKey("nicktrLHhYzLmoVbuZQzHUTicd2sfP571orwo9jfc8c");
const balance = await connection.getBalance(wallet);
console.log(`Balance: ${balance / LAMPORTS_PER_SOL} SOL`);""",
"language":"typescript"
}
}
# expt_llm = "codestral"
# llm = ChatOllama(temperature=0, model=expt_llm)
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
structured_llm = code_gen_prompt | llm.with_structured_output(code, include_raw=True)
# Optional: Check for errors in case tool use is flaky
def check_llm_output(tool_output):
"""Check for parse error or failure to call the tool"""
# Error with parsing
if tool_output["parsing_error"]:
# Report back output and parsing errors
print("Parsing error!")
raw_output = str(tool_output["raw"].content)
error = tool_output["parsing_error"]
raise ValueError(
f"Error parsing your output! Be sure to invoke the tool. Output: {raw_output}. \n Parse error: {error}"
)
# Tool was not invoked
elif not tool_output["parsed"]:
print("Failed to invoke tool!")
raise ValueError(
"You did not use the provided tool! Be sure to invoke the tool to structure the output."
)
return tool_output
# Chain with output check
code_chain_raw = (
code_gen_prompt | structured_llm | check_llm_output
)
def insert_errors(inputs):
"""Insert errors for tool parsing in the messages"""
# Get errors
error = inputs["error"]
messages = inputs["messages"]
messages += [
(
"assistant",
f"Retry. You are required to fix the parsing errors: {error} \n\n You must invoke the provided tool.",
)
]
return {
"messages": messages,
"context": inputs["context"],
}
# This will be run as a fallback chain
fallback_chain = insert_errors | code_chain_raw
N = 3 # Max re-tries
code_gen_chain_re_try = code_chain_raw.with_fallbacks(
fallbacks=[fallback_chain] * N, exception_key="error"
)
def parse_output(solution):
"""When we add 'include_raw=True' to structured output,
it will return a dict w 'raw', 'parsed', 'parsing_error'."""
return solution["parsed"]
# Optional: With re-try to correct for failure to invoke tool
code_gen_chain = code_gen_chain_re_try | parse_output
# No re-try
# code_gen_chain = code_gen_prompt | structured_llm | parse_output
### Create State
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
error : Binary flag for control flow to indicate whether test error was tripped
messages : With user question, error messages, reasoning
generation : Code solution
iterations : Number of tries
"""
error: str
messages: List
generation: List
iterations: int
### HELPER FUNCTIONS
def check_node_typescript_installation():
"""Check if Node.js and TypeScript are properly installed"""
try:
# Check Node.js
node_version = subprocess.run(["node", "--version"],
capture_output=True,
text=True)
if node_version.returncode != 0:
return False, "Node.js is not installed or not in PATH"
# Check TypeScript
tsc_version = subprocess.run(["npx", "tsc", "--version"],
capture_output=True,
text=True)
if tsc_version.returncode != 0:
return False, "TypeScript is not installed. Please run 'npm install -g typescript'"
return True, "Environment OK"
except Exception as e:
return False, f"Error checking environment: {str(e)}"
def create_temp_package_json():
"""Create a temporary package.json file for Node.js execution"""
package_json = {
"name": "temp-code-execution",
"version": "1.0.0",
"type": "module",
"dependencies": {
"typescript": "^4.9.5"
}
}
with open("package.json", "w") as f:
json.dump(package_json, f)
def run_javascript_code(code, is_typescript=False):
"""Execute JavaScript or TypeScript code using Node.js"""
# Check environment first
env_ok, env_message = check_node_typescript_installation()
if not env_ok:
return f"Environment Error: {env_message}"
try:
# Create necessary files
create_temp_package_json()
if is_typescript:
# For TypeScript, we need to compile first
with open("temp_code.ts", "w") as f:
f.write(code)
# Compile TypeScript
compile_process = subprocess.run(
["npx", "tsc", "temp_code.ts", "--module", "ES2020", "--target", "ES2020"],
capture_output=True,
text=True
)
# if compile_process.returncode != 0:
# return f"TypeScript Compilation Error:\n{compile_process.stderr}"
return compile_process
# Run compiled JavaScript
file_to_run = "temp_code.js"
else:
# For JavaScript, write directly to .js file
with open("temp_code.js", "w") as f:
f.write(code)
file_to_run = "temp_code.js"
# Execute the code using Node.js
result = subprocess.run(
["node", file_to_run],
capture_output=True,
text=True
)
# Clean up temporary files
cleanup_files = ["temp_code.js", "temp_code.ts", "package.json"]
for file in cleanup_files:
if os.path.exists(file):
os.remove(file)
# return result.stderr if result.stderr else result.stdout
return result
except Exception as e:
return f"Error: {e}"
def run_rust_code(code):
with open('code.rs', 'w') as file:
file.write(code)
compile_process = subprocess.Popen(['rustc', 'code.rs'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True)
compile_output, compile_errors = compile_process.communicate()
if compile_process.returncode != 0:
return f"Compilation Error: {compile_errors}"
run_process = subprocess.Popen(['./code'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True)
run_output, run_errors = run_process.communicate()
return run_output if not run_errors else run_errors
### Parameter
# Max tries
max_iterations = 3
# Reflect
# flag = 'reflect'
flag = "do not reflect"
### Nodes
def generate(state: GraphState):
"""
Generate a code solution
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
# State
messages = state["messages"]
iterations = state["iterations"]
error = state["error"]
question = state['messages'][-1][1]
# We have been routed back to generation with an error
if error == "yes":
messages += [
(
"user",
"Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:",
)
]
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
retrieved_docs = retriever.invoke(question)
formated_docs = format_docs(retrieved_docs)
# Solution
code_solution = code_gen_chain.invoke(
{"context": formated_docs, "messages": messages}
)
messages += [
(
"assistant",
f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
)
]
# Increment
iterations = iterations + 1
return {"generation": code_solution, "messages": messages, "iterations": iterations}
def code_check(state: GraphState):
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state["messages"]
code_solution = state["generation"]
iterations = state["iterations"]
# Get solution components
imports = code_solution.imports
code = code_solution.code
language = code_solution.language
if language.lower()=="python":
# Check imports
try:
exec(imports)
except Exception as e:
print("---CODE IMPORT CHECK: FAILED---")
error_message = [("user", f"Your solution failed the import test: {e}")]
messages += error_message
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
# Check execution
try:
exec(imports + "\n" + code)
except Exception as e:
print("---CODE BLOCK CHECK: FAILED---")
error_message = [("user", f"Your solution failed the code execution test: {e}")]
messages += error_message
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
if language.lower()=="javascript":
full_code = imports + "\n" + code
result = run_javascript_code(full_code, is_typescript=False)
if result.stderr:
print("---JS CODE BLOCK CHECK: FAILED---")
print(f"This is the error:{result.stderr}")
error_message = [("user", f"Your javascript solution failed the code execution test: {result.stderr}")]
messages += error_message
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
if language.lower()=="typescript":
full_code = imports + "\n" + code
result = run_javascript_code(full_code, is_typescript=True)
if result.stderr:
print("---TS CODE BLOCK CHECK: FAILED---")
print(f"This is the error:{result.stderr}")
error_message = [("user", f"Your typesript solution failed the code execution test: {result.stderr}")]
messages += error_message
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
if language.lower()=="rust":
full_code = imports + "\n" + code
with open('code.rs', 'w') as file:
file.write(full_code)
compile_process = subprocess.Popen(['rustc', 'code.rs'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True)
compile_output, compile_errors = compile_process.communicate()
if compile_process.stderr:
# return f"Compilation Error: {compile_errors}"
print("---RUST CODE BLOCK CHECK: COMPILATION FAILED---")
print(f"This is the error:{compile_process.stderr}")
error_message = [("user", f"Your rust solution failed the code compilation test: {compile_process.stderr}")]
messages += error_message
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
run_process = subprocess.Popen(['./code'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True)
run_output, run_errors = run_process.communicate()
if run_process.stderr:
print("---RUST CODE BLOCK CHECK: RUN FAILED---")
print(f"This is the error:{run_errors}")
error_message = [("user", f"Your rust solution failed the code run test: {run_errors}")]
messages += error_message
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
# return run_output if not run_errors else run_errors
elif language.lower() not in ["rust", "python", "typescript", "javascript"]:
# Can't test the code
print("---CANNOT TEST CODE: CODE NOT IN EXPECTED LANGUAGE---")
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "no",
}
# No errors
print("---NO CODE TEST FAILURES---")
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "no",
}
def reflect(state: GraphState):
"""
Reflect on errors
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation
"""
print("---REFLECTING ON CODE SOLUTION ERRORS---")
# State
messages = state["messages"]
iterations = state["iterations"]
code_solution = state["generation"]
question = state['messages'][-1][1]
# Prompt reflection
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
retrieved_docs = retriever.invoke(question)
formated_docs = format_docs(retrieved_docs)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": formated_docs, "messages": messages}
)
messages += [("assistant", f"Here are reflections on the error: {reflections}")]
return {"generation": code_solution, "messages": messages, "iterations": iterations}
### Edges
def decide_to_finish(state: GraphState):
"""
Determines whether to finish.
Args:
state (dict): The current graph state
Returns:
str: Next node to call
"""
error = state["error"]
iterations = state["iterations"]
if error == "no" or iterations == max_iterations:
print("---DECISION: FINISH---")
return "end"
else:
print("---DECISION: RE-TRY SOLUTION---")
if flag == "reflect":
return "reflect"
else:
return "generate"
def get_runnable():
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("generate", generate) # generation solution
workflow.add_node("check_code", code_check) # check code
workflow.add_node("reflect", reflect) # reflect
# Build graph
workflow.add_edge(START, "generate")
workflow.add_edge("generate", "check_code")
workflow.add_conditional_edges(
"check_code",
decide_to_finish,
{
"end": END,
"reflect": "reflect",
"generate": "generate",
},
)
workflow.add_edge("reflect", "generate")
# Remove the checkpointer for now since it's causing issues
code_assistant_app = workflow.compile()
# memory = AsyncSqliteSaver.from_conn_string(":memory:")
# code_assistant_app = workflow.compile(checkpointer=memory)
return code_assistant_app
# if __name__ == "__main__":
# graph = get_runnable()
# prompt = "How do I read from the solana network?"
# print(f'{graph.invoke({"messages": [("user", prompt)], "iterations": 0, "error": ""})}')