madhurjindal's picture
Update utils.py
40fc2b7 verified
raw
history blame
6.34 kB
from langchain.agents import tool
from typing import Literal
import json
from PIL import Image
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
from langgraph.graph import END, MessagesState
from render_mermaid import render_mermaid
from langchain_community.document_loaders import GithubFileLoader
# from langchain_ollama import ChatOllama
from prompts import *
from constants import file_extensions
from __init__ import llm, llm_structured
class GraphState(MessagesState):
working_knowledge: str
all_files: list[str]
remaining_files: list[str]
explored_files: list[str]
explored_summaries: str
document_summaries_store: dict
documents: list
final_graph: Image.Image
def load_github_codebase(repo: str, branch: str):
loader = GithubFileLoader(
repo=repo, # the repo name
branch=branch, # the branch name
github_api_url="https://api.github.com",
file_filter=lambda file_path: file_path.endswith(tuple(file_extensions)),
# file_filter=lambda filepath: True,
encoding="utf-8",
)
documents = loader.load()
return documents
def get_file_content_summary(file_path: str, state: GraphState):
"""Returns the functional summary of a file. Please note that the file_path should not be null.
Args:
file_path: The path of the file for which the summary is required."""
summary = check_summary_in_store(file_path, state)
if summary:
return summary
for document in state["documents"]:
if document.metadata["path"] == file_path:
doc_content = document.page_content
break
# print(content)
summary = llm.invoke(
[SystemMessage(content=summarizer_prompt), HumanMessage(content=doc_content)]
).content
summary = json.dumps({"FilePath": file_path, "Summary": summary})
save_summary_in_store(file_path, summary, state)
return summary
def explore_file(state: GraphState):
file_path = state["remaining_files"].pop()
summary_dict = json.loads(get_file_content_summary(file_path, state))
if summary_dict["FilePath"] in state["explored_files"]:
return state
knowledge_str = f"""* File Path: {summary_dict['FilePath']}\n\tSummary: {summary_dict['Summary']}\n\n"""
state["explored_summaries"] += knowledge_str
state["explored_files"].append(file_path)
return state
@tool
def generate_final_mermaid_code():
"""Generate the final mermaid code for the codebase once all the files are explored and the working knowledge is complete."""
return "generate_mermaid_code"
def check_summary_in_store(file_path: str, state: GraphState):
if file_path in state["document_summaries_store"]:
return state["document_summaries_store"][file_path]
return None
def save_summary_in_store(file_path: str, summary: str, state: GraphState):
state["document_summaries_store"][file_path] = summary
def get_all_filesnames_in_codebase(state: GraphState):
"""Get a list of all files (as filepaths) in the codebase."""
filenames = []
for document in state["documents"]:
filenames.append(document.metadata["path"])
return {
"all_files": filenames,
"explored_files": [],
"remaining_files": filenames,
"explored_summaries": "",
"document_summaries_store": {},
}
def parse_plan(state: GraphState):
"""Parse the plan and return the next action."""
if "File Exploration Plan" in state["working_knowledge"]:
plan_working = state["working_knowledge"].split("File Exploration Plan")[1]
else:
plan_working = state["working_knowledge"]
response = llm_structured.invoke(plan_parser.format(plan_list=plan_working))[
"plan_list"
]
if len(response) > 25:
response = response[:25]
# response = eval(llm.invoke(plan_parser.format(plan_list=plan_working)).content)
return {"remaining_files": response}
def router(state: GraphState):
"""Route the conversation to the appropriate node based on the current state of the conversation."""
if state["remaining_files"] != []:
return "explore_file"
else:
return "generate_mermaid_code"
def get_plan_for_codebase(state: GraphState):
new_state = get_all_filesnames_in_codebase(state)
planner_content = "# File Structure\n" + str(new_state["all_files"])
plan = llm.invoke(
[SystemMessage(content=planner_prompt), HumanMessage(content=planner_content)]
)
knowledge_str = f"""# Plan\n{plan.content}"""
new_state["working_knowledge"] = knowledge_str
# print(new_state)
return new_state
def final_mermaid_code_generation(state: GraphState):
final_graph_content = (
"# Disjoint Codebase Understanding\n"
+ state["working_knowledge"]
+ "\n\n# Completed Explorations\n"
+ state["explored_summaries"]
)
response = llm.invoke(
[
SystemMessage(content=final_graph_prompt),
HumanMessage(content=final_graph_content),
]
)
return {"messages": [response]}
import time
def extract_mermaid_and_generate_graph(state: GraphState):
mermaid_code = state["messages"][-1].content
if "mermaid" in mermaid_code:
mermaid_code = mermaid_code.split("mermaid")[-1]
response = llm.invoke(
[SystemMessage(content=mermaid_extracter), HumanMessage(content=mermaid_code)]
).content
response = response.split("```mermaid")[-1].split("```")[0]
# Save the mermaid code in a file with the current timestamp
# print(response)
file_name = f"mermaid/{int(time.time())}.png"
render_mermaid(response, file_name)
# Read image to return as output
img = Image.open(file_name)
return {"messages": [AIMessage(response)], "final_graph": img}
def need_to_update_working_knowledge(state: GraphState):
messages = state["messages"]
last_message = messages[-1]
# prev_to_last_message = messages[-2]
# If the last call is a tool message, we need to update the working knowledge
if last_message.content == "generate_mermaid_code":
return "generate_mermaid_code"
if isinstance(last_message, ToolMessage):
return "tools_knowledge_update"
# Otherwise, we continue with the agent
return "agent"