Spaces:
Sleeping
Sleeping
| # agent.py | |
| import os | |
| from dotenv import load_dotenv | |
| from langgraph.graph import START, StateGraph, MessagesState | |
| from langgraph.prebuilt import tools_condition | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint#, HuggingFaceEmbeddings | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langchain_community.utilities import WikipediaAPIWrapper | |
| from langchain_community.document_loaders import ArxivLoader | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_core.tools import tool | |
| from sentence_transformers import SentenceTransformer | |
| from langchain.embeddings.base import Embeddings | |
| from typing import List | |
| import numpy as np | |
| import yaml | |
| import pandas as pd | |
| import uuid | |
| import requests | |
| import json | |
| from langchain_core.documents import Document | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable | |
| import re | |
| from langchain_community.document_loaders import TextLoader, PyMuPDFLoader | |
| from docx import Document as DocxDocument | |
| import openpyxl | |
| from io import StringIO | |
| from transformers import BertTokenizer, BertModel | |
| import torch | |
| import torch.nn.functional as F | |
| from langchain_community.chat_models import ChatOpenAI | |
| from langchain_community.tools import Tool | |
| import time | |
| from huggingface_hub import InferenceClient | |
| from langchain_community.llms import HuggingFaceHub | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from huggingface_hub import login | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| #from langchain.agents import initialize_agent | |
| #from langchain.agents import AgentType | |
| from typing import Union | |
| from functools import reduce | |
| import operator | |
| from typing import Union | |
| from functools import reduce | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable | |
| from langchain.schema import Document | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain.tools.retriever import create_retriever_tool | |
| #from langchain_community.tools import create_retriever_tool | |
| from typing import TypedDict, Annotated, List | |
| import gradio as gr | |
| from langchain.schema import Document | |
| load_dotenv() | |
| def calculator(inputs: Union[str, dict]): | |
| """ | |
| Perform mathematical operations based on the operation provided. | |
| Supports both binary (a, b) operations and list operations. | |
| """ | |
| # If input is a JSON string, parse it | |
| if isinstance(inputs, str): | |
| try: | |
| import json | |
| inputs = json.loads(inputs) | |
| except Exception as e: | |
| return f"Invalid input format: {e}" | |
| # Handle list-based operations like SUM | |
| if "list" in inputs: | |
| nums = inputs.get("list", []) | |
| op = inputs.get("operation", "").lower() | |
| if not isinstance(nums, list) or not all(isinstance(n, (int, float)) for n in nums): | |
| return "Invalid list input. Must be a list of numbers." | |
| if op == "sum": | |
| return sum(nums) | |
| elif op == "multiply": | |
| return reduce(operator.mul, nums, 1) | |
| else: | |
| return f"Unsupported list operation: {op}" | |
| # Handle basic two-number operations | |
| a = inputs.get("a") | |
| b = inputs.get("b") | |
| operation = inputs.get("operation", "").lower() | |
| if a is None or b is None or not isinstance(a, (int, float)) or not isinstance(b, (int, float)): | |
| return "Both 'a' and 'b' must be numbers." | |
| if operation == "add": | |
| return a + b | |
| elif operation == "subtract": | |
| return a - b | |
| elif operation == "multiply": | |
| return a * b | |
| elif operation == "divide": | |
| if b == 0: | |
| return "Error: Division by zero" | |
| return a / b | |
| elif operation == "modulus": | |
| return a % b | |
| else: | |
| return f"Unknown operation: {operation}" | |
| def wiki_search(query: str) -> str: | |
| """Search Wikipedia for a query and return up to 2 results.""" | |
| search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata.get("source", "Wikipedia")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
| for doc in search_docs | |
| ] | |
| ) | |
| return formatted_search_docs | |
| def wikidata_query(query: str) -> str: | |
| """ | |
| Run a SPARQL query on Wikidata and return results. | |
| """ | |
| endpoint_url = "https://query.wikidata.org/sparql" | |
| headers = { | |
| "Accept": "application/sparql-results+json" | |
| } | |
| response = requests.get(endpoint_url, headers=headers, params={"query": query}) | |
| data = response.json() | |
| return json.dumps(data, indent=2) | |
| def web_search(query: str) -> str: | |
| """Search Tavily for a query and return up to 3 results.""" | |
| tavily_key = os.getenv("TAVILY_API_KEY") | |
| if not tavily_key: | |
| return "Error: Tavily API key not set." | |
| search_tool = TavilySearchResults(tavily_api_key=tavily_key, max_results=3) | |
| search_docs = search_tool.invoke(query=query) | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
| for doc in search_docs | |
| ]) | |
| return formatted_search_docs | |
| def arxiv_search(query: str) -> str: | |
| """Search Arxiv for a query and return maximum 3 result. | |
| Args: | |
| query: The search query.""" | |
| search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' | |
| for doc in search_docs | |
| ]) | |
| return formatted_search_docs | |
| def analyze_attachment(file_path: str) -> str: | |
| """ | |
| Analyzes attachments including PY, PDF, TXT, DOCX, and XLSX files and returns text content. | |
| Args: | |
| file_path: Local path to the attachment. | |
| """ | |
| if not os.path.exists(file_path): | |
| return f"File not found: {file_path}" | |
| try: | |
| ext = file_path.lower() | |
| if ext.endswith(".pdf"): | |
| loader = PyMuPDFLoader(file_path) | |
| documents = loader.load() | |
| content = "\n\n".join([doc.page_content for doc in documents]) | |
| elif ext.endswith(".txt") or ext.endswith(".py"): | |
| # Both .txt and .py are plain text files | |
| with open(file_path, "r", encoding="utf-8") as file: | |
| content = file.read() | |
| elif ext.endswith(".docx"): | |
| doc = DocxDocument(file_path) | |
| content = "\n".join([para.text for para in doc.paragraphs]) | |
| elif ext.endswith(".xlsx"): | |
| wb = openpyxl.load_workbook(file_path, data_only=True) | |
| content = "" | |
| for sheet in wb: | |
| content += f"Sheet: {sheet.title}\n" | |
| for row in sheet.iter_rows(values_only=True): | |
| content += "\t".join([str(cell) if cell is not None else "" for cell in row]) + "\n" | |
| else: | |
| return "Unsupported file format. Please use PY, PDF, TXT, DOCX, or XLSX." | |
| return content[:3000] # Limit output size for readability | |
| except Exception as e: | |
| return f"An error occurred while processing the file: {str(e)}" | |
| def get_youtube_transcript(url: str) -> str: | |
| """ | |
| Fetch transcript text from a YouTube video. | |
| Args: | |
| url (str): Full YouTube video URL. | |
| Returns: | |
| str: Transcript text as a single string. | |
| Raises: | |
| ValueError: If no transcript is available or URL is invalid. | |
| """ | |
| try: | |
| # Extract video ID | |
| video_id = extract_video_id(url) | |
| transcript = YouTubeTranscriptApi.get_transcript(video_id) | |
| # Combine all transcript text | |
| full_text = " ".join([entry['text'] for entry in transcript]) | |
| return full_text | |
| except (TranscriptsDisabled, VideoUnavailable) as e: | |
| raise ValueError(f"Transcript not available: {e}") | |
| except Exception as e: | |
| raise ValueError(f"Failed to fetch transcript: {e}") | |
| def extract_video_id(url: str) -> str: | |
| """ | |
| Extract the video ID from a YouTube URL. | |
| """ | |
| match = re.search(r"(?:v=|youtu\.be/)([A-Za-z0-9_-]{11})", url) | |
| if not match: | |
| raise ValueError("Invalid YouTube URL") | |
| return match.group(1) | |
| # ----------------------------- | |
| # Load configuration from YAML | |
| # ----------------------------- | |
| with open("config.yaml", "r") as f: | |
| config = yaml.safe_load(f) | |
| provider = config["provider"] | |
| model_config = config["models"][provider] | |
| #prompt_path = config["system_prompt_path"] | |
| enabled_tool_names = config["tools"] | |
| # ----------------------------- | |
| # Load system prompt | |
| # ----------------------------- | |
| # load the system prompt from the file | |
| with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
| system_prompt = f.read() | |
| # System message | |
| sys_msg = SystemMessage(content=system_prompt) | |
| # ----------------------------- | |
| # Map tool names to functions | |
| # ----------------------------- | |
| tool_map = { | |
| "math": calculator, | |
| "wiki_search": wiki_search, | |
| "web_search": web_search, | |
| "arxiv_search": arxiv_search, | |
| "get_youtube_transcript": get_youtube_transcript, | |
| "extract_video_id": extract_video_id, | |
| "analyze_attachment": analyze_attachment, | |
| "wikidata_query": wikidata_query | |
| } | |
| # Then define which tools you want enabled | |
| enabled_tool_names = [ | |
| "math", | |
| "wiki_search", | |
| "web_search", | |
| "arxiv_search", | |
| "get_youtube_transcript", | |
| "extract_video_id", | |
| "analyze_attachment", | |
| "wikidata_query" | |
| ] | |
| tools = [tool_map[name] for name in enabled_tool_names] | |
| # Safe version | |
| tools = [] | |
| for name in enabled_tool_names: | |
| if name not in tool_map: | |
| print(f"❌ Tool not found: {name}") | |
| continue | |
| tools.append(tool_map[name]) | |
| # ----------------------------- | |
| # Prepare Documents | |
| # ----------------------------- | |
| # Define the URL where the JSON file is hosted | |
| import faiss | |
| # 1. Type-Checked State for Gradio | |
| class ChatState(TypedDict): | |
| messages: Annotated[ | |
| List[str], | |
| gr.State(render=False), | |
| "Stores chat history as list of strings" | |
| ] | |
| # 2. Content Processing Utilities | |
| def process_content(raw_content) -> str: | |
| """Convert any input to a clean string""" | |
| if isinstance(raw_content, list): | |
| return " ".join(str(item) for item in raw_content) | |
| return str(raw_content) | |
| def reverse_text(text: str) -> str: | |
| """Fix reversed text patterns""" | |
| return text[::-1].replace("\\", "").strip() if text.startswith(('.', ',')) else text | |
| # 3. Unified Document Creation | |
| def create_documents(data_source: str, data: list) -> list: | |
| """Handle both Gradio chat and JSON questions""" | |
| docs = [] | |
| for item in data: | |
| content = "" | |
| # Process different data sources | |
| if data_source == "json": | |
| raw_question = item.get("question", "") | |
| content = raw_question # Adjust as per your content processing logic | |
| else: | |
| print(f"Skipping invalid data source: {data_source}") | |
| continue | |
| # Ensure metadata type safety | |
| metadata = { | |
| "task_id": str(item.get("task_id", "")), | |
| "level": str(item.get("Level", "")), | |
| "file_name": str(item.get("file_name", "")) | |
| } | |
| # Check if content is non-empty | |
| if content.strip(): # Only append non-empty content | |
| docs.append(Document(page_content=content, metadata=metadata)) | |
| else: | |
| print(f"Skipping invalid entry with empty content: {item}") | |
| return docs | |
| # Path to your data.json | |
| file_path = "/home/wendy/Downloads/data.json" | |
| def load_data(file_path: str) -> list[dict]: | |
| """Safe JSON data loading with error handling""" | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError(f"Data file not found: {file_path}") | |
| if not file_path.endswith('.json'): | |
| raise ValueError("Invalid file format. Only JSON files supported") | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except json.JSONDecodeError: | |
| raise ValueError("Invalid JSON format in data file") | |
| except Exception as e: | |
| raise RuntimeError(f"Error loading data: {str(e)}") | |
| # 4. Vector Store Integration | |
| import faiss | |
| # Custom FAISS wrapper (optional, if you still want it) | |
| class MyVector_Store: | |
| def __init__(self, index: faiss.Index): | |
| self.index = index | |
| def save_local(self, path: str): | |
| faiss.write_index(self.index, path) | |
| def load_local(cls, path: str): | |
| index = faiss.read_index(path) | |
| return cls(index) | |
| # ----------------------------- | |
| # Process JSON data and create documents | |
| # ----------------------------- | |
| file_path = "/home/wendy/Downloads/data.json" | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| print(data) | |
| except FileNotFoundError as e: | |
| print(f"Error: {e}") | |
| except json.JSONDecodeError as e: | |
| print(f"Error decoding JSON: {e}") | |
| docs = create_documents("json", data) | |
| texts = [doc.page_content for doc in docs] | |
| # ----------------------------- | |
| # Initialize embedding model | |
| # ----------------------------- | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| # ----------------------------- | |
| # Create FAISS index and save it | |
| # ----------------------------- | |
| class ChatState(TypedDict): | |
| messages: Annotated[ | |
| List[str], | |
| gr.State(render=False), | |
| "Stores chat history" | |
| ] | |
| def initialize_vector_store(): | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| index_path = "/home/wendy/my_hf_agent_course_projects/faiss_index" | |
| if os.path.exists(os.path.join(index_path, "index.faiss")): | |
| try: | |
| return FAISS.load_local( | |
| index_path, | |
| embedding_model, | |
| allow_dangerous_deserialization=True | |
| ) | |
| except Exception as e: | |
| print(f"Error loading index: {e}") | |
| # Fallback: Create new index | |
| print("Building new vector store...") | |
| docs = [...] # Your document loading logic here | |
| vector_store = FAISS.from_documents(docs, embedding_model) | |
| vector_store.save_local(index_path) | |
| return vector_store | |
| # Initialize at module level | |
| loaded_store = initialize_vector_store() | |
| retriever = loaded_store.as_retriever() | |
| # ----------------------------- | |
| # Create LangChain Retriever Tool | |
| # ----------------------------- | |
| #retriever = loaded_store.as_retriever() | |
| question_retriever_tool = create_retriever_tool( | |
| retriever=retriever, | |
| name="Question_Search", | |
| description="A tool to retrieve documents related to a user's question." | |
| ) | |
| # ----------------------------- | |
| # Load HuggingFace LLM | |
| # ----------------------------- | |
| llm = HuggingFaceEndpoint( | |
| repo_id="HuggingFaceH4/zephyr-7b-beta", | |
| task="text-generation", | |
| huggingfacehub_api_token=os.getenv("HF_TOKEN"), | |
| temperature=0.7, | |
| max_new_tokens=512 | |
| ) | |
| # ------------------------------- | |
| # Step 8: Use the Planner, Classifier, and Decision Logic | |
| # ------------------------------- | |
| def process_question(question): | |
| # Step 1: Planner generates the task sequence | |
| tasks = planner(question) | |
| print(f"Tasks to perform: {tasks}") | |
| # Step 2: Classify the task (based on question) | |
| task_type = task_classifier(question) | |
| print(f"Task type: {task_type}") | |
| # Step 3: Use the classifier and planner to decide on the next task or node | |
| state = {"question": question, "last_response": ""} | |
| next_task = decide_task(state) | |
| print(f"Next task: {next_task}") | |
| # Step 4: Use node skipper logic (skip if needed) | |
| skip = node_skipper(state) | |
| if skip: | |
| print(f"Skipping to {skip}") | |
| return skip # Or move directly to generating answer | |
| # Step 5: Execute task (with error handling) | |
| try: | |
| if task_type == "wiki_search": | |
| response = wiki_search(question) | |
| elif task_type == "math": | |
| response = calculator(question) | |
| else: | |
| response = "Default answer logic" | |
| # Step 6: Final response formatting | |
| final_response = final_answer_tool(state, {'wiki_search': response}) | |
| return final_response | |
| except Exception as e: | |
| print(f"Error executing task: {e}") | |
| return "Sorry, I encountered an error processing your request." | |
| # Run the process | |
| #question = "How many albums did Mercedes Sosa release between 2000 and 2009?" | |
| #response = agent.invoke(question) | |
| #print("Final Response:", response) | |
| from langchain.schema import HumanMessage | |
| def retriever(state: MessagesState, k: int = 4): | |
| """ | |
| Retrieves documents from the vector store using similarity scores, | |
| applies a dynamic threshold filter, and returns updated message state. | |
| Args: | |
| state (MessagesState): Current message state including the user's query. | |
| k (int): Number of top results to retrieve from the vector store. | |
| Returns: | |
| dict: Updated messages state including relevant documents or fallback message. | |
| """ | |
| query = state["messages"][0].content.strip() | |
| results = vector_store.similarity_search_with_score(query, k=k) | |
| # Determine dynamic similarity threshold | |
| if any(keyword in query.lower() for keyword in ["who", "what", "where", "when", "why", "how"]): | |
| threshold = 0.75 | |
| else: | |
| threshold = 0.8 | |
| filtered = [doc for doc, score in results if score < threshold] | |
| if not filtered: | |
| response_msg = HumanMessage(content="No relevant documents found.") | |
| else: | |
| content = "\n\n".join(doc.page_content for doc in filtered) | |
| response_msg = HumanMessage(content=f"Here are relevant reference documents:\n\n{content}") | |
| return {"messages": [sys_msg] + state["messages"] + [response_msg]} | |
| # ---------------------------------------------------------------- | |
| # LLM Loader | |
| # ---------------------------------------------------------------- | |
| def get_llm(provider: str, config: dict): | |
| if provider == "google": | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| return ChatGoogleGenerativeAI( | |
| model=config.get("model"), | |
| temperature=config.get("temperature", 0.7), | |
| google_api_key=config.get("api_key") # Optional: if needed | |
| ) | |
| elif provider == "groq": | |
| from langchain_groq import ChatGroq | |
| return ChatGroq( | |
| model=config.get("model"), | |
| temperature=config.get("temperature", 0.7), | |
| groq_api_key=config.get("api_key") # Optional: if needed | |
| ) | |
| elif provider == "huggingface": | |
| from langchain_huggingface import ChatHuggingFace | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| return ChatHuggingFace( | |
| llm=HuggingFaceEndpoint( | |
| endpoint_url=config.get("url"), | |
| temperature=config.get("temperature", 0.7), | |
| huggingfacehub_api_token=config.get("api_key") # Optional | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Invalid provider: {provider}") | |
| # ---------------------------------------------------------------- | |
| # Planning & Execution Logic | |
| # ---------------------------------------------------------------- | |
| def planner(question: str, tools: list) -> tuple: | |
| """ | |
| Select the best-matching tool(s) for a question based on keyword-based intent detection and tool metadata. | |
| Returns the detected intent and matched tools. | |
| """ | |
| question = question.lower().strip() | |
| # Define intent-based keywords | |
| intent_keywords = { | |
| "math": ["calculate", "evaluate", "add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"], | |
| "wiki_search": ["who is", "what is", "define", "explain", "tell me about", "overview of"], | |
| "web_search": ["search", "find", "look up", "google", "latest news", "current info"], | |
| "arxiv_search": ["arxiv", "research paper", "scientific paper", "preprint"], | |
| "get_youtube_transcript": ["youtube", "watch", "play video", "show me a video"], | |
| "extract_video_id": ["analyze video", "summarize video", "video content"], | |
| "data_analysis": ["analyze", "plot", "graph", "data", "visualize"], | |
| "wikidata_query": ["wikidata", "sparql", "run sparql", "query wikidata"], | |
| "default": ["why", "how", "difference between", "compare", "what happens", "reason for", "cause of", "effect of"] | |
| } | |
| # Step 1: Identify intent | |
| detected_intent = None | |
| for intent, keywords in intent_keywords.items(): | |
| if any(keyword in question for keyword in keywords): | |
| detected_intent = intent | |
| break | |
| # Step 2: Match tools by intent | |
| matched_tools = [] | |
| if detected_intent: | |
| for tool in tools: | |
| name = getattr(tool, "name", "").lower() | |
| description = getattr(tool, "description", "").lower() | |
| if detected_intent in name or detected_intent in description: | |
| matched_tools.append(tool) | |
| # Step 3: Fallback to general-purpose/default tools if no match found | |
| if not matched_tools: | |
| matched_tools = [ | |
| tool for tool in tools | |
| if "default" in getattr(tool, "name", "").lower() | |
| or "qa" in getattr(tool, "description", "").lower() | |
| ] | |
| return detected_intent, matched_tools if matched_tools else [tools[0]] | |
| def task_classifier(question: str) -> str: | |
| """ | |
| Classifies the question into one of the predefined task categories. | |
| """ | |
| question = question.lower().strip() | |
| # Context-aware intent patterns | |
| if any(phrase in question for phrase in [ | |
| "calculate", "how much is", "what is the result of", "evaluate", "solve" | |
| ]) or any(op in question for op in ["add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"]): | |
| return "math" | |
| elif any(phrase in question for phrase in [ | |
| "who is", "what is", "define", "explain", "tell me about", "give me an overview of" | |
| ]): | |
| return "wiki_search" | |
| elif any(phrase in question for phrase in [ | |
| "search", "find", "look up", "google", "get the latest", "current news", "trending" | |
| ]): | |
| return "web_search" | |
| elif any(phrase in question for phrase in [ | |
| "arxiv", "latest research", "scientific paper", "research paper", "preprint" | |
| ]): | |
| return "arxiv_search" | |
| elif any(phrase in question for phrase in [ | |
| "youtube", "watch", "play the video", "show me a video" | |
| ]): | |
| return "get_youtube_transcript" | |
| elif any(phrase in question for phrase in [ | |
| "analyze video", "summarize video", "what happens in the video", "video content" | |
| ]): | |
| return "video_analysis" | |
| elif any(phrase in question for phrase in [ | |
| "analyze", "visualize", "plot", "graph", "inspect data", "explore dataset" | |
| ]): | |
| return "data_analysis" | |
| elif any(phrase in question for phrase in [ | |
| "sparql", "wikidata", "query wikidata", "run sparql", "wikidata query" | |
| ]): | |
| return "wikidata_query" | |
| return "default" | |
| def select_tool_and_run(question: str, tools: dict): | |
| # Step 1: Classify intent | |
| intent = task_classifier(question) # assuming task_classifier maps the question to intent | |
| # Map intent to tool names | |
| intent_tool_map = { | |
| "math": "calculator", # maps to tools["math"] → calculator | |
| "wiki_search": "wiki_search", # → wiki_search | |
| "web_search": "web_search", # → web_search | |
| "arxiv_search": "arxiv_search", # → arxiv_search (spelling fixed) | |
| "get_youtube_transcript": "get_youtube_transcript", # → get_youtube_transcript | |
| "extract_video_id": "extract_video_id", # adjust based on your tools | |
| "analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this | |
| "wikidata_query": "wikidata_query", # → wikidata_query | |
| "default": "default" # → default_tool | |
| } | |
| # Get the corresponding tool name | |
| tool_name = intent_tool_map.get(intent, "default") # Default to "default" if no match | |
| # Retrieve the tool from the tools dictionary | |
| tool_func = tools.get(tool_name) | |
| if not tool_func: | |
| return f"Tool not found for intent '{intent}'" | |
| # Step 2: Run the tool | |
| try: | |
| # If the tool needs JSON or structured data | |
| try: | |
| parsed_input = json.loads(question) | |
| except json.JSONDecodeError: | |
| parsed_input = question # fallback to raw input if not JSON | |
| # Run the selected tool | |
| print(f"Running tool: {tool_name} with input: {parsed_input}") # log the tool name and input | |
| return tool_func(parsed_input) | |
| except Exception as e: | |
| return f"Error while running tool '{tool_name}': {str(e)}" | |
| # Function to extract math operation from the question | |
| def extract_math_from_question(question: str): | |
| question = question.lower() | |
| # Map natural language to symbols | |
| ops = { | |
| "add": "+", "plus": "+", | |
| "subtract": "-", "minus": "-", | |
| "multiply": "*", "times": "*", | |
| "divide": "/", "divided by": "/", | |
| "modulus": "%", "mod": "%" | |
| } | |
| for word, symbol in ops.items(): | |
| question = re.sub(rf"\b{word}\b", symbol, question) | |
| # Extract math expression like "12 + 5" | |
| match = re.search(r'(\d+)\s*([\+\-\*/%])\s*(\d+)', question) | |
| if match: | |
| num1 = int(match.group(1)) | |
| operator = match.group(2) | |
| num2 = int(match.group(3)) | |
| return { | |
| "a": num1, | |
| "b": num2, | |
| "operation": { | |
| "+": "add", | |
| "-": "subtract", | |
| "*": "multiply", | |
| "/": "divide", | |
| "%": "modulus" | |
| }[operator] | |
| } | |
| return None | |
| # Example tool set (adjust these to match your actual tool names) | |
| intent_tool_map = { | |
| "math": "math", # maps to tools["math"] → calculator | |
| "wiki_search": "wiki_search", # → wiki_search | |
| "web_search": "web_search", # → web_search | |
| "arxiv_search": "arxiv_search", # → arxiv_search (spelling fixed) | |
| "get_youtube_transcript": "get_youtube_transcript", # → get_youtube_transcript | |
| "extract_video_id": "extract_video_id", # adjust based on your tools | |
| "analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this | |
| "wikidata_query": "wikidata_query", # → wikidata_query | |
| "default": "default" # → default_tool | |
| } | |
| # The task order can also include the tools for each task | |
| priority_order = [ | |
| {"task": "math", "tool": "math"}, | |
| {"task": "wiki_search", "tool": "wiki_search"}, | |
| {"task": "web_search", "tool": "web_search"}, | |
| {"task": "arxiv_search", "tool": "arxiv_search"}, | |
| {"task": "wikidata_query", "tool": "wikidata_query"}, | |
| {"task": "retriever", "tool": "retriever"}, | |
| {"task": "get_youtube_transcript", "tool": "get_youtube_transcript"}, | |
| {"task": "extract_video_id", "tool": "extract_video_id"}, | |
| {"task": "analyze_attachment", "tool": "analyze_attachment"}, | |
| {"task": "default", "tool": "default"} # Fallback | |
| ] | |
| def decide_task(state: dict) -> str: | |
| """Decides which task to perform based on the current state.""" | |
| # Get the list of tasks from the planner | |
| tasks = planner(state["question"]) | |
| print(f"Available tasks: {tasks}") # Debugging: show all possible tasks | |
| # Check if the tasks list is empty or invalid | |
| if not tasks: | |
| print("❌ No valid tasks were returned from the planner.") | |
| return "default" # Return a default task if no tasks were generated | |
| # If there are multiple tasks, we can prioritize based on certain conditions | |
| task = tasks[0] # Default to the first task in the list | |
| if len(tasks) > 1: | |
| print(f"⚠️ Multiple tasks found. Deciding based on priority.") | |
| # Example logic to prioritize tasks, adjust based on your use case | |
| task = prioritize_tasks(tasks) | |
| print(f"Decided on task: {task}") # Debugging: show the final task | |
| return task | |
| def prioritize_tasks(tasks: list) -> str: | |
| """Prioritize tasks based on certain conditions or criteria, including tools.""" | |
| # Sort tasks based on priority_order mapping | |
| for priority in priority_order: | |
| # Check if any task matches the priority task type | |
| for task in tasks: | |
| if priority["task"] in task: | |
| print(f"✅ Prioritizing task: {task} with tool: {priority['tool']}") # Debugging: show the chosen task and tool | |
| # Assign the correct tool based on the task | |
| tool = tools.get(priority["tool"], tools["default"]) # Default to 'default_tool' if not found | |
| return task, tool | |
| # If no priority task is found, return the first task with its default tool | |
| return tasks[0], tools["default"] | |
| def process_question(question: str): | |
| """Process the question and route it to the appropriate tool.""" | |
| # Get the tasks from the planner | |
| tasks = planner(question) | |
| print(f"Tasks to perform: {tasks}") | |
| task_type, tool = decide_task({"question": question}) | |
| print(f"Next task: {task_type} with tool: {tool}") | |
| if node_skipper({"question": question}): | |
| print(f"Skipping task: {task_type}") | |
| return "Task skipped." | |
| try: | |
| # Execute the corresponding tool for the task type | |
| if task_type == "wiki_search": | |
| response = tool.run(question) # Assuming tool is wiki_tool | |
| elif task_type == "math": | |
| response = tool.run(question) # Assuming tool is calc_tool | |
| elif task_type == "retriever": | |
| response = tool.run(question) # Assuming tool is retriever_tool | |
| else: | |
| response = tool.run(question) # Default tool | |
| return generate_final_answer({"question": question}, {task_type: response}) | |
| except Exception as e: | |
| print(f"❌ Error: {e}") | |
| return f"Sorry, I encountered an error: {str(e)}" | |
| def call_llm(state): | |
| messages = state["messages"] | |
| response = llm.invoke(messages) | |
| return {"messages": messages + [response]} | |
| from langchain.schema import AIMessage | |
| from typing import TypedDict, List, Optional | |
| from langchain_core.messages import BaseMessage | |
| class AgentState(TypedDict): | |
| messages: List[BaseMessage] # Chat history | |
| input: str # Original input | |
| intent: str # Derived or predicted intent | |
| result: Optional[str] # Optional result | |
| def tool_dispatcher(state: AgentState) -> AgentState: | |
| last_msg = state["messages"][-1] | |
| # Make sure it's an AI message with tool_calls | |
| if isinstance(last_msg, AIMessage) and last_msg.tool_calls: | |
| tool_call = last_msg.tool_calls[0] | |
| tool_name = tool_call["name"] | |
| tool_input = tool_call["args"] # Adjust based on your actual schema | |
| tool_func = tool_map.get(tool_name, default_tool) | |
| # If args is a dict and your tool expects unpacked values: | |
| if isinstance(tool_input, dict): | |
| result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(**tool_input) | |
| else: | |
| result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(tool_input) | |
| # You can choose to append this to messages, or just save result | |
| return { | |
| **state, | |
| "result": result, | |
| # Optionally add: "messages": state["messages"] + [ToolMessage(...)] | |
| } | |
| # No tool call detected, return state unchanged | |
| return state | |
| # Decide what to do next: if tool call → call_tool, else → end | |
| def should_call_tool(state): | |
| last_msg = state["messages"][-1] | |
| if isinstance(last_msg, AIMessage) and last_msg.tool_calls: | |
| return "call_tool" | |
| return "end" | |
| from typing import TypedDict, List, Optional, Union | |
| from langchain.schema import BaseMessage | |
| class AgentState(TypedDict): | |
| messages: List[BaseMessage] # Chat history | |
| input: str # Original input | |
| intent: str # Derived or predicted intent | |
| result: Optional[str] # Final or intermediate result | |
| # To store previously asked questions and timestamps (simulating state persistence) | |
| recent_questions = {} | |
| def node_skipper(state: dict) -> bool: | |
| """ | |
| Determines whether to skip the task based on the state. | |
| This could include: | |
| 1. Repeated or similar questions | |
| 2. Irrelevant or empty questions | |
| 3. Tasks that have already been processed recently | |
| """ | |
| question = state.get("question", "").strip() | |
| if not question: | |
| print("❌ Skipping: Empty or invalid question.") | |
| return True # Skip if no valid question | |
| # 1. Skip if the question has already been asked recently (within a given time window) | |
| # Here, we're using a simple example with a 5-minute window (300 seconds). | |
| if question in recent_questions: | |
| last_asked_time = recent_questions[question] | |
| time_since_last_ask = time.time() - last_asked_time | |
| if time_since_last_ask < 300: # 5-minute threshold | |
| print(f"❌ Skipping: The question has been asked recently. Time since last ask: {time_since_last_ask:.2f} seconds.") | |
| return True # Skip if the question was asked within the last 5 minutes | |
| # 2. Skip if the question is irrelevant or not meaningful enough | |
| irrelevant_keywords = ["blah", "nothing", "invalid", "nonsense"] | |
| if any(keyword in question.lower() for keyword in irrelevant_keywords): | |
| print("❌ Skipping: Irrelevant or nonsense question.") | |
| return True # Skip if the question contains irrelevant keywords | |
| # 3. Skip if the task has already been completed for this question (based on a unique task identifier) | |
| if "last_response" in state and state["last_response"]: | |
| print("❌ Skipping: Task has already been processed recently.") | |
| return True # Skip if a response has already been given | |
| # 4. Skip based on a condition related to the task itself | |
| # Example: Skip math-related tasks if the result is already known or trivial | |
| if "math" in state.get("question", "").lower(): | |
| # If math is trivial (like "What is 2+2?") | |
| trivial_math = ["2 + 2", "1 + 1", "3 + 3"] | |
| if any(trivial_question in question for trivial_question in trivial_math): | |
| print(f"❌ Skipping trivial math question: {question}") | |
| return True # Skip if the math question is trivial | |
| # 5. Skip based on external factors (e.g., current time, system load, etc.) | |
| # Example: Avoid processing tasks at night if that's part of the business logic | |
| current_hour = time.localtime().tm_hour | |
| if current_hour >= 22 or current_hour < 6: | |
| print("❌ Skipping: It's night time, not processing tasks.") | |
| return True # Skip tasks during night time (e.g., between 10 PM and 6 AM) | |
| # If none of the conditions matched, don't skip the task | |
| return False | |
| # Update recent questions (for simulating repeated question check) | |
| def update_recent_questions(question: str): | |
| """Update the recent questions dictionary with the current timestamp.""" | |
| recent_questions[question] = time.time() | |
| def generate_final_answer(state: dict, task_results: dict) -> str: | |
| """Generate a final answer based on the results of the task.""" | |
| if "wiki_search" in task_results: | |
| return f"📚 Wiki Summary:\n{task_results['wiki_search']}" | |
| elif "math" in task_results: | |
| return f"🧮 Math Result: {task_results['math']}" | |
| elif "retriever" in task_results: | |
| return f"🔍 Retrieved Info: {task_results['retriever']}" | |
| else: | |
| return "🤖 Unable to generate a specific answer." | |
| def answer_question(question: str) -> str: | |
| """Process a single question and return the answer.""" | |
| print(f"Processing question: {question[:50]}...") # Debugging: show first 50 chars | |
| # Wrap the question in a HumanMessage from langchain_core (assuming langchain is used) | |
| messages = [HumanMessage(content=question)] | |
| response = graph.invoke({"messages": messages}) # Assuming `graph` is defined elsewhere | |
| # Extract the answer from the response | |
| answer = response['messages'][-1].content | |
| return answer[14:] # Assuming 'answer[14:]' is correct based on your example | |
| def process_all_tasks(tasks: list): | |
| """Process a list of tasks.""" | |
| results = {} | |
| for task in tasks: | |
| question = task.get("question", "").strip() | |
| if not question: | |
| print(f"Skipping task with missing or empty 'question': {task}") | |
| continue | |
| print(f"\n🟢 Processing Task: {task['task_id']} - Question: {question}") | |
| # Call the existing process_question logic | |
| response = process_question(question) | |
| print(f"✅ Response: {response}") | |
| results[task['task_id']] = response | |
| return results | |
| ## Langgraph | |
| # Build graph function | |
| vector_store = vector_store.save_local("faiss_index") | |
| provider = "huggingface" | |
| model_config = { | |
| "repo_id": "HuggingFaceH4/zephyr-7b-beta", | |
| "task": "text-generation", | |
| "temperature": 0.7, | |
| "max_new_tokens": 512, | |
| "huggingfacehub_api_token": os.getenv("HF_TOKEN") | |
| } | |
| # Get LLM | |
| def get_llm(provider: str, config: dict): | |
| if provider == "huggingface": | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| return HuggingFaceEndpoint( | |
| repo_id=config["repo_id"], | |
| task=config["task"], | |
| huggingfacehub_api_token=config["huggingfacehub_api_token"], | |
| temperature=config["temperature"], | |
| max_new_tokens=config["max_new_tokens"] | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| def assistant(state: dict): | |
| return { | |
| "messages": [llm_with_tools.invoke(state["messages"])] | |
| } | |
| def tools_condition(state: dict) -> str: | |
| if "use tool" in state["messages"][-1].content.lower(): | |
| return "tools" | |
| else: | |
| return "END" | |
| from langgraph.graph import StateGraph | |
| from langchain_core.messages import SystemMessage | |
| from langchain_core.runnables import RunnableLambda | |
| def build_graph(vector_store, provider: str, model_config: dict) -> StateGraph: | |
| # Get LLM | |
| llm = get_llm(provider, model_config) | |
| # Define available tools | |
| tools = [ | |
| wiki_search, calculator, web_search, arxiv_search, | |
| get_youtube_transcript, extract_video_id, analyze_attachment, wikidata_query | |
| ] | |
| # Tool mapping (global if needed elsewhere) | |
| global tool_map | |
| tool_map = {t.name: t for t in tools} | |
| # Bind tools only if LLM supports it | |
| if hasattr(llm, "bind_tools"): | |
| llm_with_tools = llm.bind_tools(tools) | |
| else: | |
| llm_with_tools = llm # fallback for non-tool-aware models | |
| sys_msg = SystemMessage(content="You are a helpful assistant.") | |
| # Define nodes as runnables | |
| retriever = RunnableLambda(lambda state: { | |
| **state, | |
| "retrieved_docs": vector_store.similarity_search(state["input"]) | |
| }) | |
| assistant = RunnableLambda(lambda state: { | |
| **state, | |
| "messages": [sys_msg] + state["messages"] | |
| }) | |
| call_llm = llm_with_tools # already configured | |
| # Start building the graph | |
| builder = StateGraph(AgentState) | |
| builder.add_node("retriever", retriever) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("call_llm", call_llm) | |
| builder.add_node("call_tool", tool_dispatcher) | |
| builder.add_node("end", lambda state: state) # Add explicit end node | |
| # Define graph flow | |
| builder.set_entry_point("retriever") | |
| builder.add_edge("retriever", "assistant") | |
| builder.add_edge("assistant", "call_llm") | |
| builder.add_conditional_edges("call_llm", should_call_tool, { | |
| "call_tool": "call_tool", | |
| "end": "end" # ✅ fixed: must point to actual "end" node | |
| }) | |
| builder.add_edge("call_tool", "call_llm") # loop back after tool call | |
| return builder.compile() | |