Spaces:
Sleeping
Sleeping
from langgraph.graph import StateGraph, START, END | |
from typing_extensions import TypedDict, Annotated, Literal, Optional | |
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage | |
from langgraph.graph.message import add_messages | |
from langchain_mistralai import ChatMistralAI | |
from langchain_openai import ChatOpenAI | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from langchain_core.runnables.graph import MermaidDrawMethod | |
from langchain_community.tools import DuckDuckGoSearchRun | |
from langchain_community.tools import WikipediaQueryRun | |
from langchain_community.utilities import WikipediaAPIWrapper | |
from langchain_aws import ChatBedrock | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_community.document_loaders import UnstructuredExcelLoader | |
# from langchain_google_vertexai import ChatVertexAI | |
# from langfuse.callback import CallbackHandler | |
import base64 | |
import json | |
import time | |
import requests | |
# import boto3 | |
from yt_dlp import YoutubeDL | |
import os | |
# from urllib.parse import urlparse, parse_qs | |
import re | |
from dotenv import load_dotenv | |
# Load env vars from .env file | |
load_dotenv() | |
# Initialize Langfuse CallbackHandler for LangGraph/Langchain (tracing) | |
# langfuse_handler = CallbackHandler() | |
######## STATE ######## | |
class State(TypedDict): | |
""" | |
A class representing the state of the agent. | |
""" | |
question: str | |
messages: Annotated[list[AnyMessage], add_messages] | |
input_file: str | |
downloaded_file: Optional[str] | |
task_id: str | |
web_search_node_result: AnyMessage | |
thinking_node_result: AnyMessage | |
vision_node_result: AnyMessage | |
video_node_result: AnyMessage | |
audio_node_result: AnyMessage | |
code_node_result: AnyMessage | |
excel_node_result: AnyMessage | |
next_node: str | |
######################## | |
######## MODELS ######## | |
def get_general_model(): | |
llm_provider = os.getenv("LLM_PROVIDER", "mistral") | |
if llm_provider == "mistral": | |
general_model = ChatMistralAI( | |
model="mistral-large-2411",#"ministral-8b-latest",#"mistral-small-latest", | |
temperature=0, | |
max_retries=2, | |
api_key=os.getenv("MISTRAL_API_KEY") | |
) | |
if llm_provider == "aws": | |
general_model = ChatBedrock( | |
model_id="arn:aws:bedrock:us-east-1:416545197702:inference-profile/us.amazon.nova-lite-v1:0", | |
# provider="amazon", | |
temperature=0, | |
region_name="eu-west-3", | |
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), | |
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY") | |
) | |
return general_model | |
def get_big_model(): | |
big_model = ChatMistralAI( | |
model="mistral-medium-2505", | |
temperature=0, | |
max_retries=2, | |
api_key=os.getenv("MISTRAL_API_KEY") | |
) | |
return big_model | |
def get_vision_model(): | |
vlm_provider = os.getenv("VLM_PROVIDER", "mistral") | |
if vlm_provider == "openai": | |
print("Spawning Open AI VLM") | |
vision_model = ChatOpenAI( | |
model="gpt-4o", | |
temperature=0, | |
max_tokens=None, | |
timeout=None, | |
max_retries=2, | |
api_key=os.getenv("OPENAI_API_KEY"), | |
) | |
if vlm_provider == "mistral": | |
print("Spawning Mistral VLM") | |
vision_model = ChatMistralAI( | |
model="pixtral-12b-2409",#"mistral-small-latest","pixtral-large-latest",# | |
temperature=0, | |
max_retries=2, | |
api_key=os.getenv("MISTRAL_API_KEY") | |
) | |
return vision_model | |
def get_video_handler_model(): | |
video_handler_model = ChatGoogleGenerativeAI( | |
model="gemini-2.0-flash", | |
temperature=0, | |
max_tokens=None, | |
timeout=None, | |
max_retries=2, | |
# other params... | |
) | |
return video_handler_model | |
def get_audio_handler_model(): | |
audio_handler_model = ChatOpenAI( | |
model="gpt-4o-audio-preview-2024-12-17",#,gpt-4o-mini-audio-preview-2024-12-17",# | |
temperature=0, | |
max_tokens=None, | |
timeout=None, | |
max_retries=2, | |
api_key=os.getenv("OPENAI_API_KEY"), | |
) | |
return audio_handler_model | |
######################## | |
######## Functions ######## | |
def download_youtube_content(url: str, output_path: Optional[str] = None) -> None: | |
""" | |
Download YouTube content (single video or playlist) in MP4 format only. | |
Args: | |
url (str): URL of the YouTube video or playlist | |
output_path (str, optional): Directory to save the downloads. Defaults to './downloads' | |
""" | |
# Set default output path if none provided | |
if output_path is None: | |
output_path = os.path.join(os.getcwd(), 'downloads') | |
# Create output directory if it doesn't exist | |
os.makedirs(output_path, exist_ok=True) | |
# Configure yt-dlp options for MP4 only | |
ydl_opts = { | |
'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', | |
'merge_output_format': 'mp4', | |
'ignoreerrors': True, | |
'no_warnings': False, | |
'extract_flat': False, | |
# Disable all additional downloads | |
'writesubtitles': False, | |
'writethumbnail': False, | |
'writeautomaticsub': False, | |
'postprocessors': [{ | |
'key': 'FFmpegVideoConvertor', | |
'preferedformat': 'mp4', | |
}], | |
# Clean up options | |
'keepvideo': False, | |
'clean_infojson': True | |
} | |
ydl_opts['outtmpl'] = os.path.join(output_path, '%(title)s.%(ext)s') | |
print("Detected single video URL. Downloading video...") | |
try: | |
with YoutubeDL(ydl_opts) as ydl: | |
# Download content | |
ydl.download([url]) | |
print(f"\nDownload completed successfully! Files saved to: {output_path}") | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") | |
result = os.listdir(output_path) | |
video_file_names = [x for x in result if re.match(r".*\.mp4$", x)] | |
if len(video_file_names) == 1: | |
video_file_name = video_file_names.pop() | |
video_file_name = f"{output_path}/{video_file_name}" | |
else: | |
video_file_name = None | |
for other_files in result: | |
if f"{output_path}/{other_files}" != video_file_name: | |
print(f"Removing file: {other_files}") | |
os.remove(os.path.join(output_path, other_files)) | |
return video_file_name | |
web_search = DuckDuckGoSearchRun() | |
wikipedia_search = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) | |
def download_input_file(task_id: str) -> str: | |
""" | |
Download the file specified in state input_file key. | |
You only need the task_id to download the file. | |
Args: | |
task_id (str): The task_id of the file to download. | |
Returns: | |
str: The path to the downloaded file. | |
""" | |
output_path = os.path.join(os.getcwd(), 'downloads') | |
api_url = os.getenv("DEFAULT_API_URL") | |
# Create output directory if it doesn't exist | |
os.makedirs(output_path, exist_ok=True) | |
# Construct the full URL | |
url = f"{api_url}/files/{task_id}" | |
try: | |
# Send a GET request to download the file | |
response = requests.get(url, stream=True) | |
response.raise_for_status() # Raise an error for bad status codes | |
headers = dict(response.headers) | |
attachement = headers["content-disposition"] | |
regex_result = re.search(r'filename="(.*)"', attachement) | |
filename = regex_result.group(1) | |
# Define the output file path | |
output_file_path = os.path.join(output_path, filename) | |
# Write the file to the output path | |
with open(output_file_path, 'wb') as file: | |
for chunk in response.iter_content(chunk_size=8192): | |
file.write(chunk) | |
print(f"File downloaded successfully and saved to: {output_file_path}") | |
return output_file_path | |
except requests.exceptions.RequestException as e: | |
print(f"An error occurred while downloading the file: {str(e)}") | |
return "" | |
######################## | |
######## LLM associations ######## | |
general_model = get_general_model() | |
big_model = get_big_model() | |
vision_model = get_vision_model() | |
video_handler_model = get_video_handler_model() | |
audio_handler_model = get_audio_handler_model() | |
######################## | |
######## Nodes Definition ######## | |
search_tools = [ | |
web_search, | |
wikipedia_search, | |
] | |
download_file_tool = [ download_input_file ] | |
web_search_node_agent = general_model.bind_tools(search_tools, parallel_tool_calls=False) | |
def thinking_node(state: State) -> dict: | |
""" | |
A powerful node to answer general questions, reflection, maths, deduction, prediction. | |
This node does not handle files | |
This node does not handle images or pictures | |
This node does not handle videos | |
This node does not handle audio | |
This node does not handle code | |
Args: | |
state (State): A dictionary containing the current state of the agent, including the 'question' key which holds the question to be answered. | |
Returns: | |
dict: A dictionary containing the response from the web search node, with the key 'thinking_node_result' holding the list of messages generated by the general model. | |
""" | |
prompt = f""" | |
You are a powerful assistant that answers general questions, reflection, maths, deduction, prediction. | |
1. You need to fully understand the question | |
2. You must think hard about what is relevant in the question to make the best answer | |
3. If there are calculations or maths, you need to verify twice before answering. | |
4. Report your thought process in detail, explaining your reasoning step-by-step. | |
Here is the question {state['question']} | |
Now provide your response immediately without any preamble in text but not in markdown. | |
""" | |
state["thinking_node_result"] = state.get("thinking_node_result", "") | |
sys_msg = SystemMessage(content=prompt) | |
thinking_node_response = [general_model.invoke([sys_msg] + [state["thinking_node_result"]])] | |
thinking_node_response[-1].pretty_print() | |
return { | |
"thinking_node_result": thinking_node_response, | |
} | |
def code_node(state: State) -> dict: | |
""" | |
A powerful node to handle and understand code. | |
This node does not handle images or pictures | |
This node does not handle videos | |
This node does not handle audio | |
This node does not access the web | |
Args: | |
state (State): A dictionary containing the current state of the agent, including the 'question' key which holds the question to be answered. | |
Returns: | |
dict: A dictionary containing the response from the web search node, with the key 'code_node_result' holding the list of messages generated by the general model. | |
""" | |
with open(state["downloaded_file"], "r") as code_file: | |
code = code_file.read() | |
prompt = f""" | |
You are a powerful assistant that handle and understand code. | |
1. You need to fully understand the question. | |
2. You must think hard about the code and predict the result to answer the question. | |
3. Report your thought process in detail, explaining your reasoning step-by-step. | |
Here is the question : {state['question']} | |
Here is the code : {code} | |
Now provide your response immediately without any preamble in text but not in markdown. | |
""" | |
sys_msg = SystemMessage(content=prompt) | |
code_node_response = [general_model.invoke([sys_msg])] | |
code_node_response[-1].pretty_print() | |
return { | |
"code_node_result": code_node_response, | |
} | |
def web_search_node(state: State) -> dict: | |
""" | |
A powerful node to answer questions and make research on the web based on the question provided in the state. | |
This node does not handle files | |
This node does not handle images or pictures | |
This node does not handle videos | |
This node does not handle audio | |
This node does not handle code | |
Args: | |
state (State): A dictionary containing the current state of the agent, including the 'question' key which holds the question to be answered. | |
Returns: | |
dict: A dictionary containing the response from the web search node, with the key 'web_search_node_result' holding the list of messages generated by the general model. | |
""" | |
prompt = f""" | |
You are a powerful assistant that makes research on the web in order to give the best answer to the question. | |
1. You need to fully understand the question | |
2. You must think hard about what is relevant in the question to make the best search with write words | |
3. You must use the best of the tools you have to answer the question precisly | |
4. Report your thought process in detail, explaining your reasoning step-by-step. | |
5. You must not change the way words or identifiers are written in the web search results. | |
Here are the tools available: | |
web_search: | |
{web_search.description} | |
Args: | |
{web_search.args_schema} | |
Returns: | |
{web_search.response_format} | |
wikipedia_search: | |
{wikipedia_search.description} | |
Args: | |
{wikipedia_search.args_schema} | |
Returns: | |
{wikipedia_search.response_format} | |
Here is the question {state['question']} | |
Now provide your response immediately without any preamble in text but not in markdown. | |
""" | |
state["web_search_node_result"] = state.get("web_search_node_result", "") | |
sys_msg = SystemMessage(content=prompt) | |
web_search_node_response = [web_search_node_agent.invoke([sys_msg] + [state["web_search_node_result"]])] | |
web_search_node_response[-1].pretty_print() | |
return { | |
"web_search_node_result": web_search_node_response, | |
} | |
def vision_node(state: State) -> dict: | |
""" | |
Vision model that can analyze images and pictures and answer questions about them. | |
This node does not handle videos. | |
This node does not handle audio. | |
This node does not handle code. | |
Args: | |
state (State): A dictionary containing the current state of the agent, including the 'question' key which holds the question to be answered and the 'input_file' key which holds the path to the image file. | |
Returns: | |
dict: A dictionary containing the response from the vision node, with the key 'vision_node_result' holding the list of messages generated by the vision model. | |
""" | |
prompt = f""" | |
You are a powerful vision assistant, you can analyze images and answer question about the picture | |
1. You need to fully understand the question. | |
2. You must think hard about what is relevant in the image to make the best answer to the question. | |
3. Report your thought process in detail, explaining your reasoning step-by-step. | |
Here is the question {state['question']} | |
Now provide your response immediately without any preamble in text but not in markdown. | |
""" | |
image_base64 = "" | |
try: | |
with open(state["downloaded_file"], "rb") as image_file: | |
image_bytes = image_file.read() | |
image_base64 = base64.b64encode(image_bytes).decode("utf-8") | |
mistral_image_handling = { | |
"type": "image_url", | |
"image_url": f"data:image/png;base64,{image_base64}", | |
} | |
openai_image_handling = { | |
"type": "image", | |
"source_type": "base64", | |
"mime_type": "image/png", # or image/png, etc. | |
"data": image_base64, | |
} | |
vision_provider = os.getenv("VLM_PROVIDER", "mistral") | |
if vision_provider == "openai": | |
image_handling = openai_image_handling | |
else: | |
image_handling = mistral_image_handling | |
message = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": prompt, | |
}, | |
image_handling | |
] | |
} | |
] | |
vision_node_response = [vision_model.invoke( | |
input=message, | |
# config={ | |
# "callbacks": [langfuse_handler] | |
# } | |
)] | |
vision_node_response[-1].pretty_print() | |
return { | |
"vision_node_result": vision_node_response | |
} | |
except Exception as e: | |
# A butler should handle errors gracefully | |
error_msg = f"Error extracting text: {str(e)}" | |
print(error_msg) | |
return {} | |
def video_node(state: State) -> str: | |
""" | |
Video handler model that can analyze videos and answer questions about them. | |
This node does not handle images or pictures. | |
This node does not handle audio. | |
This node does not handle code. | |
Args: | |
state (State): A dictionary containing the current state of the agent, including the 'question' key which holds the question to be answered. | |
Returns: | |
dict: A dictionary containing the response from the video handler node, with the key 'video_node_result' holding the list of messages generated by the video handler model. | |
""" | |
prompt = f""" | |
You are a highly capable video analysis assistant. Your task is to watch and analyze the provided video content and answer the user's question as accurately and concisely as possible. | |
1. You need to fully understand the question. | |
2. Carefully observe the video, paying attention to relevant details, actions, and context. | |
3. Focus on the user's question. | |
4. If the question requires counting, identifying, or describing, be precise and clear in your response. | |
5. If you are unsure, state what you can infer from the video. | |
6. Do not make up information that is not visible or inferable from the video. | |
Here is the question {state['question']} | |
Now provide your response immediately without any preamble in text but not in markdown. | |
""" | |
if re.search(r'youtube\.com', state["question"]): | |
# More flexible regex pattern to match YouTube URLs | |
regex_result = re.search(r"(?P<youtube_url>https://(?:www\.)?youtube\.com/watch\?v=[a-zA-Z0-9_-]+)", state["question"]) | |
if regex_result: | |
video_url = regex_result.group("youtube_url") | |
downloaded_video = download_youtube_content(url=video_url) | |
else: | |
# Fallback if regex doesn't match | |
print("Could not extract YouTube URL from question. Using question as fallback.") | |
downloaded_video = state["downloaded_file"] | |
else: | |
downloaded_video = state["downloaded_file"] | |
print(f"Downloaded video: {downloaded_video}") | |
video_mime_type = "video/mp4" | |
with open(downloaded_video, "rb") as video_file: | |
encoded_video = base64.b64encode(video_file.read()).decode("utf-8") | |
os.remove(downloaded_video) | |
message = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": prompt, | |
}, | |
{ | |
"type": "media", | |
"data": encoded_video, # Use base64 string directly | |
"mime_type": video_mime_type, | |
}, | |
] | |
} | |
] | |
video_node_response = [video_handler_model.invoke( | |
input=message, | |
# config={ | |
# "callbacks": [langfuse_handler] | |
# } | |
)] | |
video_node_response[-1].pretty_print() | |
return { | |
"video_node_result": video_node_response | |
} | |
def audio_node(state: State) -> str: | |
""" | |
Audio handler model that can analyze audio and answer questions about it. | |
This node does not handle images or pictures. | |
This node does not handle video. | |
This node does not handle code. | |
Args: | |
state (State): with question key inside | |
Returns: | |
dict: A dictionary containing the response from the video handler node, with the key 'audioo_node_result' holding the list of messages generated by the audio handler model. | |
""" | |
prompt = f""" | |
You are a highly capable audio analysis assistant. Your task is to listen to and analyze the provided audio content and answer the user's question as accurately and concisely as possible. | |
1. You need to fully understand the question. | |
2. Carefully listen to the audio, paying attention to relevant details, actions, and context. | |
3. Focus on the user's question. | |
4. If the question requires counting, identifying, or describing, be precise and clear in your response. | |
5. If you are unsure, state what you can infer from the audio. | |
6. Do not make up information that is not audible or inferable from the audio. | |
Here is the question {state['question']} | |
Now provide your response immediately without any preamble in text but not in markdown. | |
""" | |
downloaded_audio = state["downloaded_file"] | |
print(f"Downloaded audio: {downloaded_audio}") | |
audio_format = re.search(r'\.(\w+)$', downloaded_audio).group(1) | |
with open(downloaded_audio, "rb") as audio_file: | |
encoded_audio = base64.b64encode(audio_file.read()).decode() | |
os.remove(downloaded_audio) | |
message = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": prompt, | |
}, | |
{ | |
"type": "input_audio", | |
"input_audio": { | |
"data": encoded_audio, | |
"format": audio_format, | |
} | |
}, | |
] | |
} | |
] | |
audio_node_response = [audio_handler_model.invoke( | |
input=message, | |
# config={ | |
# "callbacks": [langfuse_handler] | |
# } | |
)] | |
audio_node_response[-1].pretty_print() | |
return { | |
"audio_node_result": audio_node_response | |
} | |
def excel_node(state: State): | |
""" | |
Excel handler model that can analyze excel files and answer questions about it. | |
This node does not handle images or pictures. | |
This node does not handle video. | |
This node does not handle code. | |
This node does not handle audio. | |
Args: | |
state (State): with question key inside | |
Returns: | |
dict: A dictionary containing the response from the excel handler node, with the key 'excel_node_result' holding the list of messages generated by the excel handler model. | |
""" | |
loader = UnstructuredExcelLoader(state["downloaded_file"], mode="elements") | |
docs = loader.load() | |
prompt = f""" | |
You are a powerful assistant which handles excel files. | |
1. You need to fully understand the question. | |
2. You must analyze the excel file to answer the question. | |
3. If the question requires counting, identifying, or describing, be precise and clear in your response. | |
4. Do not make up information that is not in the excel file. | |
Here is the question {state['question']} | |
Here is the excel file loaded in a Document object: {docs}. You will find htlm content of the file in the 'text_as_html' key. | |
Now provide your response immediately without any preamble in text but not in markdown. | |
""" | |
response = big_model.invoke( | |
input=prompt, | |
# config={ | |
# "callbacks": [langfuse_handler] | |
# } | |
) | |
response.pretty_print() | |
return { | |
"excel_node_result": response | |
} | |
def format_answer_node(state: State): | |
""" | |
Format answer node that formats the answer of the last node. | |
This node does not handle images or pictures. | |
This node does not handle video. | |
This node does not handle audio. | |
This node does not handle code. | |
Args: | |
state (State): with question key inside, and all other nodes results | |
Returns: | |
dict: A dictionary containing the response from the format answer node, with the key 'format_answer_node_result' holding the list of messages generated by the format answer model. | |
""" | |
prompt = """ | |
You are the best assistant for final answer formating. | |
1. You must not change the content of the response of the last node. | |
2. You must fully understand the question | |
3. You must return the answer by following hard the format and the constraints | |
4. Report your thought process in detail, explaining your reasoning step-by-step. | |
5. Conclude your answer with the following template: | |
FINAL ANSWER: [YOUR FINAL ANSWER] | |
## Response Format | |
- If asked for a number: | |
For exemple 'How many' or a question asking for a number result | |
- Provide the number without commas, dollar signs, percent signs, or any units (unless specified). | |
- Provide digits, not words | |
- If asked for a string: | |
- Write the string without articles (a, an, the). | |
- Don't answer a full sentence when a short version is enough. | |
- Do not use abbreviations (e.g., for cities). | |
- Write digits in text but (e.g., "one" instead of "1") unless specified otherwise. | |
- Start the first word with a capital letter. | |
- If asked for a comma-separated list: | |
- Apply the above rules for numbers and strings to each element in the list. | |
- And take care of having a space after each comma. | |
## Constraints | |
- You must not answer if the constraints above are not respected. | |
- Your final answer should be provided in the format: FINAL ANSWER: [YOUR FINAL ANSWER] | |
- Your final answer should be a number, a string, or a comma-separated list of numbers and/or strings, following the specified formatting rules. | |
Now provide your response immediately without any preamble in text but not in markdown. | |
""" | |
nodes_response = [HumanMessage(content="Here are the results of the previous nodes")] | |
question = [HumanMessage(content=state["question"])] | |
for node_result in ["web_search_node_result", "vision_node_result", "video_node_result", "audio_node_result", "thinking_node_result", "code_node_result", "excel_node_result"]: | |
result = state.get(node_result, "") | |
if result: | |
# Ensure result is a string. If it's a message object, extract its content. | |
if hasattr(result, "content"): | |
content = result.content | |
else: | |
content = str(result) | |
nodes_response.append(HumanMessage(content=content)) | |
sys_msg = SystemMessage(content=prompt) | |
response = [general_model.invoke([sys_msg] + state["messages"]+ question + nodes_response)] | |
return { | |
"messages": response, | |
} | |
######################## | |
######## Entry Node ######## | |
def entry_node(state: State)-> str: | |
# System message | |
system_prompt = f""" | |
You are a powerful assistant that handle the user message and manage other nodes in order to provide the best answer to the question. | |
You do not handle images or pictures | |
You do not handle videos | |
You do not handle audio | |
You do not handle code | |
You do not handle excel files | |
1. You need to fully understand the subject of the question | |
2. You need to understand the subject of the question with the question itself and the file extension | |
For example of extensions: | |
- .py is for code | |
- .wav or .mp3 is for audio | |
- a youtube url is for video | |
- a .jpg, .png, .jpeg is for image | |
- a .xlsx or .xls is for excel | |
3. You must think hard about what is relevant in the question to make the best choice for the next node | |
4. You must not answer the question by yourself | |
5. Report your thought process in detail, explaining your reasoning step-by-step. | |
Here are the nodes you can choose: | |
- thinking_node: {thinking_node.__doc__} | |
- web_search_node: {web_search_node.__doc__} | |
- vision_node: {vision_node.__doc__} | |
- video_node: {video_node.__doc__} | |
- audio_node: {audio_node.__doc__} | |
- code_node: {code_node.__doc__} | |
- excel_node: {excel_node.__doc__} | |
Here is the question : {state['question']} | |
Here is the file : {state.get("input_file", "no file to handle")} | |
Now provide your response immediately. | |
You must always respect this format in lower case: next node <the node name you choose>. | |
""" | |
downloaded = "" | |
# If there's an input file, download it directly: | |
if state.get("input_file", None): | |
downloaded = download_input_file(state.get("task_id")) | |
sys_msg = SystemMessage(content=system_prompt) | |
entry_node_response = [general_model.invoke([sys_msg] + state["messages"])] | |
entry_node_response[-1].pretty_print() | |
regex_result = re.search(r'.*next.*(?P<next_node>thinking_node|web_search_node|vision_node|video_node|audio_node|code_node|excel_node)', entry_node_response[-1].content, re.IGNORECASE) | |
next_node = "END" | |
if regex_result: | |
# Extract the node name and remove any quotes around it | |
next_node = regex_result.group("next_node") | |
next_node = next_node.lower() | |
print(f"Next node to invoke: {next_node}") | |
return { | |
"next_node": next_node, | |
"downloaded_file": downloaded | |
} | |
######################## | |
######## Build Graph ######## | |
def buildweb_search_graph(): | |
builder = StateGraph(State) | |
builder.add_node("web_search_node", web_search_node) | |
builder.add_node("tools", ToolNode(search_tools)) | |
builder.add_edge(START, "web_search_node") | |
builder.add_conditional_edges( | |
"web_search_node", | |
tools_condition, | |
) | |
builder.add_edge("tools", "web_search_node") | |
builder.add_edge("web_search_node", END) | |
return builder.compile() | |
def build_graph(): | |
builder = StateGraph(State) | |
builder.add_node("entry_node", entry_node) | |
builder.add_node("web_search_node", buildweb_search_graph()) | |
builder.add_node("vision_node", vision_node) | |
builder.add_node("video_node", video_node) | |
builder.add_node("audio_node", audio_node) | |
builder.add_node("code_node", code_node) | |
builder.add_node("thinking_node", thinking_node) | |
builder.add_node("excel_node", excel_node) | |
builder.add_node("format_answer_node", format_answer_node) | |
builder.add_edge(START, "entry_node") | |
# Conditional routing from entry_node to specialized nodes | |
builder.add_conditional_edges( | |
"entry_node", | |
lambda state: state["next_node"], | |
{ | |
"web_search_node": "web_search_node", | |
"vision_node": "vision_node", | |
"video_node": "video_node", | |
"audio_node": "audio_node", | |
"code_node": "code_node", | |
"excel_node": "excel_node", | |
"thinking_node": "thinking_node" | |
} | |
) | |
# After specialized node, go to END | |
builder.add_edge("web_search_node", "format_answer_node") | |
builder.add_edge("vision_node", "format_answer_node") | |
builder.add_edge("video_node", "format_answer_node") | |
builder.add_edge("audio_node", "format_answer_node") | |
builder.add_edge("code_node", "format_answer_node") | |
builder.add_edge("excel_node", "format_answer_node") | |
builder.add_edge("thinking_node", "format_answer_node") | |
builder.add_edge("format_answer_node", END) | |
return builder.compile() | |
######################## | |
if __name__ == "__main__": | |
agent_graph = build_graph() | |
# Save the Mermaid diagram as text instead of trying to render as PNG | |
# This avoids issues with Pyppeteer browser launching | |
# with open("graph.png", "wb") as f: | |
# f.write(agent_graph.get_graph(xray=True).draw_mermaid_png()) | |
# print("Graph saved as graph.png") | |
# print(vision_node.__doc__) | |
with open("./responses.json", "r") as responses: | |
json_responses = json.loads(responses.read()) | |
# json_questions = [{ | |
# "question": "The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.", | |
# "file_name": "7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx", | |
# "task_id": "7bd855d8-463d-4ed5-93ca-5fe35145f733" | |
# }] | |
with open("questions.json", "r") as questions: | |
json_questions = json.loads(questions.read()) | |
for input in json_questions: | |
question = input.get("question", "No question found") | |
file_name = input.get("file_name", "") | |
task_id = input.get("task_id", "") | |
print(f"QUESTION : {question}") | |
print(f"FILE: {file_name}") | |
user_prompt = [HumanMessage(content="Can you answer the question please ?")] | |
user_input = {"messages": user_prompt, "question": question, "input_file": file_name, "task_id": task_id} | |
messages = agent_graph.invoke( | |
input=user_input, | |
config={ | |
"recursion_limit": 10, | |
# "callbacks": [langfuse_handler] | |
} | |
) | |
for m in messages['messages']: | |
m.pretty_print() | |
try: | |
regex_result = re.search(r"FINAL ANSWER:\s*(?P<answer>.*)$", messages['messages'][-1].content) | |
answer = regex_result.group("answer") | |
except: | |
regex_result = re.search(r"\s*(?P<answer>.*)$", messages['messages'][-1].content) | |
answer = regex_result.group("answer") | |
print(answer) | |
if answer == json_responses.get(task_id, ""): | |
print("The answer is correct !") | |
else: | |
print("The answer is incorrect !") | |
print(f"Expected: {json_responses.get(task_id, '')}") | |
print(f"Got: {answer}") | |