|
from langchain.agents import tool |
|
from typing import Literal |
|
import json |
|
from PIL import Image |
|
|
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage |
|
from langgraph.graph import END, MessagesState |
|
|
|
from render_mermaid import render_mermaid |
|
from langchain_community.document_loaders import GithubFileLoader |
|
|
|
|
|
from prompts import * |
|
from constants import file_extensions |
|
from __init__ import llm, llm_structured |
|
|
|
|
|
class GraphState(MessagesState): |
|
working_knowledge: str |
|
all_files: list[str] |
|
remaining_files: list[str] |
|
explored_files: list[str] |
|
explored_summaries: str |
|
document_summaries_store: dict |
|
documents: list |
|
final_graph: Image.Image |
|
|
|
|
|
def load_github_codebase(repo: str, branch: str): |
|
loader = GithubFileLoader( |
|
repo=repo, |
|
branch=branch, |
|
github_api_url="https://api.github.com", |
|
file_filter=lambda file_path: file_path.endswith(tuple(file_extensions)), |
|
|
|
encoding="utf-8", |
|
) |
|
documents = loader.load() |
|
return documents |
|
|
|
|
|
def get_file_content_summary(file_path: str, state: GraphState): |
|
"""Returns the functional summary of a file. Please note that the file_path should not be null. |
|
|
|
Args: |
|
file_path: The path of the file for which the summary is required.""" |
|
|
|
summary = check_summary_in_store(file_path, state) |
|
if summary: |
|
return summary |
|
for document in state["documents"]: |
|
if document.metadata["path"] == file_path: |
|
doc_content = document.page_content |
|
break |
|
|
|
summary = llm.invoke( |
|
[SystemMessage(content=summarizer_prompt), HumanMessage(content=doc_content)] |
|
).content |
|
summary = json.dumps({"FilePath": file_path, "Summary": summary}) |
|
save_summary_in_store(file_path, summary, state) |
|
return summary |
|
|
|
|
|
def explore_file(state: GraphState): |
|
file_path = state["remaining_files"].pop() |
|
|
|
summary_dict = json.loads(get_file_content_summary(file_path, state)) |
|
if summary_dict["FilePath"] in state["explored_files"]: |
|
return state |
|
knowledge_str = f"""* File Path: {summary_dict['FilePath']}\n\tSummary: {summary_dict['Summary']}\n\n""" |
|
state["explored_summaries"] += knowledge_str |
|
state["explored_files"].append(file_path) |
|
return state |
|
|
|
|
|
@tool |
|
def generate_final_mermaid_code(): |
|
"""Generate the final mermaid code for the codebase once all the files are explored and the working knowledge is complete.""" |
|
return "generate_mermaid_code" |
|
|
|
|
|
def check_summary_in_store(file_path: str, state: GraphState): |
|
if file_path in state["document_summaries_store"]: |
|
return state["document_summaries_store"][file_path] |
|
return None |
|
|
|
|
|
def save_summary_in_store(file_path: str, summary: str, state: GraphState): |
|
state["document_summaries_store"][file_path] = summary |
|
|
|
|
|
def get_all_filesnames_in_codebase(state: GraphState): |
|
"""Get a list of all files (as filepaths) in the codebase.""" |
|
filenames = [] |
|
for document in state["documents"]: |
|
filenames.append(document.metadata["path"]) |
|
|
|
return { |
|
"all_files": filenames, |
|
"explored_files": [], |
|
"remaining_files": filenames, |
|
"explored_summaries": "", |
|
"document_summaries_store": {}, |
|
} |
|
|
|
|
|
def parse_plan(state: GraphState): |
|
"""Parse the plan and return the next action.""" |
|
if "File Exploration Plan" in state["working_knowledge"]: |
|
plan_working = state["working_knowledge"].split("File Exploration Plan")[1] |
|
else: |
|
plan_working = state["working_knowledge"] |
|
response = llm_structured.invoke(plan_parser.format(plan_list=plan_working))[ |
|
"plan_list" |
|
] |
|
if len(response) > 25: |
|
response = response[:25] |
|
|
|
return {"remaining_files": response} |
|
|
|
|
|
def router(state: GraphState): |
|
"""Route the conversation to the appropriate node based on the current state of the conversation.""" |
|
if state["remaining_files"] != []: |
|
return "explore_file" |
|
else: |
|
return "generate_mermaid_code" |
|
|
|
|
|
def get_plan_for_codebase(state: GraphState): |
|
new_state = get_all_filesnames_in_codebase(state) |
|
planner_content = "# File Structure\n" + str(new_state["all_files"]) |
|
plan = llm.invoke( |
|
[SystemMessage(content=planner_prompt), HumanMessage(content=planner_content)] |
|
) |
|
|
|
knowledge_str = f"""# Plan\n{plan.content}""" |
|
new_state["working_knowledge"] = knowledge_str |
|
|
|
return new_state |
|
|
|
|
|
def final_mermaid_code_generation(state: GraphState): |
|
final_graph_content = ( |
|
"# Disjoint Codebase Understanding\n" |
|
+ state["working_knowledge"] |
|
+ "\n\n# Completed Explorations\n" |
|
+ state["explored_summaries"] |
|
) |
|
response = llm.invoke( |
|
[ |
|
SystemMessage(content=final_graph_prompt), |
|
HumanMessage(content=final_graph_content), |
|
] |
|
) |
|
return {"messages": [response]} |
|
|
|
|
|
import time |
|
|
|
|
|
def extract_mermaid_and_generate_graph(state: GraphState): |
|
mermaid_code = state["messages"][-1].content |
|
if "mermaid" in mermaid_code: |
|
mermaid_code = mermaid_code.split("mermaid")[-1] |
|
response = llm.invoke( |
|
[SystemMessage(content=mermaid_extracter), HumanMessage(content=mermaid_code)] |
|
).content |
|
response = response.split("```mermaid")[-1].split("```")[0] |
|
|
|
|
|
file_name = f"mermaid/{int(time.time())}.png" |
|
render_mermaid(response, file_name) |
|
|
|
|
|
img = Image.open(file_name) |
|
return {"messages": [AIMessage(response)], "final_graph": img} |
|
|
|
|
|
def need_to_update_working_knowledge(state: GraphState): |
|
messages = state["messages"] |
|
last_message = messages[-1] |
|
|
|
|
|
if last_message.content == "generate_mermaid_code": |
|
return "generate_mermaid_code" |
|
if isinstance(last_message, ToolMessage): |
|
return "tools_knowledge_update" |
|
|
|
return "agent" |
|
|