Spaces:
Runtime error
Runtime error
# agent.py | |
import os | |
from dotenv import load_dotenv | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import tools_condition | |
from langgraph.prebuilt import ToolNode | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader | |
from langchain_community.utilities import WikipediaAPIWrapper | |
from langchain_community.document_loaders import ArxivLoader | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_core.tools import tool | |
from langchain.tools.retriever import create_retriever_tool | |
from supabase.client import Client, create_client | |
from sentence_transformers import SentenceTransformer | |
from langchain.embeddings.base import Embeddings | |
from typing import List | |
import numpy as np | |
import yaml | |
import pandas as pd | |
import uuid | |
import requests | |
import json | |
from langchain_core.documents import Document | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from youtube_transcript_api import YouTubeTranscriptApi | |
from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable | |
import re | |
from langchain_community.document_loaders import TextLoader, PyMuPDFLoader | |
from docx import Document as DocxDocument | |
import openpyxl | |
from io import StringIO | |
from transformers import BertTokenizer, BertModel | |
import torch | |
import torch.nn.functional as F | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain_community.tools import Tool | |
import time | |
from huggingface_hub import InferenceClient | |
from langchain_community.llms import HuggingFaceHub | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from huggingface_hub import login | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig | |
from langchain_huggingface import HuggingFaceEndpoint | |
#from langchain.agents import initialize_agent | |
#from langchain.agents import AgentType | |
from typing import Union, List | |
from functools import reduce | |
import operator | |
from typing import Union | |
from functools import reduce | |
from youtube_transcript_api import YouTubeTranscriptApi | |
from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable | |
load_dotenv() | |
def calculator(inputs: Union[str, dict]): | |
""" | |
Perform mathematical operations based on the operation provided. | |
Supports both binary (a, b) operations and list operations. | |
""" | |
# If input is a JSON string, parse it | |
if isinstance(inputs, str): | |
try: | |
import json | |
inputs = json.loads(inputs) | |
except Exception as e: | |
return f"Invalid input format: {e}" | |
# Handle list-based operations like SUM | |
if "list" in inputs: | |
nums = inputs.get("list", []) | |
op = inputs.get("operation", "").lower() | |
if not isinstance(nums, list) or not all(isinstance(n, (int, float)) for n in nums): | |
return "Invalid list input. Must be a list of numbers." | |
if op == "sum": | |
return sum(nums) | |
elif op == "multiply": | |
return reduce(operator.mul, nums, 1) | |
else: | |
return f"Unsupported list operation: {op}" | |
# Handle basic two-number operations | |
a = inputs.get("a") | |
b = inputs.get("b") | |
operation = inputs.get("operation", "").lower() | |
if a is None or b is None or not isinstance(a, (int, float)) or not isinstance(b, (int, float)): | |
return "Both 'a' and 'b' must be numbers." | |
if operation == "add": | |
return a + b | |
elif operation == "subtract": | |
return a - b | |
elif operation == "multiply": | |
return a * b | |
elif operation == "divide": | |
if b == 0: | |
return "Error: Division by zero" | |
return a / b | |
elif operation == "modulus": | |
return a % b | |
else: | |
return f"Unknown operation: {operation}" | |
def wiki_search(query: str) -> str: | |
"""Search Wikipedia for a query and return up to 2 results.""" | |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
formatted_search_docs = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata.get("source", "Wikipedia")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
for doc in search_docs | |
] | |
) | |
return formatted_search_docs | |
def wikidata_query(query: str) -> str: | |
""" | |
Run a SPARQL query on Wikidata and return results. | |
""" | |
endpoint_url = "https://query.wikidata.org/sparql" | |
headers = { | |
"Accept": "application/sparql-results+json" | |
} | |
response = requests.get(endpoint_url, headers=headers, params={"query": query}) | |
data = response.json() | |
return json.dumps(data, indent=2) | |
def web_search(query: str) -> str: | |
"""Search Tavily for a query and return up to 3 results.""" | |
tavily_key = os.getenv("TAVILY_API_KEY") | |
if not tavily_key: | |
return "Error: Tavily API key not set." | |
search_tool = TavilySearchResults(tavily_api_key=tavily_key, max_results=3) | |
search_docs = search_tool.invoke(query=query) | |
formatted_search_docs = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
for doc in search_docs | |
]) | |
return formatted_search_docs | |
def arxiv_search(query: str) -> str: | |
"""Search Arxiv for a query and return maximum 3 result. | |
Args: | |
query: The search query.""" | |
search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
formatted_search_docs = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' | |
for doc in search_docs | |
]) | |
return formatted_search_docs | |
def analyze_attachment(file_path: str) -> str: | |
""" | |
Analyzes attachments including PY, PDF, TXT, DOCX, and XLSX files and returns text content. | |
Args: | |
file_path: Local path to the attachment. | |
""" | |
if not os.path.exists(file_path): | |
return f"File not found: {file_path}" | |
try: | |
ext = file_path.lower() | |
if ext.endswith(".pdf"): | |
loader = PyMuPDFLoader(file_path) | |
documents = loader.load() | |
content = "\n\n".join([doc.page_content for doc in documents]) | |
elif ext.endswith(".txt") or ext.endswith(".py"): | |
# Both .txt and .py are plain text files | |
with open(file_path, "r", encoding="utf-8") as file: | |
content = file.read() | |
elif ext.endswith(".docx"): | |
doc = DocxDocument(file_path) | |
content = "\n".join([para.text for para in doc.paragraphs]) | |
elif ext.endswith(".xlsx"): | |
wb = openpyxl.load_workbook(file_path, data_only=True) | |
content = "" | |
for sheet in wb: | |
content += f"Sheet: {sheet.title}\n" | |
for row in sheet.iter_rows(values_only=True): | |
content += "\t".join([str(cell) if cell is not None else "" for cell in row]) + "\n" | |
else: | |
return "Unsupported file format. Please use PY, PDF, TXT, DOCX, or XLSX." | |
return content[:3000] # Limit output size for readability | |
except Exception as e: | |
return f"An error occurred while processing the file: {str(e)}" | |
def get_youtube_transcript(url: str) -> str: | |
""" | |
Fetch transcript text from a YouTube video. | |
Args: | |
url (str): Full YouTube video URL. | |
Returns: | |
str: Transcript text as a single string. | |
Raises: | |
ValueError: If no transcript is available or URL is invalid. | |
""" | |
try: | |
# Extract video ID | |
video_id = extract_video_id(url) | |
transcript = YouTubeTranscriptApi.get_transcript(video_id) | |
# Combine all transcript text | |
full_text = " ".join([entry['text'] for entry in transcript]) | |
return full_text | |
except (TranscriptsDisabled, VideoUnavailable) as e: | |
raise ValueError(f"Transcript not available: {e}") | |
except Exception as e: | |
raise ValueError(f"Failed to fetch transcript: {e}") | |
def extract_video_id(url: str) -> str: | |
""" | |
Extract the video ID from a YouTube URL. | |
""" | |
match = re.search(r"(?:v=|youtu\.be/)([A-Za-z0-9_-]{11})", url) | |
if not match: | |
raise ValueError("Invalid YouTube URL") | |
return match.group(1) | |
# ----------------------------- | |
# Load configuration from YAML | |
# ----------------------------- | |
with open("config.yaml", "r") as f: | |
config = yaml.safe_load(f) | |
provider = config["provider"] | |
model_config = config["models"][provider] | |
#prompt_path = config["system_prompt_path"] | |
enabled_tool_names = config["tools"] | |
# ----------------------------- | |
# Load system prompt | |
# ----------------------------- | |
# load the system prompt from the file | |
with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
system_prompt = f.read() | |
# System message | |
sys_msg = SystemMessage(content=system_prompt) | |
# ----------------------------- | |
# Map tool names to functions | |
# ----------------------------- | |
tool_map = { | |
"math": calculator, | |
"wiki_search": wiki_search, | |
"web_search": web_search, | |
"arxiv_search": arxiv_search, | |
"get_youtube_transcript": get_youtube_transcript, | |
"extract_video_id": extract_video_id, | |
"analyze_attachment": analyze_attachment, | |
"wikidata_query": wikidata_query | |
} | |
# Then define which tools you want enabled | |
enabled_tool_names = [ | |
"math", | |
"wiki_search", | |
"web_search", | |
"arxiv_search", | |
"get_youtube_transcript", | |
"extract_video_id", | |
"analyze_attachment", | |
"wikidata_query" | |
] | |
tools = [tool_map[name] for name in enabled_tool_names] | |
# Safe version | |
tools = [] | |
for name in enabled_tool_names: | |
if name not in tool_map: | |
print(f"❌ Tool not found: {name}") | |
continue | |
tools.append(tool_map[name]) | |
# ------------------------------- | |
# Step 2: Load the JSON file or tasks (Replace this part if you're loading tasks dynamically) | |
# ------------------------------- | |
from fastapi import FastAPI, Request | |
from langchain_core.documents import Document | |
import uuid | |
app = FastAPI() | |
async def start_questions(request: Request): | |
data = await request.json() | |
questions = data.get("questions", []) | |
docs = [] | |
for task in questions: | |
question_text = task.get("question", "").strip() | |
if not question_text: | |
continue | |
task["id"] = str(uuid.uuid4()) | |
docs.append(Document(page_content=question_text, metadata=task)) | |
return {"message": f"Loaded {len(docs)} questions", "docs": [doc.page_content for doc in docs]} | |
# ------------------------------- | |
# Step 4: Set up BERT Embeddings and FAISS VectorStore | |
# ------------------------------- | |
# ----------------------------- | |
# 1. Define Custom BERT Embedding Model | |
# ----------------------------- | |
class BERTEmbeddings(Embeddings): | |
def __init__(self, model_name='bert-base-uncased'): | |
self.tokenizer = BertTokenizer.from_pretrained(model_name) | |
self.model = BertModel.from_pretrained(model_name) | |
self.model.eval() # Set model to eval mode | |
def embed_documents(self, texts): | |
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
embeddings = outputs.last_hidden_state.mean(dim=1) | |
embeddings = F.normalize(embeddings, p=2, dim=1) # Normalize for cosine similarity | |
return embeddings.cpu().numpy() | |
def embed_query(self, text): | |
return self.embed_documents([text])[0] | |
# ----------------------------- | |
# 2. Initialize Embedding Model | |
# ----------------------------- | |
embedding_model = BERTEmbeddings() | |
# ----------------------------- | |
# 3. Prepare Documents | |
# ----------------------------- | |
docs = [ | |
Document(page_content="Mercedes Sosa released many albums between 2000 and 2009.", metadata={"id": 1}), | |
Document(page_content="She was a prominent Argentine folk singer.", metadata={"id": 2}), | |
Document(page_content="Her album 'Al Despertar' was released in 1998.", metadata={"id": 3}), | |
Document(page_content="She continued releasing music well into the 2000s.", metadata={"id": 4}), | |
] | |
# ----------------------------- | |
# 4. Create FAISS Vector Store | |
# ----------------------------- | |
vector_store = FAISS.from_documents(docs, embedding_model) | |
vector_store.save_local("faiss_index") | |
# ----------------------------- | |
# 6. Create LangChain Retriever Tool | |
# ----------------------------- | |
retriever = vector_store.as_retriever() | |
question_retriever_tool = create_retriever_tool( | |
retriever=retriever, | |
name="Question_Search", | |
description="A tool to retrieve documents related to a user's question." | |
) | |
# Define the LLM before using it | |
#llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo") # or "gpt-3.5-turbo" "gpt-4" | |
#llm = ChatMistralAI(model="mistral-7b-instruct-v0.1") | |
# Get the Hugging Face API token from the environment variable | |
#hf_token = os.getenv("HF_TOKEN") | |
llm = HuggingFaceEndpoint( | |
repo_id="HuggingFaceH4/zephyr-7b-beta", | |
task="text-generation", | |
huggingfacehub_api_token=os.getenv("HF_TOKEN"), | |
temperature=0.7, | |
max_new_tokens=512 | |
) | |
# No longer required as Langgraph is replacing Langchain | |
# Initialize LangChain agent | |
#agent = initialize_agent( | |
# tools=tools, | |
# llm=llm, | |
# agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
# verbose=True | |
#) | |
# ------------------------------- | |
# Step 8: Use the Planner, Classifier, and Decision Logic | |
# ------------------------------- | |
def process_question(question): | |
# Step 1: Planner generates the task sequence | |
tasks = planner(question) | |
print(f"Tasks to perform: {tasks}") | |
# Step 2: Classify the task (based on question) | |
task_type = task_classifier(question) | |
print(f"Task type: {task_type}") | |
# Step 3: Use the classifier and planner to decide on the next task or node | |
state = {"question": question, "last_response": ""} | |
next_task = decide_task(state) | |
print(f"Next task: {next_task}") | |
# Step 4: Use node skipper logic (skip if needed) | |
skip = node_skipper(state) | |
if skip: | |
print(f"Skipping to {skip}") | |
return skip # Or move directly to generating answer | |
# Step 5: Execute task (with error handling) | |
try: | |
if task_type == "wiki_search": | |
response = wiki_search(question) | |
elif task_type == "math": | |
response = calculator(question) | |
else: | |
response = "Default answer logic" | |
# Step 6: Final response formatting | |
final_response = final_answer_tool(state, {'wiki_search': response}) | |
return final_response | |
except Exception as e: | |
print(f"Error executing task: {e}") | |
return "Sorry, I encountered an error processing your request." | |
# Run the process | |
#question = "How many albums did Mercedes Sosa release between 2000 and 2009?" | |
#response = agent.invoke(question) | |
#print("Final Response:", response) | |
def retriever(state: MessagesState): | |
"""Retriever node using similarity scores for filtering""" | |
query = state["messages"][0].content | |
results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches | |
# Dynamically adjust threshold based on query complexity | |
threshold = 0.75 if "who" in query else 0.8 | |
filtered = [doc for doc, score in results if score < threshold] | |
# Provide a default message if no documents found | |
if not filtered: | |
example_msg = HumanMessage(content="No relevant documents found.") | |
else: | |
content = "\n\n".join(doc.page_content for doc in filtered) | |
example_msg = HumanMessage( | |
content=f"Here are relevant reference documents:\n\n{content}" | |
) | |
return {"messages": [sys_msg] + state["messages"] + [example_msg]} | |
# ---------------------------------------------------------------- | |
# LLM Loader | |
# ---------------------------------------------------------------- | |
def get_llm(provider: str, config: dict): | |
if provider == "google": | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
return ChatGoogleGenerativeAI(model=config["model"], temperature=config["temperature"]) | |
elif provider == "groq": | |
from langchain_groq import ChatGroq | |
return ChatGroq(model=config["model"], temperature=config["temperature"]) | |
elif provider == "huggingface": | |
from langchain_huggingface import ChatHuggingFace | |
from langchain_huggingface import HuggingFaceEndpoint | |
return ChatHuggingFace( | |
llm=HuggingFaceEndpoint(url=config["url"], temperature=config["temperature"]) | |
) | |
else: | |
raise ValueError(f"Invalid provider: {provider}") | |
# ---------------------------------------------------------------- | |
# Planning & Execution Logic | |
# ---------------------------------------------------------------- | |
def planner(question: str, tools: list) -> tuple: | |
""" | |
Select the best-matching tool(s) for a question based on keyword-based intent detection and tool metadata. | |
Returns the detected intent and matched tools. | |
""" | |
question = question.lower().strip() | |
# Define intent-based keywords | |
intent_keywords = { | |
"math": ["calculate", "evaluate", "add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"], | |
"wiki_search": ["who is", "what is", "define", "explain", "tell me about", "overview of"], | |
"web_search": ["search", "find", "look up", "google", "latest news", "current info"], | |
"arxiv_search": ["arxiv", "research paper", "scientific paper", "preprint"], | |
"get_youtube_transcript": ["youtube", "watch", "play video", "show me a video"], | |
"extract_video_id": ["analyze video", "summarize video", "video content"], | |
"data_analysis": ["analyze", "plot", "graph", "data", "visualize"], | |
"wikidata_query": ["wikidata", "sparql", "run sparql", "query wikidata"], | |
"default": ["why", "how", "difference between", "compare", "what happens", "reason for", "cause of", "effect of"] | |
} | |
# Step 1: Identify intent | |
detected_intent = None | |
for intent, keywords in intent_keywords.items(): | |
if any(keyword in question for keyword in keywords): | |
detected_intent = intent | |
break | |
# Step 2: Match tools by intent | |
matched_tools = [] | |
if detected_intent: | |
for tool in tools: | |
name = getattr(tool, "name", "").lower() | |
description = getattr(tool, "description", "").lower() | |
if detected_intent in name or detected_intent in description: | |
matched_tools.append(tool) | |
# Step 3: Fallback to general-purpose/default tools if no match found | |
if not matched_tools: | |
matched_tools = [ | |
tool for tool in tools | |
if "default" in getattr(tool, "name", "").lower() | |
or "qa" in getattr(tool, "description", "").lower() | |
] | |
return detected_intent, matched_tools if matched_tools else [tools[0]] | |
import json | |
def task_classifier(question: str) -> str: | |
""" | |
Classifies the question into one of the predefined task categories. | |
""" | |
question = question.lower().strip() | |
# Context-aware intent patterns | |
if any(phrase in question for phrase in [ | |
"calculate", "how much is", "what is the result of", "evaluate", "solve" | |
]) or any(op in question for op in ["add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"]): | |
return "math" | |
elif any(phrase in question for phrase in [ | |
"who is", "what is", "define", "explain", "tell me about", "give me an overview of" | |
]): | |
return "wiki_search" | |
elif any(phrase in question for phrase in [ | |
"search", "find", "look up", "google", "get the latest", "current news", "trending" | |
]): | |
return "web_search" | |
elif any(phrase in question for phrase in [ | |
"arxiv", "latest research", "scientific paper", "research paper", "preprint" | |
]): | |
return "arxiv_search" | |
elif any(phrase in question for phrase in [ | |
"youtube", "watch", "play the video", "show me a video" | |
]): | |
return "get_youtube_transcript" | |
elif any(phrase in question for phrase in [ | |
"analyze video", "summarize video", "what happens in the video", "video content" | |
]): | |
return "video_analysis" | |
elif any(phrase in question for phrase in [ | |
"analyze", "visualize", "plot", "graph", "inspect data", "explore dataset" | |
]): | |
return "data_analysis" | |
elif any(phrase in question for phrase in [ | |
"sparql", "wikidata", "query wikidata", "run sparql", "wikidata query" | |
]): | |
return "wikidata_query" | |
return "default" | |
def select_tool_and_run(question: str, tools: dict): | |
# Step 1: Classify intent | |
intent = task_classifier(question) # assuming task_classifier maps the question to intent | |
# Map intent to tool names | |
intent_tool_map = { | |
"math": "calculator", # maps to tools["math"] → calculator | |
"wiki_search": "wiki_search", # → wiki_search | |
"web_search": "web_search", # → web_search | |
"arxiv_search": "arxiv_search", # → arxiv_search (spelling fixed) | |
"get_youtube_transcript": "get_youtube_transcript", # → get_youtube_transcript | |
"extract_video_id": "extract_video_id", # adjust based on your tools | |
"analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this | |
"wikidata_query": "wikidata_query", # → wikidata_query | |
"default": "default" # → default_tool | |
} | |
# Get the corresponding tool name | |
tool_name = intent_tool_map.get(intent, "default") # Default to "default" if no match | |
# Retrieve the tool from the tools dictionary | |
tool_func = tools.get(tool_name) | |
if not tool_func: | |
return f"Tool not found for intent '{intent}'" | |
# Step 2: Run the tool | |
try: | |
# If the tool needs JSON or structured data | |
try: | |
parsed_input = json.loads(question) | |
except json.JSONDecodeError: | |
parsed_input = question # fallback to raw input if not JSON | |
# Run the selected tool | |
print(f"Running tool: {tool_name} with input: {parsed_input}") # log the tool name and input | |
return tool_func(parsed_input) | |
except Exception as e: | |
return f"Error while running tool '{tool_name}': {str(e)}" | |
# Function to extract math operation from the question | |
def extract_math_from_question(question: str): | |
question = question.lower() | |
# Map natural language to symbols | |
ops = { | |
"add": "+", "plus": "+", | |
"subtract": "-", "minus": "-", | |
"multiply": "*", "times": "*", | |
"divide": "/", "divided by": "/", | |
"modulus": "%", "mod": "%" | |
} | |
for word, symbol in ops.items(): | |
question = re.sub(rf"\b{word}\b", symbol, question) | |
# Extract math expression like "12 + 5" | |
match = re.search(r'(\d+)\s*([\+\-\*/%])\s*(\d+)', question) | |
if match: | |
num1 = int(match.group(1)) | |
operator = match.group(2) | |
num2 = int(match.group(3)) | |
return { | |
"a": num1, | |
"b": num2, | |
"operation": { | |
"+": "add", | |
"-": "subtract", | |
"*": "multiply", | |
"/": "divide", | |
"%": "modulus" | |
}[operator] | |
} | |
return None | |
# Example tool set (adjust these to match your actual tool names) | |
intent_tool_map = { | |
"math": "math", # maps to tools["math"] → calculator | |
"wiki_search": "wiki_search", # → wiki_search | |
"web_search": "web_search", # → web_search | |
"arxiv_search": "arxiv_search", # → arxiv_search (spelling fixed) | |
"get_youtube_transcript": "get_youtube_transcript", # → get_youtube_transcript | |
"extract_video_id": "extract_video_id", # adjust based on your tools | |
"analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this | |
"wikidata_query": "wikidata_query", # → wikidata_query | |
"default": "default" # → default_tool | |
} | |
# The task order can also include the tools for each task | |
priority_order = [ | |
{"task": "math", "tool": "math"}, | |
{"task": "wiki_search", "tool": "wiki_search"}, | |
{"task": "web_search", "tool": "web_search"}, | |
{"task": "arxiv_search", "tool": "arxiv_search"}, | |
{"task": "wikidata_query", "tool": "wikidata_query"}, | |
{"task": "retriever", "tool": "retriever"}, | |
{"task": "get_youtube_transcript", "tool": "get_youtube_transcript"}, | |
{"task": "extract_video_id", "tool": "extract_video_id"}, | |
{"task": "analyze_attachment", "tool": "analyze_attachment"}, | |
{"task": "default", "tool": "default"} # Fallback | |
] | |
def decide_task(state: dict) -> str: | |
"""Decides which task to perform based on the current state.""" | |
# Get the list of tasks from the planner | |
tasks = planner(state["question"]) | |
print(f"Available tasks: {tasks}") # Debugging: show all possible tasks | |
# Check if the tasks list is empty or invalid | |
if not tasks: | |
print("❌ No valid tasks were returned from the planner.") | |
return "default" # Return a default task if no tasks were generated | |
# If there are multiple tasks, we can prioritize based on certain conditions | |
task = tasks[0] # Default to the first task in the list | |
if len(tasks) > 1: | |
print(f"⚠️ Multiple tasks found. Deciding based on priority.") | |
# Example logic to prioritize tasks, adjust based on your use case | |
task = prioritize_tasks(tasks) | |
print(f"Decided on task: {task}") # Debugging: show the final task | |
return task | |
def prioritize_tasks(tasks: list) -> str: | |
"""Prioritize tasks based on certain conditions or criteria, including tools.""" | |
# Sort tasks based on priority_order mapping | |
for priority in priority_order: | |
# Check if any task matches the priority task type | |
for task in tasks: | |
if priority["task"] in task: | |
print(f"✅ Prioritizing task: {task} with tool: {priority['tool']}") # Debugging: show the chosen task and tool | |
# Assign the correct tool based on the task | |
tool = tools.get(priority["tool"], tools["default"]) # Default to 'default_tool' if not found | |
return task, tool | |
# If no priority task is found, return the first task with its default tool | |
return tasks[0], tools["default"] | |
def process_question(question: str): | |
"""Process the question and route it to the appropriate tool.""" | |
# Get the tasks from the planner | |
tasks = planner(question) | |
print(f"Tasks to perform: {tasks}") | |
task_type, tool = decide_task({"question": question}) | |
print(f"Next task: {task_type} with tool: {tool}") | |
if node_skipper({"question": question}): | |
print(f"Skipping task: {task_type}") | |
return "Task skipped." | |
try: | |
# Execute the corresponding tool for the task type | |
if task_type == "wiki_search": | |
response = tool.run(question) # Assuming tool is wiki_tool | |
elif task_type == "math": | |
response = tool.run(question) # Assuming tool is calc_tool | |
elif task_type == "retriever": | |
response = tool.run(question) # Assuming tool is retriever_tool | |
else: | |
response = tool.run(question) # Default tool | |
return generate_final_answer({"question": question}, {task_type: response}) | |
except Exception as e: | |
print(f"❌ Error: {e}") | |
return f"Sorry, I encountered an error: {str(e)}" | |
# To store previously asked questions and timestamps (simulating state persistence) | |
recent_questions = {} | |
def node_skipper(state: dict) -> bool: | |
""" | |
Determines whether to skip the task based on the state. | |
This could include: | |
1. Repeated or similar questions | |
2. Irrelevant or empty questions | |
3. Tasks that have already been processed recently | |
""" | |
question = state.get("question", "").strip() | |
if not question: | |
print("❌ Skipping: Empty or invalid question.") | |
return True # Skip if no valid question | |
# 1. Skip if the question has already been asked recently (within a given time window) | |
# Here, we're using a simple example with a 5-minute window (300 seconds). | |
if question in recent_questions: | |
last_asked_time = recent_questions[question] | |
time_since_last_ask = time.time() - last_asked_time | |
if time_since_last_ask < 300: # 5-minute threshold | |
print(f"❌ Skipping: The question has been asked recently. Time since last ask: {time_since_last_ask:.2f} seconds.") | |
return True # Skip if the question was asked within the last 5 minutes | |
# 2. Skip if the question is irrelevant or not meaningful enough | |
irrelevant_keywords = ["blah", "nothing", "invalid", "nonsense"] | |
if any(keyword in question.lower() for keyword in irrelevant_keywords): | |
print("❌ Skipping: Irrelevant or nonsense question.") | |
return True # Skip if the question contains irrelevant keywords | |
# 3. Skip if the task has already been completed for this question (based on a unique task identifier) | |
if "last_response" in state and state["last_response"]: | |
print("❌ Skipping: Task has already been processed recently.") | |
return True # Skip if a response has already been given | |
# 4. Skip based on a condition related to the task itself | |
# Example: Skip math-related tasks if the result is already known or trivial | |
if "math" in state.get("question", "").lower(): | |
# If math is trivial (like "What is 2+2?") | |
trivial_math = ["2 + 2", "1 + 1", "3 + 3"] | |
if any(trivial_question in question for trivial_question in trivial_math): | |
print(f"❌ Skipping trivial math question: {question}") | |
return True # Skip if the math question is trivial | |
# 5. Skip based on external factors (e.g., current time, system load, etc.) | |
# Example: Avoid processing tasks at night if that's part of the business logic | |
current_hour = time.localtime().tm_hour | |
if current_hour >= 22 or current_hour < 6: | |
print("❌ Skipping: It's night time, not processing tasks.") | |
return True # Skip tasks during night time (e.g., between 10 PM and 6 AM) | |
# If none of the conditions matched, don't skip the task | |
return False | |
# Update recent questions (for simulating repeated question check) | |
def update_recent_questions(question: str): | |
"""Update the recent questions dictionary with the current timestamp.""" | |
recent_questions[question] = time.time() | |
def generate_final_answer(state: dict, task_results: dict) -> str: | |
"""Generate a final answer based on the results of the task.""" | |
if "wiki_search" in task_results: | |
return f"📚 Wiki Summary:\n{task_results['wiki_search']}" | |
elif "math" in task_results: | |
return f"🧮 Math Result: {task_results['math']}" | |
elif "retriever" in task_results: | |
return f"🔍 Retrieved Info: {task_results['retriever']}" | |
else: | |
return "🤖 Unable to generate a specific answer." | |
def answer_question(question: str) -> str: | |
"""Process a single question and return the answer.""" | |
print(f"Processing question: {question[:50]}...") # Debugging: show first 50 chars | |
# Wrap the question in a HumanMessage from langchain_core (assuming langchain is used) | |
messages = [HumanMessage(content=question)] | |
response = graph.invoke({"messages": messages}) # Assuming `graph` is defined elsewhere | |
# Extract the answer from the response | |
answer = response['messages'][-1].content | |
return answer[14:] # Assuming 'answer[14:]' is correct based on your example | |
def process_all_tasks(tasks: list): | |
"""Process a list of tasks.""" | |
results = {} | |
for task in tasks: | |
question = task.get("question", "").strip() | |
if not question: | |
print(f"Skipping task with missing or empty 'question': {task}") | |
continue | |
print(f"\n🟢 Processing Task: {task['task_id']} - Question: {question}") | |
# Call the existing process_question logic | |
response = process_question(question) | |
print(f"✅ Response: {response}") | |
results[task['task_id']] = response | |
return results | |
## Langgraph | |
# Build graph function | |
provider = "huggingface" | |
model_config = { | |
"repo_id": "HuggingFaceH4/zephyr-7b-beta", | |
"task": "text-generation", | |
"temperature": 0.7, | |
"max_new_tokens": 512, | |
"huggingfacehub_api_token": os.getenv("HF_TOKEN") | |
} | |
def build_graph(provider, model_config): | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langgraph.graph import StateGraph, ToolNode | |
from langchain_core.runnables import RunnableLambda | |
from some_module import vector_store # Make sure this is defined/imported | |
# Step 1: Get LLM | |
def get_llm(provider: str, config: dict): | |
if provider == "huggingface": | |
from langchain_huggingface import HuggingFaceEndpoint | |
return HuggingFaceEndpoint( | |
repo_id=config["repo_id"], | |
task=config["task"], | |
huggingfacehub_api_token=config["huggingfacehub_api_token"], | |
temperature=config["temperature"], | |
max_new_tokens=config["max_new_tokens"] | |
) | |
else: | |
raise ValueError(f"Unsupported provider: {provider}") | |
llm = get_llm(provider, model_config) | |
# Step 2: Define tools | |
tools = [ | |
wiki_search, | |
calculator, | |
web_search, | |
arxiv_search, | |
get_youtube_transcript, | |
extract_video_id, | |
analyze_attachment, | |
wikidata_query | |
] | |
# Step 3: Bind tools to LLM | |
llm_with_tools = llm.bind_tools(tools) | |
# Step 4: Build stateful graph logic | |
sys_msg = SystemMessage(content="You are a helpful assistant.") | |
def retriever(state: dict): | |
user_query = state["messages"][0].content | |
similar_docs = vector_store.similarity_search(user_query) | |
if not similar_docs: | |
wiki_result = wiki_search.run(user_query) | |
return { | |
"messages": [ | |
sys_msg, | |
state["messages"][0], | |
HumanMessage(content=f"Using Wikipedia search:\n\n{wiki_result}") | |
] | |
} | |
else: | |
return { | |
"messages": [ | |
sys_msg, | |
state["messages"][0], | |
HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}") | |
] | |
} | |
def assistant(state: dict): | |
return {"messages": [llm_with_tools.invoke(state["messages"])]} | |
def tools_condition(state: dict) -> str: | |
if "use tool" in state["messages"][-1].content.lower(): | |
return "tools" | |
else: | |
return "END" | |
# Step 5: Define LangGraph StateGraph | |
builder = StateGraph(dict) # Using dict as state type here | |
builder.add_node("retriever", retriever) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(tools)) | |
builder.set_entry_point("retriever") | |
builder.add_edge("retriever", "assistant") | |
builder.add_conditional_edges("assistant", tools_condition) | |
builder.add_edge("tools", "assistant") | |
graph = builder.compile() | |
return graph | |
# call build_graph AFTER it’s defined | |
agent = build_graph(provider, model_config) | |
# Now you can use the agent like this: | |
result = agent.invoke({"messages": [HumanMessage(content=question)]}) | |