from langchain_openai import ChatOpenAI from langchain_ollama import ChatOllama from langchain_groq import ChatGroq from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import WebBaseLoader # from langchain_community.vectorstores import Chroma from langchain_chroma import Chroma from langchain_community.embeddings import HuggingFaceBgeEmbeddings import pickle from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field from typing import List from typing_extensions import TypedDict from langgraph.graph import END, StateGraph, START import subprocess import time import re import json import os from dotenv import load_dotenv load_dotenv() # Add after your imports os.environ["TOKENIZERS_PARALLELISM"] = "false" # llm = ChatOllama(model="codestral") expt_llm = "gpt-4o-mini" llm = ChatOpenAI(temperature=0, model=expt_llm) ## Create retrieval from existing store # Load the existing vectorstore # Load an existing (saved) embedding model from a pickle file # model_path = "/Model/embedding_model.pkl" model_path = "embedding_model.pkl" with open(model_path, 'rb') as f: embedding_model = pickle.load(f) print("Loaded embedding model successfully") vectorstore = Chroma( collection_name="solcoder-chroma", embedding_function=embedding_model, persist_directory="solcoder-db-cpu" ) retriever = vectorstore.as_retriever() # Grader prompt code_gen_prompt = ChatPromptTemplate( [ ( "system", """ You are a coding assistant with expertise in Solana Blockchain ecosystem. \n Here is a set of Solana development documentation based on a user question: \n ------- \n {context} \n ------- \n Answer the user question based on the above provided documentation. Ensure any code you provide can be executed with all required imports and variables \n defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n Invoke the code tool to structure the output correctly. \n Here is the user question:""", ), ("placeholder", "{messages}"), ] ) # Data model class code(BaseModel): """Schema for code solutions to questions about Solana development.""" prefix: str = Field(description="Description of the problem and approach") imports: str = Field(description="Code block import statements") code: str = Field(description="Code block not including import statements") language: str = Field(description="programming language the code is implemented") class Config: json_schema_extra = { "example": { "prefix": "To read the balance of an account from the Solana network, you can use the `@solana/web3.js` library.", "imports": 'import { clusterApiUrl, Connection, PublicKey, LAMPORTS_PER_SOL,} from "@solana/web3.js";', "code":"""const connection = new Connection(clusterApiUrl("devnet"), "confirmed"); const wallet = new PublicKey("nicktrLHhYzLmoVbuZQzHUTicd2sfP571orwo9jfc8c"); const balance = await connection.getBalance(wallet); console.log(`Balance: ${balance / LAMPORTS_PER_SOL} SOL`);""", "language":"typescript" } } # expt_llm = "codestral" # llm = ChatOllama(temperature=0, model=expt_llm) # Post-processing def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) structured_llm = code_gen_prompt | llm.with_structured_output(code, include_raw=True) # Optional: Check for errors in case tool use is flaky def check_llm_output(tool_output): """Check for parse error or failure to call the tool""" # Error with parsing if tool_output["parsing_error"]: # Report back output and parsing errors print("Parsing error!") raw_output = str(tool_output["raw"].content) error = tool_output["parsing_error"] raise ValueError( f"Error parsing your output! Be sure to invoke the tool. Output: {raw_output}. \n Parse error: {error}" ) # Tool was not invoked elif not tool_output["parsed"]: print("Failed to invoke tool!") raise ValueError( "You did not use the provided tool! Be sure to invoke the tool to structure the output." ) return tool_output # Chain with output check code_chain_raw = ( code_gen_prompt | structured_llm | check_llm_output ) def insert_errors(inputs): """Insert errors for tool parsing in the messages""" # Get errors error = inputs["error"] messages = inputs["messages"] messages += [ ( "assistant", f"Retry. You are required to fix the parsing errors: {error} \n\n You must invoke the provided tool.", ) ] return { "messages": messages, "context": inputs["context"], } # This will be run as a fallback chain fallback_chain = insert_errors | code_chain_raw N = 3 # Max re-tries code_gen_chain_re_try = code_chain_raw.with_fallbacks( fallbacks=[fallback_chain] * N, exception_key="error" ) def parse_output(solution): """When we add 'include_raw=True' to structured output, it will return a dict w 'raw', 'parsed', 'parsing_error'.""" return solution["parsed"] # Optional: With re-try to correct for failure to invoke tool code_gen_chain = code_gen_chain_re_try | parse_output # No re-try # code_gen_chain = code_gen_prompt | structured_llm | parse_output ### Create State class GraphState(TypedDict): """ Represents the state of our graph. Attributes: error : Binary flag for control flow to indicate whether test error was tripped messages : With user question, error messages, reasoning generation : Code solution iterations : Number of tries """ error: str messages: List generation: List iterations: int ### HELPER FUNCTIONS def check_node_typescript_installation(): """Check if Node.js and TypeScript are properly installed""" try: # Check Node.js node_version = subprocess.run(["node", "--version"], capture_output=True, text=True) if node_version.returncode != 0: return False, "Node.js is not installed or not in PATH" # Check TypeScript tsc_version = subprocess.run(["npx", "tsc", "--version"], capture_output=True, text=True) if tsc_version.returncode != 0: return False, "TypeScript is not installed. Please run 'npm install -g typescript'" return True, "Environment OK" except Exception as e: return False, f"Error checking environment: {str(e)}" def create_temp_package_json(): """Create a temporary package.json file for Node.js execution""" package_json = { "name": "temp-code-execution", "version": "1.0.0", "type": "module", "dependencies": { "typescript": "^4.9.5" } } with open("package.json", "w") as f: json.dump(package_json, f) def run_javascript_code(code, is_typescript=False): """Execute JavaScript or TypeScript code using Node.js""" # Check environment first env_ok, env_message = check_node_typescript_installation() if not env_ok: return f"Environment Error: {env_message}" try: # Create necessary files create_temp_package_json() if is_typescript: # For TypeScript, we need to compile first with open("temp_code.ts", "w") as f: f.write(code) # Compile TypeScript compile_process = subprocess.run( ["npx", "tsc", "temp_code.ts", "--module", "ES2020", "--target", "ES2020"], capture_output=True, text=True ) # if compile_process.returncode != 0: # return f"TypeScript Compilation Error:\n{compile_process.stderr}" return compile_process # Run compiled JavaScript file_to_run = "temp_code.js" else: # For JavaScript, write directly to .js file with open("temp_code.js", "w") as f: f.write(code) file_to_run = "temp_code.js" # Execute the code using Node.js result = subprocess.run( ["node", file_to_run], capture_output=True, text=True ) # Clean up temporary files cleanup_files = ["temp_code.js", "temp_code.ts", "package.json"] for file in cleanup_files: if os.path.exists(file): os.remove(file) # return result.stderr if result.stderr else result.stdout return result except Exception as e: return f"Error: {e}" def run_rust_code(code): with open('code.rs', 'w') as file: file.write(code) compile_process = subprocess.Popen(['rustc', 'code.rs'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) compile_output, compile_errors = compile_process.communicate() if compile_process.returncode != 0: return f"Compilation Error: {compile_errors}" run_process = subprocess.Popen(['./code'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) run_output, run_errors = run_process.communicate() return run_output if not run_errors else run_errors ### Parameter # Max tries max_iterations = 3 # Reflect # flag = 'reflect' flag = "do not reflect" ### Nodes def generate(state: GraphState): """ Generate a code solution Args: state (dict): The current graph state Returns: state (dict): New key added to state, generation """ print("---GENERATING CODE SOLUTION---") # State messages = state["messages"] iterations = state["iterations"] error = state["error"] question = state['messages'][-1][1] # We have been routed back to generation with an error if error == "yes": messages += [ ( "user", "Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:", ) ] # Post-processing def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) retrieved_docs = retriever.invoke(question) formated_docs = format_docs(retrieved_docs) # Solution code_solution = code_gen_chain.invoke( {"context": formated_docs, "messages": messages} ) messages += [ ( "assistant", f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}", ) ] # Increment iterations = iterations + 1 return {"generation": code_solution, "messages": messages, "iterations": iterations} def code_check(state: GraphState): """ Check code Args: state (dict): The current graph state Returns: state (dict): New key added to state, error """ print("---CHECKING CODE---") # State messages = state["messages"] code_solution = state["generation"] iterations = state["iterations"] # Get solution components imports = code_solution.imports code = code_solution.code language = code_solution.language if language.lower()=="python": # Check imports try: exec(imports) except Exception as e: print("---CODE IMPORT CHECK: FAILED---") error_message = [("user", f"Your solution failed the import test: {e}")] messages += error_message return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes", } # Check execution try: exec(imports + "\n" + code) except Exception as e: print("---CODE BLOCK CHECK: FAILED---") error_message = [("user", f"Your solution failed the code execution test: {e}")] messages += error_message return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes", } if language.lower()=="javascript": full_code = imports + "\n" + code result = run_javascript_code(full_code, is_typescript=False) if result.stderr: print("---JS CODE BLOCK CHECK: FAILED---") print(f"This is the error:{result.stderr}") error_message = [("user", f"Your javascript solution failed the code execution test: {result.stderr}")] messages += error_message return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes", } if language.lower()=="typescript": full_code = imports + "\n" + code result = run_javascript_code(full_code, is_typescript=True) if result.stderr: print("---TS CODE BLOCK CHECK: FAILED---") print(f"This is the error:{result.stderr}") error_message = [("user", f"Your typesript solution failed the code execution test: {result.stderr}")] messages += error_message return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes", } if language.lower()=="rust": full_code = imports + "\n" + code with open('code.rs', 'w') as file: file.write(full_code) compile_process = subprocess.Popen(['rustc', 'code.rs'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) compile_output, compile_errors = compile_process.communicate() if compile_process.stderr: # return f"Compilation Error: {compile_errors}" print("---RUST CODE BLOCK CHECK: COMPILATION FAILED---") print(f"This is the error:{compile_process.stderr}") error_message = [("user", f"Your rust solution failed the code compilation test: {compile_process.stderr}")] messages += error_message return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes", } run_process = subprocess.Popen(['./code'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) run_output, run_errors = run_process.communicate() if run_process.stderr: print("---RUST CODE BLOCK CHECK: RUN FAILED---") print(f"This is the error:{run_errors}") error_message = [("user", f"Your rust solution failed the code run test: {run_errors}")] messages += error_message return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes", } # return run_output if not run_errors else run_errors elif language.lower() not in ["rust", "python", "typescript", "javascript"]: # Can't test the code print("---CANNOT TEST CODE: CODE NOT IN EXPECTED LANGUAGE---") return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "no", } # No errors print("---NO CODE TEST FAILURES---") return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "no", } def reflect(state: GraphState): """ Reflect on errors Args: state (dict): The current graph state Returns: state (dict): New key added to state, generation """ print("---REFLECTING ON CODE SOLUTION ERRORS---") # State messages = state["messages"] iterations = state["iterations"] code_solution = state["generation"] question = state['messages'][-1][1] # Prompt reflection # Post-processing def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) retrieved_docs = retriever.invoke(question) formated_docs = format_docs(retrieved_docs) # Add reflection reflections = code_gen_chain.invoke( {"context": formated_docs, "messages": messages} ) messages += [("assistant", f"Here are reflections on the error: {reflections}")] return {"generation": code_solution, "messages": messages, "iterations": iterations} ### Edges def decide_to_finish(state: GraphState): """ Determines whether to finish. Args: state (dict): The current graph state Returns: str: Next node to call """ error = state["error"] iterations = state["iterations"] if error == "no" or iterations == max_iterations: print("---DECISION: FINISH---") return "end" else: print("---DECISION: RE-TRY SOLUTION---") if flag == "reflect": return "reflect" else: return "generate" def get_runnable(): workflow = StateGraph(GraphState) # Define the nodes workflow.add_node("generate", generate) # generation solution # workflow.add_node("check_code", code_check) # check code # workflow.add_node("reflect", reflect) # reflect # Build graph workflow.add_edge(START, "generate") workflow.add_edge("generate", END) # workflow.add_edge("generate", "check_code") # workflow.add_conditional_edges( # "check_code", # decide_to_finish, # { # "end": END, # "reflect": "reflect", # "generate": "generate", # }, # ) # workflow.add_edge("reflect", "generate") # Remove the checkpointer for now since it's causing issues code_assistant_app = workflow.compile() # memory = AsyncSqliteSaver.from_conn_string(":memory:") # code_assistant_app = workflow.compile(checkpointer=memory) return code_assistant_app # if __name__ == "__main__": # graph = get_runnable() # prompt = "How do I read from the solana network?" # print(f'{graph.invoke({"messages": [("user", prompt)], "iterations": 0, "error": ""})}')