Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| from langgraph.prebuilt import ToolNode | |
| from typing import TypedDict, Annotated, Literal | |
| from langchain.chat_models import init_chat_model | |
| from langgraph.graph import add_messages, StateGraph, START, END | |
| from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage | |
| from tools import ( | |
| default_file_reader, | |
| image_reader, | |
| excel_column_reader, | |
| excel_find_column_values_sum, | |
| wiki_search, | |
| archive_search, | |
| get_ioc_code, | |
| check_commutativity, | |
| audio_to_text, | |
| video_to_text | |
| ) | |
| load_dotenv() | |
| os.environ['CURL_CA_BUNDLE'] = '' | |
| class AgentState(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| def start_agent(question: str, question_filepath: str): | |
| chat = init_chat_model("claude-3-5-sonnet-20241022", model_provider="anthropic", temperature=0) | |
| tools = [wiki_search, archive_search, get_ioc_code, check_commutativity, video_to_text] | |
| if question_filepath: | |
| #if a file is provided, then add file tools | |
| tools = tools + [default_file_reader, image_reader, excel_column_reader, excel_find_column_values_sum, audio_to_text] | |
| chat_with_tools = chat.bind_tools(tools) | |
| try: | |
| with open("system_prompt.txt", 'r') as sp_file: | |
| system_prompt = sp_file.read() | |
| except FileNotFoundError: | |
| print("Error: unable to open system_prompt.txt") | |
| return None | |
| if question_filepath: | |
| messages = [ | |
| SystemMessage(system_prompt), | |
| HumanMessage(content=f"{question} File located at: {question_filepath}") | |
| ] | |
| else: | |
| messages = [ | |
| SystemMessage(system_prompt), | |
| HumanMessage(content=f"{question}") | |
| ] | |
| def assistant(state: AgentState): | |
| return { | |
| **state, | |
| "messages": [chat_with_tools.invoke(state["messages"])], | |
| } | |
| def validate_answer_format(state: AgentState): | |
| try: | |
| with open("final_answer_validation_prompt.txt", 'r') as favp_file: | |
| final_answer_validation_prompt = favp_file.read() | |
| except FileNotFoundError: | |
| print(f"Error: unable to open final_answer_validation_prompt.txt") | |
| return None | |
| state["messages"].append( | |
| HumanMessage(content=f"Verify your FINAL ANSWER again so it meet user question requirements: {question}") | |
| ) | |
| state["messages"].append( | |
| HumanMessage(content=f"Verify your FINAL ANSWER again so it meets these requirements: {final_answer_validation_prompt}. " | |
| f"Do not use any tool here, just validate format of the final answer.") | |
| ) | |
| return { | |
| **state, | |
| "messages": [chat_with_tools.invoke(state["messages"])], | |
| } | |
| def custom_tool_condition(state: AgentState, messages_key: str = "messages") -> Literal["tools", "validate"]: | |
| if isinstance(state, list): | |
| ai_message = state[-1] | |
| elif isinstance(state, dict) and (messages := state.get(messages_key, [])): | |
| ai_message = messages[-1] | |
| elif messages := getattr(state, messages_key, []): | |
| ai_message = messages[-1] | |
| else: | |
| raise ValueError(f"No messages found in input state to tool_edge: {state}") | |
| if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: | |
| return "tools" | |
| return "validate" | |
| initial_state = AgentState( | |
| messages=messages, | |
| ) | |
| builder = StateGraph(AgentState) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| builder.add_node("validate", validate_answer_format) | |
| builder.add_edge(START, "assistant") | |
| builder.add_conditional_edges("assistant", custom_tool_condition) | |
| builder.add_edge("tools", "assistant") | |
| builder.add_edge("validate", END) | |
| agent = builder.compile() | |
| response = agent.invoke(initial_state) | |
| return response['messages'][-1].content | |