Spaces:
Sleeping
Sleeping
from langchain_openai import ChatOpenAI | |
from langchain_ollama import ChatOllama | |
from langchain_groq import ChatGroq | |
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import WebBaseLoader | |
# from langchain_community.vectorstores import Chroma | |
from langchain_chroma import Chroma | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
import pickle | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_openai import ChatOpenAI | |
from pydantic import BaseModel, Field | |
from typing import List | |
from typing_extensions import TypedDict | |
from langgraph.graph import END, StateGraph, START | |
import subprocess | |
import time | |
import re | |
import json | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Add after your imports | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# llm = ChatOllama(model="codestral") | |
expt_llm = "gpt-4o-mini" | |
llm = ChatOpenAI(temperature=0, model=expt_llm) | |
## Create retrieval from existing store | |
# Load the existing vectorstore | |
# Load an existing (saved) embedding model from a pickle file | |
model_path = "/Model/embedding_model.pkl" | |
with open(model_path, 'rb') as f: | |
embedding_model = pickle.load(f) | |
print("Loaded embedding model successfully") | |
vectorstore = Chroma( | |
collection_name="solcoder-chroma", | |
embedding_function=embedding_model, | |
persist_directory="solcoder-db" | |
) | |
retriever = vectorstore.as_retriever() | |
# Grader prompt | |
code_gen_prompt = ChatPromptTemplate( | |
[ | |
( | |
"system", | |
"""<instructions> You are a coding assistant with expertise in Solana Blockchain ecosystem. \n | |
Here is a set of Solana development documentation based on a user question: \n ------- \n {context} \n ------- \n | |
Answer the user question based on the above provided documentation. Ensure any code you provide can be executed with all required imports and variables \n | |
defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n | |
Invoke the code tool to structure the output correctly. </instructions> \n Here is the user question:""", | |
), | |
("placeholder", "{messages}"), | |
] | |
) | |
# Data model | |
class code(BaseModel): | |
"""Schema for code solutions to questions about Solana development.""" | |
prefix: str = Field(description="Description of the problem and approach") | |
imports: str = Field(description="Code block import statements") | |
code: str = Field(description="Code block not including import statements") | |
language: str = Field(description="programming language the code is implemented") | |
class Config: | |
json_schema_extra = { | |
"example": { | |
"prefix": "To read the balance of an account from the Solana network, you can use the `@solana/web3.js` library.", | |
"imports": 'import { clusterApiUrl, Connection, PublicKey, LAMPORTS_PER_SOL,} from "@solana/web3.js";', | |
"code":"""const connection = new Connection(clusterApiUrl("devnet"), "confirmed"); | |
const wallet = new PublicKey("nicktrLHhYzLmoVbuZQzHUTicd2sfP571orwo9jfc8c"); | |
const balance = await connection.getBalance(wallet); | |
console.log(`Balance: ${balance / LAMPORTS_PER_SOL} SOL`);""", | |
"language":"typescript" | |
} | |
} | |
# expt_llm = "codestral" | |
# llm = ChatOllama(temperature=0, model=expt_llm) | |
# Post-processing | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
structured_llm = code_gen_prompt | llm.with_structured_output(code, include_raw=True) | |
# Optional: Check for errors in case tool use is flaky | |
def check_llm_output(tool_output): | |
"""Check for parse error or failure to call the tool""" | |
# Error with parsing | |
if tool_output["parsing_error"]: | |
# Report back output and parsing errors | |
print("Parsing error!") | |
raw_output = str(tool_output["raw"].content) | |
error = tool_output["parsing_error"] | |
raise ValueError( | |
f"Error parsing your output! Be sure to invoke the tool. Output: {raw_output}. \n Parse error: {error}" | |
) | |
# Tool was not invoked | |
elif not tool_output["parsed"]: | |
print("Failed to invoke tool!") | |
raise ValueError( | |
"You did not use the provided tool! Be sure to invoke the tool to structure the output." | |
) | |
return tool_output | |
# Chain with output check | |
code_chain_raw = ( | |
code_gen_prompt | structured_llm | check_llm_output | |
) | |
def insert_errors(inputs): | |
"""Insert errors for tool parsing in the messages""" | |
# Get errors | |
error = inputs["error"] | |
messages = inputs["messages"] | |
messages += [ | |
( | |
"assistant", | |
f"Retry. You are required to fix the parsing errors: {error} \n\n You must invoke the provided tool.", | |
) | |
] | |
return { | |
"messages": messages, | |
"context": inputs["context"], | |
} | |
# This will be run as a fallback chain | |
fallback_chain = insert_errors | code_chain_raw | |
N = 3 # Max re-tries | |
code_gen_chain_re_try = code_chain_raw.with_fallbacks( | |
fallbacks=[fallback_chain] * N, exception_key="error" | |
) | |
def parse_output(solution): | |
"""When we add 'include_raw=True' to structured output, | |
it will return a dict w 'raw', 'parsed', 'parsing_error'.""" | |
return solution["parsed"] | |
# Optional: With re-try to correct for failure to invoke tool | |
code_gen_chain = code_gen_chain_re_try | parse_output | |
# No re-try | |
# code_gen_chain = code_gen_prompt | structured_llm | parse_output | |
### Create State | |
class GraphState(TypedDict): | |
""" | |
Represents the state of our graph. | |
Attributes: | |
error : Binary flag for control flow to indicate whether test error was tripped | |
messages : With user question, error messages, reasoning | |
generation : Code solution | |
iterations : Number of tries | |
""" | |
error: str | |
messages: List | |
generation: List | |
iterations: int | |
### HELPER FUNCTIONS | |
def check_node_typescript_installation(): | |
"""Check if Node.js and TypeScript are properly installed""" | |
try: | |
# Check Node.js | |
node_version = subprocess.run(["node", "--version"], | |
capture_output=True, | |
text=True) | |
if node_version.returncode != 0: | |
return False, "Node.js is not installed or not in PATH" | |
# Check TypeScript | |
tsc_version = subprocess.run(["npx", "tsc", "--version"], | |
capture_output=True, | |
text=True) | |
if tsc_version.returncode != 0: | |
return False, "TypeScript is not installed. Please run 'npm install -g typescript'" | |
return True, "Environment OK" | |
except Exception as e: | |
return False, f"Error checking environment: {str(e)}" | |
def create_temp_package_json(): | |
"""Create a temporary package.json file for Node.js execution""" | |
package_json = { | |
"name": "temp-code-execution", | |
"version": "1.0.0", | |
"type": "module", | |
"dependencies": { | |
"typescript": "^4.9.5" | |
} | |
} | |
with open("package.json", "w") as f: | |
json.dump(package_json, f) | |
def run_javascript_code(code, is_typescript=False): | |
"""Execute JavaScript or TypeScript code using Node.js""" | |
# Check environment first | |
env_ok, env_message = check_node_typescript_installation() | |
if not env_ok: | |
return f"Environment Error: {env_message}" | |
try: | |
# Create necessary files | |
create_temp_package_json() | |
if is_typescript: | |
# For TypeScript, we need to compile first | |
with open("temp_code.ts", "w") as f: | |
f.write(code) | |
# Compile TypeScript | |
compile_process = subprocess.run( | |
["npx", "tsc", "temp_code.ts", "--module", "ES2020", "--target", "ES2020"], | |
capture_output=True, | |
text=True | |
) | |
# if compile_process.returncode != 0: | |
# return f"TypeScript Compilation Error:\n{compile_process.stderr}" | |
return compile_process | |
# Run compiled JavaScript | |
file_to_run = "temp_code.js" | |
else: | |
# For JavaScript, write directly to .js file | |
with open("temp_code.js", "w") as f: | |
f.write(code) | |
file_to_run = "temp_code.js" | |
# Execute the code using Node.js | |
result = subprocess.run( | |
["node", file_to_run], | |
capture_output=True, | |
text=True | |
) | |
# Clean up temporary files | |
cleanup_files = ["temp_code.js", "temp_code.ts", "package.json"] | |
for file in cleanup_files: | |
if os.path.exists(file): | |
os.remove(file) | |
# return result.stderr if result.stderr else result.stdout | |
return result | |
except Exception as e: | |
return f"Error: {e}" | |
def run_rust_code(code): | |
with open('code.rs', 'w') as file: | |
file.write(code) | |
compile_process = subprocess.Popen(['rustc', 'code.rs'], | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
text=True) | |
compile_output, compile_errors = compile_process.communicate() | |
if compile_process.returncode != 0: | |
return f"Compilation Error: {compile_errors}" | |
run_process = subprocess.Popen(['./code'], | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
text=True) | |
run_output, run_errors = run_process.communicate() | |
return run_output if not run_errors else run_errors | |
### Parameter | |
# Max tries | |
max_iterations = 3 | |
# Reflect | |
# flag = 'reflect' | |
flag = "do not reflect" | |
### Nodes | |
def generate(state: GraphState): | |
""" | |
Generate a code solution | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, generation | |
""" | |
print("---GENERATING CODE SOLUTION---") | |
# State | |
messages = state["messages"] | |
iterations = state["iterations"] | |
error = state["error"] | |
question = state['messages'][-1][1] | |
# We have been routed back to generation with an error | |
if error == "yes": | |
messages += [ | |
( | |
"user", | |
"Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:", | |
) | |
] | |
# Post-processing | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
retrieved_docs = retriever.invoke(question) | |
formated_docs = format_docs(retrieved_docs) | |
# Solution | |
code_solution = code_gen_chain.invoke( | |
{"context": formated_docs, "messages": messages} | |
) | |
messages += [ | |
( | |
"assistant", | |
f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}", | |
) | |
] | |
# Increment | |
iterations = iterations + 1 | |
return {"generation": code_solution, "messages": messages, "iterations": iterations} | |
def code_check(state: GraphState): | |
""" | |
Check code | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, error | |
""" | |
print("---CHECKING CODE---") | |
# State | |
messages = state["messages"] | |
code_solution = state["generation"] | |
iterations = state["iterations"] | |
# Get solution components | |
imports = code_solution.imports | |
code = code_solution.code | |
language = code_solution.language | |
if language.lower()=="python": | |
# Check imports | |
try: | |
exec(imports) | |
except Exception as e: | |
print("---CODE IMPORT CHECK: FAILED---") | |
error_message = [("user", f"Your solution failed the import test: {e}")] | |
messages += error_message | |
return { | |
"generation": code_solution, | |
"messages": messages, | |
"iterations": iterations, | |
"error": "yes", | |
} | |
# Check execution | |
try: | |
exec(imports + "\n" + code) | |
except Exception as e: | |
print("---CODE BLOCK CHECK: FAILED---") | |
error_message = [("user", f"Your solution failed the code execution test: {e}")] | |
messages += error_message | |
return { | |
"generation": code_solution, | |
"messages": messages, | |
"iterations": iterations, | |
"error": "yes", | |
} | |
if language.lower()=="javascript": | |
full_code = imports + "\n" + code | |
result = run_javascript_code(full_code, is_typescript=False) | |
if result.stderr: | |
print("---JS CODE BLOCK CHECK: FAILED---") | |
print(f"This is the error:{result.stderr}") | |
error_message = [("user", f"Your javascript solution failed the code execution test: {result.stderr}")] | |
messages += error_message | |
return { | |
"generation": code_solution, | |
"messages": messages, | |
"iterations": iterations, | |
"error": "yes", | |
} | |
if language.lower()=="typescript": | |
full_code = imports + "\n" + code | |
result = run_javascript_code(full_code, is_typescript=True) | |
if result.stderr: | |
print("---TS CODE BLOCK CHECK: FAILED---") | |
print(f"This is the error:{result.stderr}") | |
error_message = [("user", f"Your typesript solution failed the code execution test: {result.stderr}")] | |
messages += error_message | |
return { | |
"generation": code_solution, | |
"messages": messages, | |
"iterations": iterations, | |
"error": "yes", | |
} | |
if language.lower()=="rust": | |
full_code = imports + "\n" + code | |
with open('code.rs', 'w') as file: | |
file.write(full_code) | |
compile_process = subprocess.Popen(['rustc', 'code.rs'], | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
text=True) | |
compile_output, compile_errors = compile_process.communicate() | |
if compile_process.stderr: | |
# return f"Compilation Error: {compile_errors}" | |
print("---RUST CODE BLOCK CHECK: COMPILATION FAILED---") | |
print(f"This is the error:{compile_process.stderr}") | |
error_message = [("user", f"Your rust solution failed the code compilation test: {compile_process.stderr}")] | |
messages += error_message | |
return { | |
"generation": code_solution, | |
"messages": messages, | |
"iterations": iterations, | |
"error": "yes", | |
} | |
run_process = subprocess.Popen(['./code'], | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
text=True) | |
run_output, run_errors = run_process.communicate() | |
if run_process.stderr: | |
print("---RUST CODE BLOCK CHECK: RUN FAILED---") | |
print(f"This is the error:{run_errors}") | |
error_message = [("user", f"Your rust solution failed the code run test: {run_errors}")] | |
messages += error_message | |
return { | |
"generation": code_solution, | |
"messages": messages, | |
"iterations": iterations, | |
"error": "yes", | |
} | |
# return run_output if not run_errors else run_errors | |
elif language.lower() not in ["rust", "python", "typescript", "javascript"]: | |
# Can't test the code | |
print("---CANNOT TEST CODE: CODE NOT IN EXPECTED LANGUAGE---") | |
return { | |
"generation": code_solution, | |
"messages": messages, | |
"iterations": iterations, | |
"error": "no", | |
} | |
# No errors | |
print("---NO CODE TEST FAILURES---") | |
return { | |
"generation": code_solution, | |
"messages": messages, | |
"iterations": iterations, | |
"error": "no", | |
} | |
def reflect(state: GraphState): | |
""" | |
Reflect on errors | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, generation | |
""" | |
print("---REFLECTING ON CODE SOLUTION ERRORS---") | |
# State | |
messages = state["messages"] | |
iterations = state["iterations"] | |
code_solution = state["generation"] | |
question = state['messages'][-1][1] | |
# Prompt reflection | |
# Post-processing | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
retrieved_docs = retriever.invoke(question) | |
formated_docs = format_docs(retrieved_docs) | |
# Add reflection | |
reflections = code_gen_chain.invoke( | |
{"context": formated_docs, "messages": messages} | |
) | |
messages += [("assistant", f"Here are reflections on the error: {reflections}")] | |
return {"generation": code_solution, "messages": messages, "iterations": iterations} | |
### Edges | |
def decide_to_finish(state: GraphState): | |
""" | |
Determines whether to finish. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Next node to call | |
""" | |
error = state["error"] | |
iterations = state["iterations"] | |
if error == "no" or iterations == max_iterations: | |
print("---DECISION: FINISH---") | |
return "end" | |
else: | |
print("---DECISION: RE-TRY SOLUTION---") | |
if flag == "reflect": | |
return "reflect" | |
else: | |
return "generate" | |
def get_runnable(): | |
workflow = StateGraph(GraphState) | |
# Define the nodes | |
workflow.add_node("generate", generate) # generation solution | |
workflow.add_node("check_code", code_check) # check code | |
workflow.add_node("reflect", reflect) # reflect | |
# Build graph | |
workflow.add_edge(START, "generate") | |
workflow.add_edge("generate", "check_code") | |
workflow.add_conditional_edges( | |
"check_code", | |
decide_to_finish, | |
{ | |
"end": END, | |
"reflect": "reflect", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("reflect", "generate") | |
# Remove the checkpointer for now since it's causing issues | |
code_assistant_app = workflow.compile() | |
# memory = AsyncSqliteSaver.from_conn_string(":memory:") | |
# code_assistant_app = workflow.compile(checkpointer=memory) | |
return code_assistant_app | |
# if __name__ == "__main__": | |
# graph = get_runnable() | |
# prompt = "How do I read from the solana network?" | |
# print(f'{graph.invoke({"messages": [("user", prompt)], "iterations": 0, "error": ""})}') | |