import os import getpass import html from typing import Annotated, Union from typing_extensions import TypedDict from langchain_community.graphs import Neo4jGraph from langchain_groq import ChatGroq from langchain_openai import ChatOpenAI from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.checkpoint import base from langgraph.graph import add_messages with SqliteSaver.from_conn_string(":memory:") as mem : memory = mem def format_df(df): """ Used to display the generated plan in a nice format Returns html code in a string """ def format_cell(cell): if isinstance(cell, str): # Encode special characters, but preserve line breaks return html.escape(cell).replace('\n', '
') return cell # Convert the DataFrame to HTML with custom CSS formatted_df = df.map(format_cell) html_table = formatted_df.to_html(escape=False, index=False) # Add custom CSS to allow multiple lines and scrolling in cells css = """ """ return css + html_table def format_doc(doc: dict) -> str : formatted_string = "" for key in doc: formatted_string += f"**{key}**: {doc[key]}\n" return formatted_string def _set_env(var: str, value: str = None): if not os.environ.get(var): if value: os.environ[var] = value else: os.environ[var] = getpass.getpass(f"{var}: ") def init_app(openai_key : str = None, groq_key : str = None, langsmith_key : str = None): """ Initialize app with user api keys and sets up proxy settings """ _set_env("GROQ_API_KEY", value=groq_key) _set_env("LANGSMITH_API_KEY", value=langsmith_key) _set_env("OPENAI_API_KEY", value=openai_key) os.environ["LANGSMITH_TRACING_V2"] = "true" os.environ["LANGCHAIN_PROJECT"] = "3GPP Test" os.environ["http_proxy"] = "185.46.212.98:80" os.environ["https_proxy"] = "185.46.212.98:80" os.environ["NO_PROXY"] = "thalescloud.io" def clear_memory(memory, thread_id: str) -> None: """ Clears checkpointer state for a given thread_id, broken for now TODO : fix this """ with SqliteSaver.from_conn_string(":memory:") as mem : memory = mem checkpoint = base.empty_checkpoint() memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={}) def get_model(model : str = "mixtral-8x7b-32768"): """ Wrapper to return the correct llm object depending on the 'model' param """ if model == "gpt-4o": llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/") else: llm = ChatGroq(model=model) return llm class ConfigSchema(TypedDict): graph: Neo4jGraph plan_method: str use_detailed_query: bool class State(TypedDict): messages : Annotated[list, add_messages] store_plan : list[str] current_plan_step : int valid_docs : list[str] class DocRetrieverState(TypedDict): messages: Annotated[list, add_messages] query: str docs: list[dict] cyphers: list[str] current_plan_step : int valid_docs: list[Union[str, dict]] class HumanValidationState(TypedDict): human_validated : bool process_steps : list[str] def update_doc_history(left : list | None, right : list | None) -> list: """ Reducer for the 'docs_in_processing' field. Doesn't work currently because of bad handlinf of duplicates TODO : make this work (reference : https://langchain-ai.github.io/langgraph/how-tos/subgraph/#custom-reducer-functions-to-manage-state) """ if not left: # This shouldn't happen left = [[]] if not right: right = [] for i in range(len(right)): left[i].append(right[i]) return left class DocProcessorState(TypedDict): valid_docs : list[Union[str, dict]] docs_in_processing : list process_steps : list[Union[str,dict]] current_process_step : int