# Standard Library import os import re import tempfile import string import glob import shutil import gc import sys import uuid import signal from pathlib import Path import subprocess from datetime import datetime from io import BytesIO from contextlib import contextmanager from langchain_huggingface import HuggingFacePipeline from typing import TypedDict, List, Optional, Dict, Any, Annotated, Literal, Union, Tuple, Set, Type import time from collections import Counter from pydantic import Field, BaseModel, Extra import hashlib import json import numpy as np import ast from concurrent.futures import ThreadPoolExecutor, as_completed from collections import Counter, defaultdict # Third-Party Packages import cv2 import requests import wikipedia import spacy import yt_dlp import librosa from PIL import Image from bs4 import BeautifulSoup from duckduckgo_search import DDGS from sentence_transformers import SentenceTransformer from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline, AutoTokenizer from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from sklearn.cluster import KMeans from sklearn.preprocessing import StandardScaler import speech_recognition as sr from pydub import AudioSegment from pydub.silence import split_on_silence import nltk from nltk.corpus import words import pandas as pd # LangChain Ecosystem from langchain.docstore.document import Document from langchain.prompts import PromptTemplate from langchain_community.document_loaders import WikipediaLoader from langchain_huggingface import HuggingFaceEndpoint from langchain_community.retrievers import BM25Retriever from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema import Document from langchain_community.tools import DuckDuckGoSearchRun from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, BaseMessage, SystemMessage, ToolMessage from langchain_core.tools import BaseTool, StructuredTool, tool, render_text_description from langchain_core.documents import Document from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) # LangGraph from langgraph.graph import START, END, StateGraph from langgraph.prebuilt import ToolNode, tools_condition # PyTorch import torch from functools import partial from transformers import pipeline # Additional Utilities from datetime import datetime from urllib.parse import urljoin, urlparse import logging nlp = spacy.load("en_core_web_sm") # Ensure the word list is downloaded nltk.download('words', quiet=True) english_words = set(words.words()) logger = logging.getLogger(__name__) logging.basicConfig(filename='app.log', level=logging.INFO) # --- Model Configuration --- def create_llm_pipeline(): #model_id = "meta-llama/Llama-2-13b-chat-hf" #model_id = "meta-llama/Llama-3.3-70B-Instruct" #model_id = "mistralai/Mistral-Small-24B-Base-2501" model_id = "mistralai/Mistral-7B-Instruct-v0.3" #model_id = "Meta-Llama/Llama-2-7b-chat-hf" #model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO" #model_id = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF" #model_id = "mistralai/Mistral-7B-Instruct-v0.2" #model_id = "Qwen/Qwen2-7B-Instruct" #model_id = "GSAI-ML/LLaDA-8B-Instruct" return pipeline( "text-generation", model=model_id, device_map="cpu", torch_dtype=torch.float16, max_new_tokens=1024, temperature=0.05, do_sample=False, repetition_penalty=1.2 ) # Define file extension sets for each category PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'} CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'} SPREADSHEET_EXTENSIONS = { '.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm', '.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet', '.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2', '.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx', '.pmv', '.uos', '.txt' } def get_file_type(filename: str) -> str: if not filename or '.' not in filename or filename == '': return '' ext = filename.lower().rsplit('.', 1)[-1] dot_ext = f'.{ext}' if dot_ext in PICTURE_EXTENSIONS: return 'picture' elif dot_ext in AUDIO_EXTENSIONS: return 'audio' elif dot_ext in CODE_EXTENSIONS: return 'code' elif dot_ext in SPREADSHEET_EXTENSIONS: return 'spreadsheet' else: return 'unknown' def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str: """ Writes bytes to a file in the system temporary directory using the provided file_name. Returns the full path to the saved file. The file will persist until manually deleted or the OS cleans the temp directory. """ temp_dir = "/tmp" # /tmp is always writable in Hugging Face Spaces os.makedirs(temp_dir, exist_ok=True) file_path = os.path.join(temp_dir, file_name) with open(file_path, 'wb') as f: f.write(file_bytes) print(f"File written to: {file_path}") return file_path def extract_final_answer(text: str) -> str: """ Extracts the answer after the last 'FINAL ANSWER:' (case-insensitive), removes any parenthetical immediately following a numeric answer, strips trailing punctuation, sorts comma-separated lists, and does not split numbers containing commas. Returns an empty string if marker not found. """ marker = "FINAL ANSWER:" idx = text.lower().rfind(marker.lower()) if idx == -1: return "" # Extract answer after marker result = text[idx + len(marker):].strip() if "pure vanilla extract" not in result: result = result.replace("vanilla extract", "pure vanilla extract") # Remove parenthetical immediately following a number at the start result = re.sub(r'^(\d+(?:\.\d+)?)\s*\(.*?\)', r'\1', result) # Remove trailing punctuation and whitespace result = result.rstrip(string.punctuation + " ") # Split on commas NOT between digits (i.e., not inside numbers) # This regex splits on commas not surrounded by digits (to avoid splitting numbers like 1,000) items = re.split(r',(?!\s*\d{3}\b)', result) # If we have a list, sort it if len(items) > 1: items = [item.strip() for item in items] # Try to sort numerically try: sorted_items = sorted( items, key=lambda x: float(re.sub(r'[^\d\.]', '', x)) # Remove non-numeric except . ) return ', '.join(sorted_items) except ValueError: # Fallback: sort alphabetically sorted_items = sorted(items, key=lambda x: x.lower()) return ', '.join(sorted_items) return result class AudioTranscriptionInput(BaseModel): """Input schema for AudioTranscriptionTool.""" file_path: str = Field(description="Path to the audio file to transcribe") engine: Optional[str] = Field(default="google", description="Speech recognition engine to use") language: Optional[str] = Field(default="en-US", description="Language of the audio") class AudioTranscriptionTool(BaseTool): """Tool for transcribing audio files using local speech recognition.""" name: str = "audio_transcription" description: str = """ Transcribes voice memo, audio files (mp3, wav, m4a, flac, etc.) to text using local speech recognition. Input should be a dictionary with 'file_path' key containing the path to the audio file. Optionally accepts 'engine' and 'language' parameters. Returns the transcribed text as a string. """ args_schema: type[BaseModel] = AudioTranscriptionInput class Config: arbitrary_types_allowed = True def __init__(self, **kwargs): """Initialize the AudioTranscriptionTool.""" super().__init__(**kwargs) self._init_speech_recognition() def _init_speech_recognition(self): """Initialize speech recognition components.""" try: import speech_recognition as sr from pydub import AudioSegment object.__setattr__(self, 'recognizer', sr.Recognizer()) object.__setattr__(self, 'sr', sr) object.__setattr__(self, 'AudioSegment', AudioSegment) except ImportError as e: raise ImportError( "Required libraries not found. Install with: " "pip install SpeechRecognition pydub" ) from e def _validate_audio_file(self, file_path: str) -> bool: """Validate that the audio file exists and has a supported format.""" if not os.path.exists(file_path): raise FileNotFoundError(f"Audio file not found: {file_path}") # Check file extension - pydub supports many formats supported_formats = {'.mp3', '.wav', '.m4a', '.flac', '.mp4', '.mpeg', '.mpga', '.webm', '.ogg', '.aac'} file_extension = Path(file_path).suffix.lower() if file_extension not in supported_formats: raise ValueError( f"Unsupported audio format: {file_extension}. " f"Supported formats: {', '.join(supported_formats)}" ) return True def _convert_to_wav(self, file_path: str) -> str: """Convert audio file to WAV format if needed.""" file_extension = Path(file_path).suffix.lower() if file_extension == '.wav': return file_path try: # Convert to WAV using pydub audio = self.AudioSegment.from_file(file_path) # Create temporary WAV file temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') audio.export(temp_wav.name, format="wav") return temp_wav.name except Exception as e: raise RuntimeError(f"Error converting audio file to WAV: {str(e)}") def _transcribe_audio(self, file_path: str, engine: str = "google", language: str = "en-US") -> str: """Transcribe audio file using local speech recognition.""" temp_wav_path = None try: # Convert to WAV if necessary wav_path = self._convert_to_wav(file_path) if wav_path != file_path: temp_wav_path = wav_path # Load audio file with self.sr.AudioFile(wav_path) as source: # Adjust for ambient noise self.recognizer.adjust_for_ambient_noise(source, duration=0.5) # Record the audio audio_data = self.recognizer.record(source) # Choose recognition engine if engine == "google": transcript = self.recognizer.recognize_google(audio_data, language=language) elif engine == "sphinx": transcript = self.recognizer.recognize_sphinx(audio_data, language=language) elif engine == "wit": # Note: requires WIT_AI_KEY environment variable wit_key = os.getenv('WIT_AI_KEY') if not wit_key: raise ValueError("WIT_AI_KEY environment variable required for Wit.ai engine") transcript = self.recognizer.recognize_wit(audio_data, key=wit_key) elif engine == "bing": # Note: requires BING_KEY environment variable bing_key = os.getenv('BING_KEY') if not bing_key: raise ValueError("BING_KEY environment variable required for Bing engine") transcript = self.recognizer.recognize_bing(audio_data, key=bing_key, language=language) else: # Default to Google transcript = self.recognizer.recognize_google(audio_data, language=language) return transcript except self.sr.UnknownValueError: return "Could not understand the audio - speech was unclear or inaudible" except self.sr.RequestError as e: return f"Error with speech recognition service: {str(e)}" except Exception as e: raise RuntimeError(f"Error transcribing audio: {str(e)}") finally: # Clean up temporary WAV file if temp_wav_path and os.path.exists(temp_wav_path): try: os.unlink(temp_wav_path) except OSError: pass # Ignore cleanup errors def _run(self, file_path: str, engine: str = "google", language: str = "en-US", **kwargs) -> str: """ Internal method required by LangChain BaseTool. Args: file_path: Path to the audio file to transcribe engine: Speech recognition engine to use language: Language of the audio Returns: str: Transcribed text from the audio file """ try: # Validate audio file self._validate_audio_file(file_path) # Transcribe audio transcript = self._transcribe_audio( file_path=file_path, engine=engine, language=language ) return transcript except Exception as e: error_msg = f"AudioTranscriptionTool error: {str(e)}" print(error_msg) return error_msg def run(self, tool_input: Dict[str, Any]) -> str: """ Main method to run the audio transcription tool. Args: tool_input: Dictionary containing 'file_path' and optional parameters Returns: str: Transcribed text from the audio file """ try: # Extract parameters from input file_path = tool_input.get('file_path') if not file_path: raise ValueError("file_path is required in tool_input") engine = tool_input.get('engine', 'google') language = tool_input.get('language', 'en-US') # Call the internal _run method return self._run(file_path=file_path, engine=engine, language=language) except Exception as e: error_msg = f"AudioTranscriptionTool error: {str(e)}" print(error_msg) return error_msg # Enhanced local transcription tool with multiple engine support class AdvancedAudioTranscriptionTool(BaseTool): """Advanced tool with support for multiple local transcription engines including Whisper.""" name: str = "advanced_audio_transcription" description: str = """ Advanced audio transcription tool supporting multiple engines including local Whisper. Supports engines: 'whisper' (local), 'google', 'sphinx', 'wit', 'bing'. Input should be a dictionary with 'file_path' key. Returns the transcribed text as a string. """ args_schema: type[BaseModel] = AudioTranscriptionInput class Config: arbitrary_types_allowed = True def __init__(self, **kwargs): """Initialize the AdvancedAudioTranscriptionTool.""" super().__init__(**kwargs) self._init_speech_recognition() self._init_whisper() def _init_speech_recognition(self): """Initialize speech recognition components.""" try: import speech_recognition as sr from pydub import AudioSegment object.__setattr__(self, 'recognizer', sr.Recognizer()) object.__setattr__(self, 'sr', sr) object.__setattr__(self, 'AudioSegment', AudioSegment) except ImportError as e: raise ImportError( "Required libraries not found. Install with: " "pip install SpeechRecognition pydub" ) from e def _init_whisper(self): """Initialize Whisper if available.""" try: import whisper object.__setattr__(self, 'whisper', whisper) except ImportError: object.__setattr__(self, 'whisper', None) print("Warning: OpenAI Whisper not installed. Install with 'pip install openai-whisper' for local Whisper support.") def _validate_audio_file(self, file_path: str) -> bool: """Validate that the audio file exists and has a supported format.""" if not os.path.exists(file_path): raise FileNotFoundError(f"Audio file not found: {file_path}") supported_formats = {'.mp3', '.wav', '.m4a', '.flac', '.mp4', '.mpeg', '.mpga', '.webm', '.ogg', '.aac'} file_extension = Path(file_path).suffix.lower() if file_extension not in supported_formats: raise ValueError( f"Unsupported audio format: {file_extension}. " f"Supported formats: {', '.join(supported_formats)}" ) return True def _transcribe_with_whisper(self, file_path: str, language: str = "en") -> str: """Transcribe using local Whisper model.""" if not self.whisper: raise RuntimeError("Whisper not installed. Install with 'pip install openai-whisper'") try: # Load the model (you can change model size: tiny, base, small, medium, large) model = self.whisper.load_model("base") # Transcribe the audio result = model.transcribe(file_path, language=language if language != "en-US" else "en") return result["text"].strip() except Exception as e: raise RuntimeError(f"Error with Whisper transcription: {str(e)}") def _convert_to_wav(self, file_path: str) -> str: """Convert audio file to WAV format if needed.""" file_extension = Path(file_path).suffix.lower() if file_extension == '.wav': return file_path try: audio = self.AudioSegment.from_file(file_path) temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') audio.export(temp_wav.name, format="wav") return temp_wav.name except Exception as e: raise RuntimeError(f"Error converting audio file to WAV: {str(e)}") def _transcribe_with_sr(self, file_path: str, engine: str = "google", language: str = "en-US") -> str: """Transcribe using speech_recognition library.""" temp_wav_path = None try: wav_path = self._convert_to_wav(file_path) if wav_path != file_path: temp_wav_path = wav_path with self.sr.AudioFile(wav_path) as source: self.recognizer.adjust_for_ambient_noise(source, duration=0.5) audio_data = self.recognizer.record(source) if engine == "google": transcript = self.recognizer.recognize_google(audio_data, language=language) elif engine == "sphinx": transcript = self.recognizer.recognize_sphinx(audio_data) elif engine == "wit": wit_key = os.getenv('WIT_AI_KEY') if not wit_key: raise ValueError("WIT_AI_KEY environment variable required for Wit.ai engine") transcript = self.recognizer.recognize_wit(audio_data, key=wit_key) elif engine == "bing": bing_key = os.getenv('BING_KEY') if not bing_key: raise ValueError("BING_KEY environment variable required for Bing engine") transcript = self.recognizer.recognize_bing(audio_data, key=bing_key, language=language) else: transcript = self.recognizer.recognize_google(audio_data, language=language) return transcript except self.sr.UnknownValueError: return "Could not understand the audio - speech was unclear or inaudible" except self.sr.RequestError as e: return f"Error with speech recognition service: {str(e)}" finally: if temp_wav_path and os.path.exists(temp_wav_path): try: os.unlink(temp_wav_path) except OSError: pass def _run(self, file_path: str, engine: str = "google", language: str = "en-US", **kwargs) -> str: """ Internal method required by LangChain BaseTool. Args: file_path: Path to the audio file to transcribe engine: Speech recognition engine to use language: Language of the audio Returns: str: Transcribed text from the audio file """ try: self._validate_audio_file(file_path) # Use local Whisper if specified if engine == "whisper": transcript = self._transcribe_with_whisper(file_path, language) else: # Use speech_recognition library transcript = self._transcribe_with_sr(file_path, engine, language) return transcript except Exception as e: error_msg = f"AdvancedAudioTranscriptionTool error: {str(e)}" print(error_msg) return error_msg def run(self, tool_input: Dict[str, Any]) -> str: """ Main method to run the advanced audio transcription tool. Args: tool_input: Dictionary containing 'file_path' and optional parameters Returns: str: Transcribed text from the audio file """ try: file_path = tool_input.get('file_path') if not file_path: raise ValueError("file_path is required in tool_input") engine = tool_input.get('engine', 'google') language = tool_input.get('language', 'en-US') # Call the internal _run method return self._run(file_path=file_path, engine=engine, language=language) except Exception as e: error_msg = f"AdvancedAudioTranscriptionTool error: {str(e)}" print(error_msg) return error_msg class ExcelReaderInput(BaseModel): """Input schema for ExcelReaderTool.""" file_path: str = Field(description="Path to the Excel file to read") class ExcelReaderTool(BaseTool): """Tool for reading Excel files and formatting them for LLM consumption.""" name: str = "excel_reader" description: str = ( "Reads an Excel file from the specified file path " "Use for running math operations on a table of sales data from a fast-food restaurant chain ONLY," ) args_schema: Type[BaseModel] = ExcelReaderInput def _run(self, file_path: str, run_manager: Optional[Any] = None) -> str: """ Execute the tool to read Excel file and return formatted table. Args: file_path: Path to the Excel file run_manager: Optional callback manager Returns: Formatted string representation of the Excel table """ try: # Validate file exists if not os.path.exists(file_path): return f"Error: File not found at path: {file_path}" # Validate file extension if not file_path.lower().endswith(('.xlsx', '.xls')): return f"Error: File must be an Excel file (.xlsx or .xls). Got: {file_path}" # Read Excel file - specifically Sheet1 try: df = pd.read_excel(file_path, sheet_name='Sheet1') except ValueError as e: if "Worksheet named 'Sheet1' not found" in str(e): # If Sheet1 doesn't exist, try reading the first sheet df = pd.read_excel(file_path, sheet_name=0) else: raise e # Check if dataframe is empty if df.empty: return "The Excel file contains no data in Sheet1." # Format the table for LLM consumption formatted_output = self._format_table_for_llm(df, file_path) return formatted_output except FileNotFoundError: return f"Error: File not found at path: {file_path}" except PermissionError: return f"Error: Permission denied accessing file: {file_path}" except Exception as e: return f"Error reading Excel file: {str(e)}" def _format_table_for_llm(self, df: pd.DataFrame, file_path: str) -> str: """ Format the pandas DataFrame into a readable string format for LLMs. Args: df: The pandas DataFrame containing the Excel data file_path: Original file path for reference Returns: Formatted string representation of the table """ output_lines = [] # Add header information #output_lines.append(f"EXCEL FILE DATA FROM: {os.path.basename(file_path)}") #output_lines.append(f"Sheet: Sheet1") #output_lines.append(f"Dimensions: {df.shape[0]} rows ร— {df.shape[1]} columns") #output_lines.append("-" * 60) # Add column information #output_lines.append("COLUMNS:") #for i, col in enumerate(df.columns, 1): # col_type = str(df[col].dtype) # non_null_count = df[col].count() # output_lines.append(f" {i}. {col} ({col_type}) - {non_null_count} non-null values") #output_lines.append("-" * 60) # Add table data in a clean format output_lines.append("TABLE DATA:") # Convert DataFrame to string with proper formatting # Handle potential NaN values and make it readable df_clean = df.fillna("N/A") # Replace NaN with readable placeholder # Create a formatted table string #table_str = df_clean.to_string(index=True, max_rows=None, max_cols=None) #output_lines.append(table_str) # Add summary statistics for numeric columns if they exist numeric_cols = df.select_dtypes(include=['number']).columns sums = df_clean[numeric_cols].sum() # Step 2: Define which columns are food and which are drink food_cols = [col for col in numeric_cols if col.lower() != 'soda'] drink_cols = [col for col in numeric_cols if col.lower() == 'soda'] # Step 3: Aggregate totals food_total = sums[food_cols].sum() drink_total = sums[drink_cols].sum() # Step 4: Format the results as dollars formatted_totals = { 'Food': f"{food_total:.2f}", 'Drink': f"{drink_total:.2f}" } # Step 5: Convert to string for display (optional) result_string = '\n'.join([f"{k}: {v}" for k, v in formatted_totals.items()]) # Convert to string for display #result_string = formatted.to_string() output_lines.append(result_string) #output_lines.append(df_clean[numeric_cols].sum()) if len(numeric_cols) > 0: output_lines.append("-" * 60) #output_lines.append("NUMERIC COLUMN SUMMARY:") #for col in numeric_cols: # stats = df[col].describe() # output_lines.append(f"\n{col}:") # output_lines.append(f" Count: {stats['count']}") # output_lines.append(f" Mean: {stats['mean']:.2f}") # output_lines.append(f" Min: {stats['min']}") # output_lines.append(f" Max: {stats['max']}") return "\n".join(output_lines) async def _arun(self, file_path: str, run_manager: Optional[Any] = None) -> str: """Async version of the tool (falls back to sync implementation).""" return self._run(file_path, run_manager) class PythonExecutorInput(BaseModel): """Input schema for PythonExecutor tool.""" file_path: str = Field(description="Path to the Python file to execute") class PythonExecutorTool(BaseTool): """Tool that executes a Python file and returns the result.""" name: str = "python_executor" description: str = "Executes a Python file from the given file path and returns the output" args_schema: Type[BaseModel] = PythonExecutorInput def _run( self, file_path: str, run_manager: Optional[Any] = None, ) -> str: """Execute the Python file and return the result.""" try: # Validate that the file exists if not os.path.exists(file_path): return f"Error: File '{file_path}' does not exist" # Validate that it's a Python file if not file_path.endswith('.py'): return f"Error: '{file_path}' is not a Python file (.py extension required)" # Execute the Python file result = subprocess.run( [sys.executable, file_path], capture_output=True, text=True, timeout=600 # 30 second timeout to prevent hanging ) # Prepare the output output_parts = [] if result.stdout: output_parts.append(f"STDOUT:\n{result.stdout}") if result.stderr: output_parts.append(f"STDERR:\n{result.stderr}") if result.returncode != 0: output_parts.append(f"Return code: {result.returncode}") if not output_parts: return "Script executed successfully with no output" return "\n\n".join(output_parts) except subprocess.TimeoutExpired: return "Error: Script execution timed out (30 seconds)" except Exception as e: return f"Error executing Python file: {str(e)}" async def _arun( self, file_path: str, run_manager: Optional[Any] = None, ) -> str: """Async version - delegates to sync implementation.""" return self._run(file_path, run_manager) class CommutativityAnalysisTool(BaseTool): """ A tool that analyzes an algebraic operation table to find elements involved in counter-examples to commutativity. This tool executes predefined Python code that: 1. Defines a set S = ['a', 'b', 'c', 'd', 'e'] 2. Defines an operation table as a dictionary of dictionaries 3. Finds elements where table[x][y] != table[y][x] (non-commutative pairs) 4. Returns a sorted, comma-separated list of all elements involved """ name: str = "commutativity_analysis" description: str = ( "Analyzes an algebraic operation table to find elements involved in " "counter-examples to commutativity. Returns a comma-separated list of " "elements where the operation is not commutative." "Provides a direct answer to question on commutativity analysis" ) return_direct: bool = False def _run( self, run_manager: Optional[CallbackManagerForToolRun] = None, **kwargs: Any ) -> str: """Execute the commutativity analysis synchronously.""" try: # Define the set and the operation table S = ['a', 'b', 'c', 'd', 'e'] # The operation table as a dictionary of dictionaries table = { 'a': {'a': 'a', 'b': 'b', 'c': 'c', 'd': 'b', 'e': 'd'}, 'b': {'a': 'b', 'b': 'c', 'c': 'a', 'd': 'e', 'e': 'c'}, 'c': {'a': 'c', 'b': 'a', 'c': 'b', 'd': 'b', 'e': 'a'}, 'd': {'a': 'b', 'b': 'e', 'c': 'b', 'd': 'e', 'e': 'd'}, 'e': {'a': 'd', 'b': 'b', 'c': 'a', 'd': 'd', 'e': 'c'} } # Find elements involved in counter-examples to commutativity involved = set() for x in S: for y in S: if table[x][y] != table[y][x]: involved.add(x) involved.add(y) # Output the result as a comma-separated, alphabetically sorted list result = "subset of S involved in any possible counter-examples: " + ', '.join(sorted(involved)) return result except Exception as e: return f"Error executing commutativity analysis: {str(e)}" async def _arun( self, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, **kwargs: Any ) -> str: """Execute the commutativity analysis asynchronously.""" # For this simple computation, we can just call the synchronous version return self._run() class EnhancedDuckDuckGoSearchTool(BaseTool): name: str = "enhanced_search" description: str = ( "Performs a DuckDuckGo web search and retrieves actual content from the top web results. " "Input should be a search query string. " "Returns search results with extracted content from web pages, making it much more useful for answering questions. " "Use this tool when you need up-to-date information, details about current events, or when other tools do not provide sufficient or recent answers. " "Ideal for topics that require the latest news, recent developments, or information not covered in static sources." ) max_results: int = 3 max_chars_per_page: int = 18000 session: Any = None def model_post_init(self, __context: Any) -> None: super().model_post_init(__context) self.session = requests.Session() self.session.headers.update({ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', 'Accept-Language': 'en-US,en;q=0.5', 'Accept-Encoding': 'gzip, deflate', 'Connection': 'keep-alive', 'Upgrade-Insecure-Requests': '1', }) def _search_duckduckgo(self, query_term: str) -> List[Dict]: # Renamed 'query' to 'query_term' for clarity """Perform DuckDuckGo search and return results.""" try: with DDGS() as ddgs: results = list(ddgs.text(query_term, max_results=self.max_results)) return results except Exception as e: logger.error(f"DuckDuckGo search failed: {e}") return [] def _extract_content_from_url(self, url: str, timeout: int = 10) -> Optional[str]: """Extract clean text content from a web page.""" try: if any(url.lower().endswith(ext) for ext in ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']): return "Content type not supported for extraction" response = self.session.get(url, timeout=timeout, allow_redirects=True) response.raise_for_status() content_type = response.headers.get('content-type', '').lower() if 'text/html' not in content_type: return "Non-HTML content detected" soup = BeautifulSoup(response.content, 'html.parser') for script_or_style in soup(["script", "style", "nav", "header", "footer", "aside", "form"]): script_or_style.decompose() main_content = None for selector in ['main', 'article', '.content', '#content', '.post', '.entry-content', '.entry']: # Added .entry-content main_content = soup.select_one(selector) if main_content: break if not main_content: main_content = soup.find('body') or soup text = main_content.get_text(separator='\n', strip=True) lines = [line.strip() for line in text.split('\n') if line.strip()] text = '\n'.join(lines) text = re.sub(r'\n{3,}', '\n\n', text) text = re.sub(r' {2,}', ' ', text) if len(text) > self.max_chars_per_page: text = text[:self.max_chars_per_page] + "\n[Content truncated...]" return text except requests.exceptions.Timeout: logger.warning(f"Page loading timed out for {url}") return "Page loading timed out" except requests.exceptions.RequestException as e: logger.warning(f"Failed to retrieve page {url}: {str(e)}") return f"Failed to retrieve page: {str(e)}" except Exception as e: logger.error(f"Content extraction failed for {url}: {e}") return "Failed to extract content from page" def _format_search_result(self, result: Dict, content: str) -> str: """Format a single search result with its content.""" title = result.get('title', 'No title') url = result.get('href', 'No URL') snippet = result.get('body', 'No snippet') return f"๐Ÿ” **{title}**\nURL: {url}\nSnippet: {snippet}\n\n๐Ÿ“„ **Page Content:**\n{content}\n---\n" def run(self, tool_input: Union[str, Dict]) -> str: query_str: Optional[str] = None if isinstance(tool_input, dict): if "query" in tool_input and isinstance(tool_input["query"], str): query_str = tool_input["query"] elif "input" in tool_input and isinstance(tool_input["input"], str): query_str = tool_input["input"] else: return "Invalid input: Dictionary received, but does not contain a recognizable string query under 'query' or 'input' keys." elif isinstance(tool_input, str): query_str = tool_input else: return f"Invalid input type: Expected a string or a dictionary, but got {type(tool_input).__name__}." # The misplaced docstring """Execute the enhanced search.""" was removed from here. # Use query_str consistently from now on if not query_str or not query_str.strip(): return "Please provide a search query." query_str = query_str.strip() # Apply strip to query_str logger.info(f"Searching for: {query_str}") # Use query_str search_results = self._search_duckduckgo(query_str) # Use query_str if not search_results: return f"No search results found for query: {query_str}" # Use query_str enhanced_results = [] processed_count = 0 for i, result in enumerate(search_results[:self.max_results]): url = result.get('href', '') if not url: continue logger.info(f"Processing result {i+1}: {url}") content = self._extract_content_from_url(url) if content and len(content.strip()) > 50: formatted_result = self._format_search_result(result, content) enhanced_results.append(formatted_result) processed_count += 1 time.sleep(0.5) # Consider making this configurable or adjusting based on use case if not enhanced_results: return f"Search completed but no content could be extracted from the pages for query: {query_str}" # Use query_str response = f"""๐Ÿ” **Enhanced Search Results for: "{query_str}"** Found {len(search_results)} results, successfully processed {processed_count} pages with content. {''.join(enhanced_results)} ๐Ÿ’ก **Summary:** Retrieved and processed content from {processed_count} web pages to provide comprehensive information about your search query. """ # Use query_str if len(response) > 18000: # This limit is arbitrary; consider if it should relate to self.max_chars_per_page response = response[:18000] + "\n[Response truncated to prevent memory issues]" return response def _run(self, query_or_tool_input: Union[str, Dict]) -> str: # Updated to reflect run's input """Required by BaseTool interface. Handles various input types.""" # This _run method now correctly passes the input to the run method, # which is designed to handle both string and dictionary inputs. return self.run(query_or_tool_input) # --- Agent State Definition --- class AgentState(TypedDict): messages: Annotated[List[AnyMessage], lambda x, y: x + y] done: bool = False # Default value of False question: str task_id: str input_file: Optional[bytes] file_type: Optional[str] context: List[Document] # Using LangChain's Document class file_path: Optional[str] youtube_url: Optional[str] answer: Optional[str] frame_answers: Optional[list] def fetch_page_with_tables(page_title): """ Fetches Wikipedia page content and extracts all tables as readable text. Returns a tuple: (main_text, [table_texts]) """ # Fetch the page object page = wikipedia.page(page_title) main_text = page.content # Get the HTML for table extraction html = page.html() soup = BeautifulSoup(html, 'html.parser') tables = soup.find_all('table') table_texts = [] for table in tables: rows = table.find_all('tr') table_lines = [] for row in rows: cells = row.find_all(['th', 'td']) cell_texts = [cell.get_text(strip=True) for cell in cells] if cell_texts: # Format as Markdown table row table_lines.append(" | ".join(cell_texts)) if table_lines: table_text = "\n".join(table_lines) table_texts.append(table_text) return main_text, table_texts class WikipediaSearchToolWithFAISS(BaseTool): name: str = "wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval" description: str = ( "Fetches content from multiple Wikipedia pages based on intelligent NLP query processing " "of various search candidates, with strong prioritization of query entities. It then performs " "entity-focused semantic search across all fetched content to find the most relevant information, " "with improved retrieval for lists like discographies. Uses spaCy for named entity " "recognition and query enhancement. Input should be a search query or topic. " "Note: Uses the current live version of Wikipedia." ) embedding_model_name: str = "all-MiniLM-L6-v2" chunk_size: int = 4000 chunk_overlap: int = 250 # Maintained moderate overlap top_k_results: int = 3 spacy_model: str = "en_core_web_sm" # Increased multiplier to fetch more candidates per semantic query variant semantic_search_candidate_multiplier: int = 1 # Was 2, increased to 3, consider 4 if still problematic def __init__(self, **kwargs): super().__init__(**kwargs) try: self._nlp = spacy.load(self.spacy_model) print(f"Loaded spaCy model: {self.spacy_model}") self._embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_name) # Refined separators for better handling of Wikipedia lists and sections self._text_splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, separators=[ "\n\n== ", "\n\n=== ", "\n\n==== ", # Section headers (keep with following content) "\n\n\n", "\n\n", # Multiple newlines (paragraph breaks) "\n* ", "\n- ", "\n# ", # List items "\n", ". ", "! ", "? ", # Sentence breaks after newline, common punctuation " ", "" # Word and character level ] ) except OSError as e: print(f"Error loading spaCy model '{self.spacy_model}': {e}") print("Try running: python -m spacy download en_core_web_sm") self._nlp = None self._embedding_model = None self._text_splitter = None except Exception as e: print(f"Error initializing WikipediaSearchToolWithFAISS components: {e}") self._nlp = None self._embedding_model = None self._text_splitter = None def _extract_entities_and_keywords(self, query: str) -> Tuple[List[str], List[str], str]: if not self._nlp: return [], [], query doc = self._nlp(query) main_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "EVENT", "WORK_OF_ART"]] keywords = [token.lemma_.lower() for token in doc if token.pos_ in ["NOUN", "PROPN", "ADJ"] and not token.is_stop and not token.is_punct and len(token.text) > 2] main_entities = list(dict.fromkeys(main_entities)) keywords = list(dict.fromkeys(keywords)) processed_tokens = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct and token.text.strip()] processed_query = " ".join(processed_tokens) return main_entities, keywords, processed_query def _generate_search_candidates(self, query: str, main_entities: List[str], keywords: List[str], processed_query: str) -> List[str]: candidates_set = set() entity_prefix = main_entities[0] if main_entities else None for me in main_entities: candidates_set.add(me) candidates_set.add(query) if processed_query and processed_query != query: candidates_set.add(processed_query) if entity_prefix and keywords: first_entity_lower = entity_prefix.lower() for kw in keywords[:3]: if kw not in first_entity_lower and len(kw) > 2: candidates_set.add(f"{entity_prefix} {kw}") keyword_combo_short = " ".join(k for k in keywords[:2] if k not in first_entity_lower and len(k)>2) if keyword_combo_short: candidates_set.add(f"{entity_prefix} {keyword_combo_short}") if len(main_entities) > 1: candidates_set.add(" ".join(main_entities[:2])) if keywords: keyword_combo = " ".join(keywords[:2]) if entity_prefix: candidate_to_add = f"{entity_prefix} {keyword_combo}" if not any(c.lower() == candidate_to_add.lower() for c in candidates_set): candidates_set.add(candidate_to_add) elif not main_entities: candidates_set.add(keyword_combo) ordered_candidates = [] for me in main_entities: if me not in ordered_candidates: ordered_candidates.append(me) for c in list(candidates_set): if c and c.strip() and c not in ordered_candidates: ordered_candidates.append(c) print(f"Generated {len(ordered_candidates)} search candidates for Wikipedia page lookup (entity-prioritized): {ordered_candidates}") return ordered_candidates def _smart_wikipedia_search(self, query_text: str, main_entities_from_query: List[str], keywords_from_query: List[str], processed_query_text: str) -> List[Tuple[str, str]]: candidates = self._generate_search_candidates(query_text, main_entities_from_query, keywords_from_query, processed_query_text) found_pages_data: List[Tuple[str, str]] = [] processed_page_titles: Set[str] = set() for i, candidate_query in enumerate(candidates): print(f"\nProcessing candidate {i+1}/{len(candidates)} for page: '{candidate_query}'") page_object = None final_page_title = None is_candidate_entity_focused = any(me.lower() in candidate_query.lower() for me in main_entities_from_query) if main_entities_from_query else False try: try: page_to_load = candidate_query suggest_mode = True # Default to auto_suggest=True if is_candidate_entity_focused and main_entities_from_query: try: # Attempt precise match first for entity-focused candidates temp_page = wikipedia.page(page_to_load, auto_suggest=False, redirect=True) suggest_mode = False # Flag that precise match worked except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError): print(f" - auto_suggest=False failed for entity-focused '{page_to_load}', trying with auto_suggest=True.") # Fallthrough to auto_suggest=True below if this fails if suggest_mode: # If not attempted or failed with auto_suggest=False temp_page = wikipedia.page(page_to_load, auto_suggest=True, redirect=True) final_page_title = temp_page.title if is_candidate_entity_focused and main_entities_from_query: title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) if not title_matches_main_entity: print(f" ! Page title '{final_page_title}' (from entity-focused candidate '{candidate_query}') " f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") continue if final_page_title in processed_page_titles: print(f" ~ Already processed '{final_page_title}'") continue page_object = temp_page print(f" โœ“ Direct hit/suggestion for '{candidate_query}' -> '{final_page_title}'") except wikipedia.exceptions.PageError: if i < max(2, len(candidates) // 3) : # Try Wikipedia search for a smaller, more promising subset of candidates print(f" - Direct access failed for '{candidate_query}'. Trying Wikipedia search...") search_results = wikipedia.search(candidate_query, results=1) if not search_results: print(f" - No Wikipedia search results for '{candidate_query}'.") continue search_result_title = search_results[0] try: temp_page = wikipedia.page(search_result_title, auto_suggest=False, redirect=True) # Search results are usually canonical final_page_title = temp_page.title if is_candidate_entity_focused and main_entities_from_query: # Still check against original intent title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) if not title_matches_main_entity: print(f" ! Page title '{final_page_title}' (from search for '{candidate_query}' -> '{search_result_title}') " f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") continue if final_page_title in processed_page_titles: print(f" ~ Already processed '{final_page_title}'") continue page_object = temp_page print(f" โœ“ Found via search '{candidate_query}' -> '{search_result_title}' -> '{final_page_title}'") except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError) as e_sr: print(f" ! Error/Disambiguation for search result '{search_result_title}': {e_sr}") else: print(f" - Direct access failed for '{candidate_query}'. Skipping further search for this lower priority candidate.") except wikipedia.exceptions.DisambiguationError as de: print(f" ! Disambiguation for '{candidate_query}'. Options: {de.options[:1]}") if de.options: option_title = de.options[0] try: temp_page = wikipedia.page(option_title, auto_suggest=False, redirect=True) final_page_title = temp_page.title if is_candidate_entity_focused and main_entities_from_query: # Check against original intent title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) if not title_matches_main_entity: print(f" ! Page title '{final_page_title}' (from disamb. of '{candidate_query}' -> '{option_title}') " f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") continue if final_page_title in processed_page_titles: print(f" ~ Already processed '{final_page_title}'") continue page_object = temp_page print(f" โœ“ Resolved disambiguation '{candidate_query}' -> '{option_title}' -> '{final_page_title}'") except Exception as e_dis_opt: print(f" ! Could not load disambiguation option '{option_title}': {e_dis_opt}") if page_object and final_page_title and (final_page_title not in processed_page_titles): # Extract main text main_text = page_object.content # Extract tables using BeautifulSoup try: html = page_object.html() soup = BeautifulSoup(html, 'html.parser') tables = soup.find_all('table') table_texts = [] for table in tables: rows = table.find_all('tr') table_lines = [] for row in rows: cells = row.find_all(['th', 'td']) cell_texts = [cell.get_text(strip=True) for cell in cells] if cell_texts: table_lines.append(" | ".join(cell_texts)) if table_lines: table_text = "\n".join(table_lines) table_texts.append(table_text) except Exception as e: print(f" !! Error extracting tables for '{final_page_title}': {e}") table_texts = [] # Combine main text and all table texts as separate chunks all_text_chunks = [main_text] + table_texts for chunk in all_text_chunks: found_pages_data.append((chunk, final_page_title)) processed_page_titles.add(final_page_title) print(f" -> Added page '{final_page_title}'. Main text length: {len(main_text)} | Tables extracted: {len(table_texts)}") except Exception as e: print(f" !! Unexpected error processing candidate '{candidate_query}': {e}") if not found_pages_data: print(f"\nCould not find any new, unique, entity-validated Wikipedia pages for query '{query_text}'.") else: print(f"\nFound {len(found_pages_data)} unique, validated page(s) for processing.") return found_pages_data def _enhance_semantic_search(self, query: str, vector_store, main_entities: List[str], keywords: List[str], processed_query: str) -> List[Document]: core_query_parts = set() core_query_parts.add(query) if processed_query != query: core_query_parts.add(processed_query) if keywords: core_query_parts.add(" ".join(keywords[:2])) section_phrases_templates = [] lower_query_terms = set(query.lower().split()) | set(k.lower() for k in keywords) section_keywords_map = { "discography": ["discography", "list of studio albums", "studio album titles and years", "albums by year", "album release dates", "official albums", "complete album list", "albums published"], "biography": ["biography", "life story", "career details", "background history"], "filmography": ["filmography", "list of films", "movie appearances", "acting roles"], } for section_term_key, specific_phrases_list in section_keywords_map.items(): # Check if the key (e.g., "discography") or any of its specific phrases (e.g. "list of studio albums") # are mentioned or implied by the query terms. if section_term_key in lower_query_terms or any(phrase_part in lower_query_terms for phrase_part in section_term_key.split()): section_phrases_templates.extend(specific_phrases_list) # Also check if phrases themselves are in query terms (e.g. query "list of albums by X") for phrase in specific_phrases_list: if phrase in query.lower(): # Check against original query for direct phrase matches section_phrases_templates.extend(specific_phrases_list) # Add all related if one specific is hit break section_phrases_templates = list(dict.fromkeys(section_phrases_templates)) # Deduplicate final_search_queries = set() if main_entities: entity_prefix = main_entities[0] final_search_queries.add(entity_prefix) for part in core_query_parts: final_search_queries.add(f"{entity_prefix} {part}" if entity_prefix.lower() not in part.lower() else part) for phrase_template in section_phrases_templates: final_search_queries.add(f"{entity_prefix} {phrase_template}") if "list of" in phrase_template or "history of" in phrase_template : final_search_queries.add(f"{phrase_template} of {entity_prefix}") else: final_search_queries.update(core_query_parts) final_search_queries.update(section_phrases_templates) deduplicated_queries = list(dict.fromkeys(sq for sq in final_search_queries if sq and sq.strip())) print(f"Generated {len(deduplicated_queries)} semantic search query variants (list-retrieval focused): {deduplicated_queries}") all_results_docs: List[Document] = [] seen_content_hashes: Set[int] = set() k_to_fetch = self.top_k_results * self.semantic_search_candidate_multiplier for search_query_variant in deduplicated_queries: try: results = vector_store.similarity_search_with_score(search_query_variant, k=k_to_fetch) print(f" Semantic search variant '{search_query_variant}' (k={k_to_fetch}) -> {len(results)} raw chunk(s) with scores.") for doc, score in results: # Assuming similarity_search_with_score returns (doc, score) content_hash = hash(doc.page_content[:250]) # Slightly more for hash uniqueness if content_hash not in seen_content_hashes: seen_content_hashes.add(content_hash) doc.metadata['retrieved_by_variant'] = search_query_variant doc.metadata['retrieval_score'] = float(score) # Store score all_results_docs.append(doc) except Exception as e: print(f" Error in semantic search for variant '{search_query_variant}': {e}") # Sort all collected unique results by score (FAISS L2 distance is lower is better) all_results_docs.sort(key=lambda x: x.metadata.get('retrieval_score', float('inf'))) print(f"Collected and re-sorted {len(all_results_docs)} unique chunks from all semantic query variants.") return all_results_docs[:self.top_k_results] def _run(self, query: str = None, search_query: str = None, **kwargs) -> str: if not self._nlp or not self._embedding_model or not self._text_splitter: print("ERROR: WikipediaSearchToolWithFAISS components not initialized properly.") return "Error: Wikipedia tool components not initialized properly. Please check server logs." if not query: query = search_query or kwargs.get('q') or kwargs.get('search_term') try: print(f"\n--- Running {self.name} for query: '{query}' ---") main_entities, keywords, processed_query = self._extract_entities_and_keywords(query) print(f"Initial NLP Analysis - Main Entities: {main_entities}, Keywords: {keywords}, Processed Query: '{processed_query}'") fetched_pages_data = self._smart_wikipedia_search(query, main_entities, keywords, processed_query) if not fetched_pages_data: return (f"Could not find any relevant, entity-validated Wikipedia pages for the query '{query}'. " f"Main entities sought: {main_entities}") all_page_titles = [title for _, title in fetched_pages_data] print(f"\nSuccessfully fetched content for {len(fetched_pages_data)} Wikipedia page(s): {', '.join(all_page_titles)}") all_documents: List[Document] = [] for page_content, page_title in fetched_pages_data: chunks = self._text_splitter.split_text(page_content) if not chunks: print(f"Warning: Could not split content from Wikipedia page '{page_title}' into chunks.") continue for i, chunk_text in enumerate(chunks): all_documents.append(Document(page_content=chunk_text, metadata={ "source_page_title": page_title, "original_query": query, "chunk_index": i # Add chunk index for potential debugging or ordering })) print(f"Split content from '{page_title}' into {len(chunks)} chunks.") if not all_documents: return (f"Could not process content into searchable chunks from the fetched Wikipedia pages " f"({', '.join(all_page_titles)}) for query '{query}'.") print(f"\nTotal document chunks from all pages: {len(all_documents)}") print("Creating FAISS index from content of all fetched pages...") try: vector_store = FAISS.from_documents(all_documents, self._embedding_model) print("FAISS index created successfully.") except Exception as e: return f"Error creating FAISS vector store: {e}" print(f"\nPerforming enhanced semantic search across all collected content...") try: relevant_docs = self._enhance_semantic_search(query, vector_store, main_entities, keywords, processed_query) except Exception as e: return f"Error during semantic search: {e}" if not relevant_docs: return (f"No relevant information found within Wikipedia page(s) '{', '.join(list(dict.fromkeys(all_page_titles)))}' " f"for your query '{query}' using entity-focused semantic search with list retrieval.") unique_sources_in_results = list(dict.fromkeys([doc.metadata.get('source_page_title', 'Unknown Source') for doc in relevant_docs])) result_header = (f"Found {len(relevant_docs)} relevant piece(s) of information from Wikipedia page(s) " f"'{', '.join(unique_sources_in_results)}' for your query '{query}':\n") nlp_summary = (f"[Original Query NLP: Main Entities: {', '.join(main_entities) if main_entities else 'None'}, " f"Keywords: {', '.join(keywords[:5]) if keywords else 'None'}]\n\n") result_details = [] for i, doc in enumerate(relevant_docs): source_info = doc.metadata.get('source_page_title', 'Unknown Source') variant_info = doc.metadata.get('retrieved_by_variant', 'N/A') score_info = doc.metadata.get('retrieval_score', 'N/A') detail = (f"Result {i+1} (source: '{source_info}', score: {score_info:.4f})\n" f"(Retrieved by: '{variant_info}')\n{doc.page_content}") result_details.append(detail) final_result = result_header + nlp_summary + "\n\n---\n\n".join(result_details) print(f"\nReturning {len(relevant_docs)} relevant chunks from {len(set(all_page_titles))} source page(s).") return final_result.strip() except Exception as e: import traceback print(f"Unexpected error in {self.name}: {traceback.format_exc()}") return f"An unexpected error occurred: {str(e)}" class EnhancedYoutubeScreenshotQA(BaseTool): name: str = "bird_species_screenshot_qa" description: str = ( "Use this tool to calculate the number of bird species on camera at any one time," "Input should be a dict with keys: 'youtube_url', 'question', and optional parameters. " "Example: {'youtube_url': 'https://youtube.com/watch?v=xyz', 'question': 'What animals are visible?'}" ) # Define Pydantic fields for the attributes we need to set device: Any = Field(default=None, exclude=True) processor_vqa: Any = Field(default=None, exclude=True) model_vqa: Any = Field(default=None, exclude=True) class Config: # Allow arbitrary types (needed for torch.device, model objects) arbitrary_types_allowed = True # Allow extra fields to be set extra = "allow" def __init__(self, **kwargs): super().__init__(**kwargs) # Initialize directories cache_dir = '/tmp/youtube_qa_cache/' video_dir = '/tmp/video/' frames_dir = '/tmp/video_frames/' # Initialize model and device self._initialize_model() # Create directories for dir_path in [cache_dir, video_dir, frames_dir]: os.makedirs(dir_path, exist_ok=True) def _get_config(self, key: str, default_value=None, input_data: Dict[str, Any] = None): """Get configuration value with fallback to defaults""" defaults = { 'frame_interval_seconds': 20, 'max_frames': 50, 'use_scene_detection': True, 'resize_frames': True, 'parallel_processing': True, 'cache_enabled': True, 'quality_threshold': 30.0, 'semantic_similarity_threshold': 0.8 } if input_data and key in input_data: return input_data[key] return defaults.get(key, default_value) def _initialize_model(self): """Initialize BLIP model for VQA with error handling""" try: self.device = torch.device("cpu") print(f"Using device: {self.device}") self.processor_vqa = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") self.model_vqa = BlipForQuestionAnswering.from_pretrained( "Salesforce/blip-vqa-base" ).to(self.device) print("BLIP VQA model loaded successfully") except Exception as e: print(f"Error initializing VQA model: {str(e)}") raise def _get_video_hash(self, url: str) -> str: """Generate hash for video URL for caching""" return hashlib.md5(url.encode()).hexdigest() def _get_cache_path(self, video_hash: str, cache_type: str) -> str: """Get cache file path""" cache_dir = '/tmp/youtube_qa_cache/' return os.path.join(cache_dir, f"{video_hash}_{cache_type}") def _load_from_cache(self, cache_path: str, cache_enabled: bool = True) -> Optional[Any]: """Load data from cache""" if not cache_enabled or not os.path.exists(cache_path): return None try: with open(cache_path, 'r') as f: return json.load(f) except Exception as e: print(f"Error loading cache: {str(e)}") return None def _save_to_cache(self, cache_path: str, data: Any, cache_enabled: bool = True): """Save data to cache""" if not cache_enabled: return try: with open(cache_path, 'w') as f: json.dump(data, f) except Exception as e: print(f"Error saving cache: {str(e)}") def download_youtube_video(self, url: str, video_hash: str, cache_enabled: bool = True) -> Optional[str]: """Enhanced YouTube video download with anti-bot measures""" video_dir = '/tmp/video/' output_filename = f'{video_hash}.mp4' output_path = os.path.join(video_dir, output_filename) # Check cache if cache_enabled and os.path.exists(output_path): print(f"Using cached video: {output_path}") return output_path # Clean directory self._clean_directory(video_dir) try: # Enhanced yt-dlp options with anti-bot measures ydl_opts = { # Format selection - prefer lower quality to avoid restrictions 'format': 'best[height<=480][ext=mp4]/best[height<=720][ext=mp4]/best[ext=mp4]/best', 'outtmpl': output_path, 'quiet': False, # Changed to False for debugging 'no_warnings': False, 'merge_output_format': 'mp4', # Anti-bot headers and user agent 'http_headers': { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36', 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', 'Accept-Language': 'en-us,en;q=0.5', 'Accept-Encoding': 'gzip,deflate', 'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7', 'Connection': 'keep-alive', 'Upgrade-Insecure-Requests': '1', }, # Additional anti-detection measures 'extractor_args': { 'youtube': { 'skip': ['hls', 'dash'], # Skip some formats that might trigger detection 'player_skip': ['js'], # Skip JavaScript player } }, # Rate limiting 'sleep_interval': 1, 'max_sleep_interval': 5, 'sleep_interval_subtitles': 1, # Retry settings 'retries': 3, 'fragment_retries': 3, 'skip_unavailable_fragments': True, # Cookie handling (you can add browser cookies if needed) # 'cookiefile': '/path/to/cookies.txt', # Uncomment and set path if you have cookies # Additional options 'extract_flat': False, 'writesubtitles': False, 'writeautomaticsub': False, 'ignoreerrors': True, # Postprocessors 'postprocessors': [{ 'key': 'FFmpegVideoConvertor', 'preferedformat': 'mp4', }] } print(f"Attempting to download: {url}") # Try multiple download strategies strategies = [ # Strategy 1: Standard download ydl_opts, # Strategy 2: More conservative approach { **ydl_opts, 'format': 'worst[ext=mp4]/worst', # Try worst quality first 'sleep_interval': 2, 'max_sleep_interval': 10, }, # Strategy 3: Different user agent { **ydl_opts, 'http_headers': { 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/16.1 Safari/605.1.15' }, 'format': 'best[height<=360][ext=mp4]/best[ext=mp4]/best', } ] last_error = None for i, strategy in enumerate(strategies, 1): try: print(f"Trying download strategy {i}/3...") with yt_dlp.YoutubeDL(strategy) as ydl: # Add some delay before download import time time.sleep(2) ydl.download([url]) if os.path.exists(output_path): print(f"Video downloaded successfully with strategy {i}: {output_path}") return output_path else: print(f"Strategy {i} completed but file not found") except Exception as e: last_error = e print(f"Strategy {i} failed: {str(e)}") if i < len(strategies): print(f"Trying next strategy...") # Add delay between strategies import time time.sleep(5) continue # If all strategies failed, try one more approach with cookies from browser print("All standard strategies failed. Trying with browser cookies...") try: cookie_strategy = { **ydl_opts, 'cookiesfrombrowser': ('chrome',), # Try to get cookies from Chrome 'format': 'worst[ext=mp4]/worst', } with yt_dlp.YoutubeDL(cookie_strategy) as ydl: ydl.download([url]) if os.path.exists(output_path): print(f"Video downloaded successfully with browser cookies: {output_path}") return output_path except Exception as e: print(f"Browser cookie strategy also failed: {str(e)}") print(f"All download strategies failed. Last error: {last_error}") return None except Exception as e: print(f"Error downloading YouTube video: {str(e)}") return None def _clean_directory(self, directory: str): """Clean directory contents""" if os.path.exists(directory): for filename in os.listdir(directory): file_path = os.path.join(directory, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print(f'Failed to delete {file_path}. Reason: {e}') def _assess_frame_quality(self, frame: np.ndarray) -> float: """Assess frame quality using Laplacian variance (blur detection)""" try: gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) return cv2.Laplacian(gray, cv2.CV_64F).var() except Exception: return 0.0 def _detect_scene_changes(self, video_path: str, threshold: float = 30.0) -> List[int]: """Detect scene changes in video""" scene_frames = [] try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return [] prev_frame = None frame_count = 0 while True: ret, frame = cap.read() if not ret: break if prev_frame is not None: # Calculate histogram difference hist1 = cv2.calcHist([prev_frame], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) hist2 = cv2.calcHist([frame], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) diff = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CHISQR) if diff > threshold: scene_frames.append(frame_count) prev_frame = frame.copy() frame_count += 1 cap.release() return scene_frames except Exception as e: print(f"Error in scene detection: {str(e)}") return [] def smart_extract_frames(self, video_path: str, video_hash: str, input_data: Dict[str, Any] = None) -> List[str]: """Intelligently extract frames with quality filtering and scene detection""" cache_enabled = self._get_config('cache_enabled', True, input_data) cache_path = self._get_cache_path(video_hash, "frames_info.json") cached_info = self._load_from_cache(cache_path, cache_enabled) if cached_info: # Verify cached frames still exist existing_frames = [f for f in cached_info['frame_paths'] if os.path.exists(f)] if len(existing_frames) == len(cached_info['frame_paths']): print(f"Using {len(existing_frames)} cached frames") return existing_frames # Clean frames directory frames_dir = '/tmp/video_frames/' self._clean_directory(frames_dir) try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print("Error: Could not open video.") return [] fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_interval_seconds = self._get_config('frame_interval_seconds', 10, input_data) frame_interval = max(1, int(fps * frame_interval_seconds)) print(f"Video info: {total_frames} frames, {fps:.2f} fps") # Get scene change frames if enabled scene_frames = set() use_scene_detection = self._get_config('use_scene_detection', True, input_data) if use_scene_detection: scene_frames = set(self._detect_scene_changes(video_path)) print(f"Detected {len(scene_frames)} scene changes") extracted_frames = [] frame_count = 0 saved_count = 0 max_frames = self._get_config('max_frames', 50, input_data) while True: ret, frame = cap.read() if not ret or saved_count >= max_frames: break # Check if we should extract this frame should_extract = ( frame_count % frame_interval == 0 or frame_count in scene_frames ) if should_extract: # Assess frame quality quality = self._assess_frame_quality(frame) quality_threshold = self._get_config('quality_threshold', 30.0, input_data) if quality >= quality_threshold: # Resize frame if enabled resize_frames = self._get_config('resize_frames', True, input_data) if resize_frames: height, width = frame.shape[:2] if width > 800: scale = 800 / width new_width = 800 new_height = int(height * scale) frame = cv2.resize(frame, (new_width, new_height)) frame_filename = os.path.join( frames_dir, f"frame_{frame_count:06d}_q{quality:.1f}.jpg" ) if cv2.imwrite(frame_filename, frame): extracted_frames.append(frame_filename) saved_count += 1 print(f"Extracted frame {saved_count}/{max_frames} " f"(quality: {quality:.1f})") frame_count += 1 cap.release() # Cache frame information frame_info = { 'frame_paths': extracted_frames, 'extraction_time': time.time(), 'total_frames_processed': frame_count, 'frames_extracted': len(extracted_frames) } self._save_to_cache(cache_path, frame_info, cache_enabled) print(f"Successfully extracted {len(extracted_frames)} high-quality frames") return extracted_frames except Exception as e: print(f"Exception during frame extraction: {e}") return [] def _answer_question_on_frame(self, frame_path: str, question: str) -> Tuple[str, float]: """Answer question on single frame with confidence scoring""" try: image = Image.open(frame_path).convert('RGB') inputs = self.processor_vqa(image, question, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model_vqa.generate(**inputs, output_scores=True, return_dict_in_generate=True) answer = self.processor_vqa.decode(outputs.sequences[0], skip_special_tokens=True) # Calculate confidence (simplified - you might want to use actual model confidence) confidence = 1.0 # Placeholder - BLIP doesn't directly provide confidence return answer, confidence except Exception as e: print(f"Error processing frame {frame_path}: {str(e)}") return "Error processing this frame", 0.0 def _process_frames_parallel(self, frame_files: List[str], question: str, input_data: Dict[str, Any] = None) -> List[Tuple[str, str, float]]: """Process frames in parallel""" results = [] parallel_processing = self._get_config('parallel_processing', True, input_data) if parallel_processing: with ThreadPoolExecutor(max_workers=min(4, len(frame_files))) as executor: future_to_frame = { executor.submit(self._answer_question_on_frame, frame_path, question): frame_path for frame_path in frame_files } for future in as_completed(future_to_frame): frame_path = future_to_frame[future] try: answer, confidence = future.result() results.append((frame_path, answer, confidence)) print(f"Processed {os.path.basename(frame_path)}: {answer} (conf: {confidence:.2f})") except Exception as e: print(f"Error processing {frame_path}: {str(e)}") results.append((frame_path, "Error", 0.0)) else: for frame_path in frame_files: answer, confidence = self._answer_question_on_frame(frame_path, question) results.append((frame_path, answer, confidence)) print(f"Processed {os.path.basename(frame_path)}: {answer} (conf: {confidence:.2f})") return results def _cluster_similar_answers(self, answers: List[str], input_data: Dict[str, Any] = None) -> Dict[str, List[str]]: """Cluster semantically similar answers""" if len(answers) <= 1: return {answers[0]: answers} if answers else {} try: # First try with standard TF-IDF settings vectorizer = TfidfVectorizer( stop_words='english', lowercase=True, min_df=1, # Include words that appear in at least 1 document max_df=1.0 # Include words that appear in up to 100% of documents ) tfidf_matrix = vectorizer.fit_transform(answers) # Check if we have any features after TF-IDF if tfidf_matrix.shape[1] == 0: raise ValueError("No features after TF-IDF processing") # Calculate cosine similarity similarity_matrix = cosine_similarity(tfidf_matrix) # Cluster similar answers clusters = defaultdict(list) used = set() semantic_similarity_threshold = self._get_config('semantic_similarity_threshold', 0.8, input_data) for i, answer in enumerate(answers): if i in used: continue cluster_key = answer clusters[cluster_key].append(answer) used.add(i) # Find similar answers for j in range(i + 1, len(answers)): if j not in used and similarity_matrix[i][j] >= semantic_similarity_threshold: clusters[cluster_key].append(answers[j]) used.add(j) return dict(clusters) except (ValueError, Exception) as e: print(f"Error in semantic clustering: {str(e)}") # Fallback 1: Try without stop words filtering try: print("Attempting clustering without stop word filtering...") vectorizer_no_stop = TfidfVectorizer( lowercase=True, min_df=1, token_pattern=r'\b\w+\b' # Match any word ) tfidf_matrix = vectorizer_no_stop.fit_transform(answers) if tfidf_matrix.shape[1] > 0: similarity_matrix = cosine_similarity(tfidf_matrix) clusters = defaultdict(list) used = set() semantic_similarity_threshold = self._get_config('semantic_similarity_threshold', 0.8, input_data) for i, answer in enumerate(answers): if i in used: continue cluster_key = answer clusters[cluster_key].append(answer) used.add(i) for j in range(i + 1, len(answers)): if j not in used and similarity_matrix[i][j] >= semantic_similarity_threshold: clusters[cluster_key].append(answers[j]) used.add(j) return dict(clusters) except Exception as e2: print(f"Fallback clustering also failed: {str(e2)}") # Fallback 2: Simple string-based clustering print("Using simple string-based clustering...") return self._simple_string_cluster(answers) def _simple_string_cluster(self, answers: List[str]) -> Dict[str, List[str]]: """Simple string-based clustering fallback""" clusters = defaultdict(list) # Normalize answers for comparison normalized_answers = {} for answer in answers: normalized = answer.lower().strip() normalized_answers[answer] = normalized used = set() for i, answer in enumerate(answers): if answer in used: continue cluster_key = answer clusters[cluster_key].append(answer) used.add(answer) # Find similar answers using simple string similarity for j, other_answer in enumerate(answers[i+1:], i+1): if other_answer in used: continue # Check for exact match after normalization if normalized_answers[answer] == normalized_answers[other_answer]: clusters[cluster_key].append(other_answer) used.add(other_answer) # Alternatively, check if one string contains the other elif (normalized_answers[answer] in normalized_answers[other_answer] or normalized_answers[other_answer] in normalized_answers[answer]): clusters[cluster_key].append(other_answer) used.add(other_answer) return dict(clusters) def _analyze_temporal_patterns(self, results: List[Tuple[str, str, float]]) -> Dict[str, Any]: """Analyze temporal patterns in answers""" try: # Sort by frame number def get_frame_number(frame_path): match = re.search(r'frame_(\d+)', os.path.basename(frame_path)) return int(match.group(1)) if match else 0 sorted_results = sorted(results, key=lambda x: get_frame_number(x[0])) # Analyze answer changes over time answers_timeline = [result[1] for result in sorted_results] changes = [] for i in range(1, len(answers_timeline)): if answers_timeline[i] != answers_timeline[i-1]: changes.append({ 'frame_index': i, 'from_answer': answers_timeline[i-1], 'to_answer': answers_timeline[i] }) return { 'total_changes': len(changes), 'change_points': changes, 'stability_ratio': 1 - (len(changes) / max(1, len(answers_timeline) - 1)), 'answers_timeline': answers_timeline } except Exception as e: print(f"Error in temporal analysis: {str(e)}") return {'error': str(e)} def analyze_video_question(self, frame_files: List[str], question: str, input_data: Dict[str, Any] = None) -> Dict[str, Any]: """Comprehensive video question analysis""" if not frame_files: return { "final_answer": "No frames available for analysis.", "confidence": 0.0, "frame_count": 0, "error": "No valid frames found" } # Process all frames print(f"Processing {len(frame_files)} frames...") results = self._process_frames_parallel(frame_files, question, input_data) if not results: return { "final_answer": "Could not analyze any frames successfully.", "confidence": 0.0, "frame_count": 0, "error": "Frame processing failed" } # Extract answers and confidences answers = [result[1] for result in results if result[1] != "Error"] confidences = [result[2] for result in results if result[1] != "Error"] # Calculate statistical summary on numeric answers numeric_answers = [] for answer in answers: try: # Try to convert answer to float numeric_value = float(answer) numeric_answers.append(numeric_value) except (ValueError, TypeError): # Skip non-numeric answers pass if numeric_answers: stats = { "minimum": float(np.min(numeric_answers)), "maximum": float(np.max(numeric_answers)), "range": float(np.max(numeric_answers) - np.min(numeric_answers)), "mean": float(np.mean(numeric_answers)), "median": float(np.median(numeric_answers)), "count": len(numeric_answers), "data_type": "answers" } elif confidences: # Fallback to confidence statistics if no numeric answers stats = { "minimum": float(np.min(confidences)), "maximum": float(np.max(confidences)), "range": float(np.max(confidences) - np.min(confidences)), "mean": float(np.mean(confidences)), "median": float(np.median(confidences)), "count": len(confidences), "data_type": "confidences" } else: stats = { "minimum": 0.0, "maximum": 0.0, "range": 0.0, "mean": 0.0, "median": 0.0, "count": 0, "data_type": "none", "note": "No numeric results available for statistical summary" } if not answers: return { "final_answer": "All frame processing failed.", "confidence": 0.0, "frame_count": len(frame_files), "error": "No successful frame analysis" } # Cluster similar answers answer_clusters = self._cluster_similar_answers(answers, input_data) # Find most common cluster largest_cluster = max(answer_clusters.items(), key=lambda x: len(x[1])) most_common_answer = largest_cluster[0] # Calculate weighted confidence answer_counts = Counter(answers) total_answers = len(answers) frequency_confidence = answer_counts[most_common_answer] / total_answers avg_confidence = np.mean(confidences) if confidences else 0.0 final_confidence = (frequency_confidence * 0.7) + (avg_confidence * 0.3) # Temporal analysis temporal_analysis = self._analyze_temporal_patterns(results) return { "final_answer": most_common_answer, "confidence": final_confidence, "frame_count": len(frame_files), "successful_analyses": len(answers), "answer_distribution": dict(answer_counts), "semantic_clusters": {k: len(v) for k, v in answer_clusters.items()}, "temporal_analysis": temporal_analysis, "average_model_confidence": avg_confidence, "frequency_confidence": frequency_confidence, "statistical_summary": stats } def _run(self, youtube_url, question, **kwargs) -> str: """Enhanced main execution method""" question = "How many unique bird species are on camera?" input_data = { 'youtube_url': youtube_url, 'question': question } if not youtube_url or not question: return "Error: Input must include 'youtube_url' and 'question'." try: # Generate video hash for caching video_hash = self._get_video_hash(youtube_url) # Step 1: Download video print(f"Downloading YouTube video from {youtube_url}...") cache_enabled = self._get_config('cache_enabled', True, input_data) video_path = self.download_youtube_video(youtube_url, video_hash, cache_enabled) if not video_path or not os.path.exists(video_path): return "Error: Failed to download the YouTube video. This may be due to YouTube's anti-bot protection. Try using a different video or implement cookie authentication." # Step 2: Smart frame extraction print(f"Extracting frames with smart selection...") frame_files = self.smart_extract_frames(video_path, video_hash, input_data) if not frame_files: return "Error: Failed to extract frames from the video." # Step 3: Comprehensive analysis print(f"Analyzing {len(frame_files)} frames for question: '{question}'") analysis_result = self.analyze_video_question(frame_files, question, input_data) if analysis_result.get("error"): return f"Error: {analysis_result['error']}" # Format comprehensive result - Fixed the reference to stats result = f""" ๐Ÿ“Š **STATISTICAL SUMMARY**: โ€ข Minimum: {analysis_result['statistical_summary']['minimum']:.2f} โ€ข Maximum: {analysis_result['statistical_summary']['maximum']:.2f} โ€ข Mean: {analysis_result['statistical_summary']['mean']:.2f} โ€ข Median: {analysis_result['statistical_summary']['median']:.2f} โ€ข Range: {analysis_result['statistical_summary']['range']:.2f} """.strip() return result except Exception as e: return f"Error during video analysis: {str(e)}" # Initialize the enhanced tool def create_enhanced_youtube_qa_tool(**kwargs): """Factory function to create the enhanced tool with custom parameters""" return EnhancedYoutubeScreenshotQA(**kwargs) import os import json import hashlib import time import shutil import glob from typing import Dict, Any, List, Optional from concurrent.futures import ThreadPoolExecutor, as_completed import yt_dlp import speech_recognition as sr from pydantic import Field from pydantic.v1 import BaseModel from pydub import AudioSegment from pydub.silence import split_on_silence class BaseTool(BaseModel): name: str description: str class YouTubeTranscriptExtractor(BaseTool): name: str = "youtube_transcript_extractor" description: str = ( "Downloads a YouTube video and extracts the complete audio transcript using speech recognition. " "Use this tool for questions about what people say in YouTube videos. " "Input should be a dict with keys: 'youtube_url' and optional parameters. " "Optional parameters include 'language' (e.g., 'en-US'), " "'cookies_file_path' (path to a cookies TXT file for authentication), " "or 'cookies_from_browser' (string specifying browser for cookies, e.g., 'chrome', 'firefox:profileName', 'edge+keyringName:profileName::containerName'). " "Example: {'youtube_url': 'https://youtube.com/watch?v=xyz', 'language': 'en-US'} or " "{'youtube_url': '...', 'cookies_file_path': '/path/to/cookies.txt'} or " "{'youtube_url': '...', 'cookies_from_browser': 'chrome'}" ) recognizer: Any = Field(default=None, exclude=True) class Config: arbitrary_types_allowed = True extra = Extra.allow # Adjusted if pydantic v1 style def __init__(self, **kwargs: Any): super().__init__(**kwargs) self.cache_dir = '/tmp/youtube_transcript_cache/' self.audio_dir = '/tmp/audio/' self.chunks_dir = '/tmp/audio_chunks/' self.recognizer = sr.Recognizer() self.recognizer.energy_threshold = 4000 self.recognizer.pause_threshold = 0.8 for dir_path in [self.cache_dir, self.audio_dir, self.chunks_dir]: os.makedirs(dir_path, exist_ok=True) def _get_config(self, key: str, default_value: Any = None, input_data: Optional[Dict[str, Any]] = None) -> Any: defaults = { 'language': 'en-US', 'chunk_length_ms': 30000, 'silence_thresh': -40, 'audio_quality': 'best', 'cache_enabled': True, 'min_silence_len': 500, 'overlap_ms': 1000, 'cookies_file_path': None, # New: Path to a cookies file 'cookies_from_browser': None # New: Browser string e.g., "chrome", "firefox:profile_name" } if input_data and key in input_data: return input_data[key] return defaults.get(key, default_value) def _get_video_hash(self, url: str) -> str: return hashlib.md5(url.encode()).hexdigest() def _get_cache_path(self, video_hash: str, cache_type: str) -> str: return os.path.join(self.cache_dir, f"{video_hash}_{cache_type}") def _load_from_cache(self, cache_path: str, cache_enabled: bool = True) -> Optional[Any]: if not cache_enabled or not os.path.exists(cache_path): return None try: with open(cache_path, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: print(f"Error loading cache: {str(e)}") return None def _save_to_cache(self, cache_path: str, data: Any, cache_enabled: bool = True): if not cache_enabled: return try: with open(cache_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) except Exception as e: print(f"Error saving cache: {str(e)}") def _clean_directory(self, directory: str): if os.path.exists(directory): for filename in os.listdir(directory): file_path = os.path.join(directory, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print(f'Failed to delete {file_path}. Reason: {e}') def download_youtube_audio(self, url: str, video_hash: str, input_data: Optional[Dict[str, Any]] = None) -> Optional[str]: audio_quality = self._get_config('audio_quality', 'best', input_data) output_filename = f'{video_hash}.wav' output_path = os.path.join(self.audio_dir, output_filename) cache_enabled = self._get_config('cache_enabled', True, input_data) if cache_enabled and os.path.exists(output_path): print(f"Using cached audio: {output_path}") return output_path self._clean_directory(self.audio_dir) cookies_file_path = self._get_config('cookies_file_path', None, input_data) cookies_from_browser_str = self._get_config('cookies_from_browser', None, input_data) try: ydl_opts: Dict[str, Any] = { 'format': 'bestaudio[ext=m4a]/bestaudio/best', 'outtmpl': os.path.join(self.audio_dir, f'{video_hash}.%(ext)s'), 'quiet': False, 'no_warnings': False, 'extract_flat': False, # Ensure this is false for actual downloads 'writethumbnail': False, 'writeinfojson': False, 'postprocessors': [{ 'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav', 'preferredquality': '192' if audio_quality == 'best' else '128', }], 'http_headers': { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' }, 'nocheckcertificate': True, } if cookies_file_path: ydl_opts['cookiefile'] = cookies_file_path print(f"Using cookies from file: {cookies_file_path}") elif cookies_from_browser_str: parsed_browser, parsed_profile, parsed_keyring, parsed_container = None, None, None, None temp_str = cookies_from_browser_str if '::' in temp_str: main_part_before_container, parsed_container_val = temp_str.split('::', 1) parsed_container = parsed_container_val if parsed_container_val else None temp_str = main_part_before_container if ':' in temp_str: browser_keyring_part, parsed_profile_val = temp_str.split(':', 1) parsed_profile = parsed_profile_val if parsed_profile_val else None temp_str = browser_keyring_part if '+' in temp_str: parsed_browser_val, parsed_keyring_val = temp_str.split('+', 1) parsed_browser = parsed_browser_val parsed_keyring = parsed_keyring_val if parsed_keyring_val else None else: parsed_browser = temp_str if parsed_browser: # yt-dlp expects cookiesfrombrowser as a tuple: (BROWSER, PROFILE, KEYRING, CONTAINER) final_tuple: Tuple[Optional[str], ...] = ( parsed_browser, parsed_profile, parsed_keyring, parsed_container ) ydl_opts['cookiesfrombrowser'] = final_tuple print(f"Attempting to use cookies from browser spec '{cookies_from_browser_str}', parsed as: {final_tuple}") else: print(f"Invalid or empty browser name in cookies_from_browser string: '{cookies_from_browser_str}'") with yt_dlp.YoutubeDL(ydl_opts) as ydl: print(f"Downloading audio from: {url} with options: {ydl_opts}") ydl.download([url]) if os.path.exists(output_path): print(f"Audio downloaded successfully: {output_path}") return output_path else: possible_files = glob.glob(os.path.join(self.audio_dir, f'{video_hash}.*')) if possible_files: source_file = possible_files[0] if not source_file.endswith('.wav'): try: audio = AudioSegment.from_file(source_file) audio.export(output_path, format="wav") os.remove(source_file) print(f"Audio converted to WAV: {output_path}") return output_path except Exception as e: print(f"Error converting audio: {str(e)}") return None else: # Already a .wav, possibly due to postprocessor already creating it with a different ext pattern if source_file != output_path: # if names differ due to original extension shutil.move(source_file, output_path) print(f"Audio file found: {output_path}") return output_path print(f"No audio file found at expected path after download: {output_path}") return None except yt_dlp.utils.DownloadError as de: print(f"yt-dlp DownloadError: {str(de)}") if "Sign in to confirm you're not a bot" in str(de) and not (cookies_file_path or cookies_from_browser_str): print("Authentication required. Consider using 'cookies_file_path' or 'cookies_from_browser' options.") return None except Exception as e: print(f"Error downloading YouTube audio: {type(e).__name__} - {str(e)}") # Fallback attempt is removed as it's unlikely to succeed if the primary authenticated attempt fails due to bot detection return None def _split_audio_intelligent(self, audio_path: str, input_data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: self._clean_directory(self.chunks_dir) try: audio = AudioSegment.from_wav(audio_path) chunk_length_ms = self._get_config('chunk_length_ms', 30000, input_data) silence_thresh = self._get_config('silence_thresh', -40, input_data) min_silence_len = self._get_config('min_silence_len', 500, input_data) overlap_ms = self._get_config('overlap_ms', 1000, input_data) # Not used in current split_on_silence chunks = split_on_silence( audio, min_silence_len=min_silence_len, silence_thresh=silence_thresh, keep_silence=True ) processed_chunks: List[AudioSegment] = [] # type: ignore # Combine small chunks or re-chunk if silence splitting is ineffective temp_chunk: Optional[AudioSegment] = None # type: ignore for chunk in chunks: if temp_chunk is None: temp_chunk = chunk else: temp_chunk += chunk if len(temp_chunk) > chunk_length_ms / 2 or chunk == chunks[-1]: # Arbitrary threshold to combine small chunks processed_chunks.append(temp_chunk) temp_chunk = None if not processed_chunks or any(len(p_chunk) > chunk_length_ms * 1.5 for p_chunk in processed_chunks): # If still problematic print("Using time-based splitting due to ineffective silence splitting or overly large chunks...") processed_chunks = [] for i in range(0, len(audio), chunk_length_ms - overlap_ms): chunk_segment = audio[i:i + chunk_length_ms] if len(chunk_segment) > 1000: processed_chunks.append(chunk_segment) chunk_data = [] current_time_ms = 0 for i, chunk_segment in enumerate(processed_chunks): if len(chunk_segment) < 1000: continue chunk_filename = os.path.join(self.chunks_dir, f"chunk_{i:04d}.wav") chunk_segment.export(chunk_filename, format="wav") duration_s = len(chunk_segment) / 1000.0 start_time_s = current_time_ms / 1000.0 end_time_s = start_time_s + duration_s chunk_data.append({ 'filename': chunk_filename, 'index': i, 'start_time': start_time_s, 'duration': duration_s, 'end_time': end_time_s }) current_time_ms += len(chunk_segment) # Approximation, true timestamping is harder print(f"Split audio into {len(chunk_data)} chunks") return chunk_data except Exception as e: print(f"Error splitting audio: {str(e)}") try: # Fallback: single chunk audio = AudioSegment.from_wav(audio_path) duration = len(audio) / 1000.0 return [{'filename': audio_path, 'index': 0, 'start_time': 0, 'duration': duration, 'end_time': duration}] except: return [] def _transcribe_audio_chunk(self, chunk_info: Dict[str, Any], input_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: chunk_path = chunk_info['filename'] base_result = { 'start_time': chunk_info.get('start_time', 0), 'end_time': chunk_info.get('end_time', 0), 'duration': chunk_info.get('duration', 0), 'index': chunk_info.get('index', -1), 'success': False, 'confidence': 0.0 } try: language = self._get_config('language', 'en-US', input_data) with sr.AudioFile(chunk_path) as source: self.recognizer.adjust_for_ambient_noise(source, duration=0.2) # Shorter adjustment audio_data = self.recognizer.record(source) try: text = self.recognizer.recognize_google(audio_data, language=language) return {**base_result, 'text': text, 'confidence': 1.0, 'success': True} except sr.UnknownValueError: try: # Try without specific language text = self.recognizer.recognize_google(audio_data) return {**base_result, 'text': text, 'confidence': 0.8, 'success': True} # Lower confidence except sr.UnknownValueError: return {**base_result, 'text': '[INAUDIBLE]'} except sr.RequestError as e: return {**base_result, 'text': f'[RECOGNITION_ERROR: {str(e)}]', 'error': str(e)} except Exception as e: return {**base_result, 'text': f'[ERROR: {str(e)}]', 'error': str(e)} def _transcribe_chunks_parallel(self, chunk_data: List[Dict[str, Any]], input_data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: results = [] max_workers = min(os.cpu_count() or 1, 4) # Limit workers with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_chunk = { executor.submit(self._transcribe_audio_chunk, chunk_info, input_data): chunk_info for chunk_info in chunk_data } for future in as_completed(future_to_chunk): chunk_info = future_to_chunk[future] try: result = future.result() results.append(result) status = "Transcribed" if result['success'] else "Failed" preview = result['text'][:50] + "..." if len(result['text']) > 50 else result['text'] print(f"{status} chunk {result['index']}: {preview}") except Exception as e: print(f"Error processing chunk {chunk_info.get('index', '?')}: {str(e)}") results.append({ 'text': f'[PROCESSING_ERROR: {str(e)}]', 'confidence': 0.0, 'start_time': chunk_info.get('start_time', 0), 'end_time': chunk_info.get('end_time', 0), 'duration': chunk_info.get('duration', 0), 'index': chunk_info.get('index', 0), 'success': False, 'error': str(e) }) results.sort(key=lambda x: x['index']) return results def extract_transcript(self, audio_path: str, video_hash: str, input_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: cache_enabled = self._get_config('cache_enabled', True, input_data) cache_path = self._get_cache_path(video_hash, "transcript.json") cached_transcript = self._load_from_cache(cache_path, cache_enabled) if cached_transcript: print("Using cached transcript") return cached_transcript try: print("Splitting audio into chunks...") chunk_data = self._split_audio_intelligent(audio_path, input_data) if not chunk_data: return {'error': 'Failed to split audio', 'full_transcript': '', 'success_rate': 0.0} print(f"Transcribing {len(chunk_data)} audio chunks...") transcript_results = self._transcribe_chunks_parallel(chunk_data, input_data) successful_chunks = [r for r in transcript_results if r['success']] full_text = ' '.join([r['text'] for r in successful_chunks if r['text'] and '[INAUDIBLE]' not in r['text'] and 'ERROR' not in r['text']]).strip() total_c = len(transcript_results) successful_c = len(successful_chunks) success_rate = successful_c / total_c if total_c > 0 else 0.0 final_result = { 'full_transcript': full_text, 'word_count': len(full_text.split()), 'total_chunks': total_c, 'successful_chunks': successful_c, 'success_rate': success_rate, 'extraction_timestamp': time.time(), 'extraction_date': time.strftime('%Y-%m-%d %H:%M:%S'), 'detailed_results': transcript_results } self._save_to_cache(cache_path, final_result, cache_enabled) print(f"Transcript extraction completed. Success rate: {success_rate:.1%}") return final_result except Exception as e: print(f"Error during transcript extraction: {str(e)}") return {'error': str(e), 'full_transcript': '', 'success_rate': 0.0} def _run(self, youtube_url: str, **kwargs: Any) -> str: input_data = {'youtube_url': youtube_url, **kwargs} if not youtube_url: return "Error: youtube_url is required." try: video_hash = self._get_video_hash(youtube_url) print(f"Processing YouTube URL: {youtube_url} (Hash: {video_hash})") audio_path = self.download_youtube_audio(youtube_url, video_hash, input_data) if not audio_path or not os.path.exists(audio_path): return "Error: Failed to download YouTube audio. Check URL or authentication (cookies)." print("Extracting audio transcript...") transcript_result = self.extract_transcript(audio_path, video_hash, input_data) if transcript_result.get("error"): return f"Error: {transcript_result['error']}" main_transcript = transcript_result.get('full_transcript', '') if not main_transcript: return "Error: No transcript could be extracted." print(f"Transcript extracted. Word count: {transcript_result.get('word_count',0)}. Success: {transcript_result.get('success_rate',0):.1%}") return "TRANSCRIPT: " + main_transcript except Exception as e: print(f"Unhandled error in _run: {str(e)}") # For debugging return f"Error during transcript extraction: {str(e)}" # Factory function to create the tool def create_youtube_transcript_tool(**kwargs): """Factory function to create the transcript extraction tool with custom parameters""" return YouTubeTranscriptExtractor(**kwargs) # --- Model Configuration --- def create_llm_pipeline(): #model_id = "meta-llama/Llama-2-13b-chat-hf" #model_id = "meta-llama/Llama-3.3-70B-Instruct" #model_id = "mistralai/Mistral-Small-24B-Base-2501" model_id = "mistralai/Mistral-7B-Instruct-v0.3" #model_id = "Meta-Llama/Llama-2-7b-chat-hf" #model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO" #model_id = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF" #model_id = "mistralai/Mistral-7B-Instruct-v0.2" #model_id = "Qwen/Qwen2-7B-Instruct" #model_id = "GSAI-ML/LLaDA-8B-Instruct" return pipeline( "text-generation", model=model_id, device_map="cpu", torch_dtype=torch.float16, max_new_tokens=1024, temperature=0.3, top_k=50, top_p=0.95 ) nlp = None # Set to None if not using spaCy, so the regex fallback is used in extract_entities # --- Agent State Definition --- class AgentState(TypedDict): messages: Annotated[List[AnyMessage], lambda x, y: x + y] done: bool = False # Default value of False question: str task_id: str input_file: Optional[bytes] file_type: Optional[str] context: List[Document] # Using LangChain's Document class file_path: Optional[str] youtube_url: Optional[str] answer: Optional[str] frame_answers: Optional[list] # --- Define Call LLM function --- # 3. Improved LLM call with memory management def call_llm_with_memory_management(state: AgentState, llm_model) -> AgentState: """Enhanced LLM call with better prompt engineering and hallucination prevention.""" print("Running call_llm with memory management...") #ipdb.set_trace() original_messages = messages_for_llm = state["messages"] # Context management - be more aggressive about truncation system_message_content = None if messages_for_llm and isinstance(messages_for_llm[0], SystemMessage): system_message_content = messages_for_llm[0] regular_messages = messages_for_llm[1:] else: regular_messages = messages_for_llm # Keep only the most recent messages (more aggressive) max_regular_messages = 6 # Reduced from 9 if len(regular_messages) > max_regular_messages: print(f"๐Ÿ”„ Truncating to {max_regular_messages} recent messages") regular_messages = regular_messages[-max_regular_messages:] # Reconstruct for LLM messages_for_llm = [] if system_message_content: messages_for_llm.append(system_message_content) messages_for_llm.extend(regular_messages) # Character limit check total_chars = sum(len(str(msg.content)) for msg in messages_for_llm) char_limit = 20000 if total_chars > char_limit: print(f"๐Ÿ“ Context too long ({total_chars} chars) - further truncation") while regular_messages and sum(len(str(m.content)) for m in regular_messages) > char_limit - (len(str(system_message_content.content)) if system_message_content else 0): regular_messages.pop(0) messages_for_llm = [] if system_message_content: messages_for_llm.append(system_message_content) messages_for_llm.extend(regular_messages) new_state = state.copy() try: if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"๐Ÿค– Calling LLM with {len(messages_for_llm)} messages") # Convert to simple string format that the model can understand if len(messages_for_llm) == 2 and isinstance(messages_for_llm[0], SystemMessage) and isinstance(messages_for_llm[1], HumanMessage): # Initial query - use simple format system_content = messages_for_llm[0].content human_content = messages_for_llm[1].content formatted_input = f"{system_content}\n\nHuman: {human_content}\n\nAssistant:" else: # Ongoing conversation - build context formatted_input = "" # Add system message if present if system_message_content: formatted_input += f"{system_message_content.content}\n\n" # Add conversation messages for msg in regular_messages: if isinstance(msg, HumanMessage): formatted_input += f"Human: {msg.content}\n\n" elif isinstance(msg, AIMessage): formatted_input += f"Assistant: {msg.content}\n\n" elif isinstance(msg, ToolMessage): formatted_input += f"Tool Result: {msg.content}\n\n" # Add explicit instruction for immediate final answer if we have recent tool results if any(isinstance(msg, ToolMessage) for msg in regular_messages[-2:]): formatted_input += "Based on the tool results above, provide your FINAL ANSWER now.\n\n" formatted_input += "REMINDER ON ANSWER FORMAT: \n" formatted_input += "- Numbers: no commas, no units unless specified\n" formatted_input += "- Strings: no articles, no abbreviations, digits in plain text\n" formatted_input += "- Lists: comma-separated following above rules\n" formatted_input += "- Be extremely brief and concise" formatted_input += "Assistant:" print(f"Input preview: {formatted_input[:300]}...") llm_response_object = llm_model.invoke(formatted_input) # Process response and clean up hallucinated content if isinstance(llm_response_object, BaseMessage): raw_content = llm_response_object.content elif hasattr(llm_response_object, 'content'): raw_content = str(llm_response_object.content) else: raw_content = str(llm_response_object) # Clean up the response to prevent hallucinated follow-up questions cleaned_content = clean_llm_response(raw_content) ai_message_response = AIMessage(content=cleaned_content) print(f"๐Ÿ” LLM Response preview: {cleaned_content[:200]}...") final_messages = original_messages + [ai_message_response] new_state["messages"] = final_messages new_state.pop("done", None) except Exception as e: print(f"โŒ LLM call failed: {e}") error_message = AIMessage(content=f"Error: LLM call failed - {str(e)}") new_state["messages"] = original_messages + [error_message] new_state["done"] = True finally: if torch.cuda.is_available(): torch.cuda.empty_cache() return new_state def clean_llm_response(response_text: str) -> str: """ Clean LLM response to prevent hallucinated follow-up questions and conversations. Specifically handles ReAct format: Thought: -> Action: -> Action Input: """ if not response_text: return response_text print(f"Initial response: {response_text[:200]}...") # --- START MODIFICATION --- # Isolate the text generated by the assistant in the last turn. # This prevents parsing examples or instructions from the preamble. assistant_marker = "Assistant:" last_marker_idx = response_text.rfind(assistant_marker) text_to_process = response_text # Default to full text if marker not found if last_marker_idx != -1: # If "Assistant:" is found, process only the text after the last occurrence. text_to_process = response_text[last_marker_idx + len(assistant_marker):].strip() print(f"โ„น๏ธ Parsing content after last 'Assistant:': {text_to_process[:200]}...") else: # If "Assistant:" is not found, process the whole input. # This might occur if the input is already just the assistant's direct response # or if the prompt structure is different. print(f"โ„น๏ธ No 'Assistant:' marker found. Processing entire input as is.") # --- END MODIFICATION --- # Now, all subsequent operations use 'text_to_process' # Try to find a complete ReAct pattern in the assistant's actual output react_pattern = r'Thought:\s*(.*?)\s*Action:\s*([^\n\r]+)\s*Action Input:\s*(.*?)(?=\s*(?:Thought:|Action:|FINAL ANSWER:|$))' # Apply search to 'text_to_process' react_match = re.search(react_pattern, text_to_process, re.DOTALL | re.IGNORECASE) if react_match: thought_text = react_match.group(1).strip() action_name = react_match.group(2).strip() action_input = react_match.group(3).strip() # Clean up the action input - remove any trailing content that looks like instructions action_input_clean = re.sub(r'\s*(When you have|FINAL ANSWER|ANSWER FORMAT|IMPORTANT:).*$', '', action_input, flags=re.DOTALL | re.IGNORECASE) action_input_clean = action_input_clean.strip() react_sequence = f"Thought: {thought_text}\nAction: {action_name}\nAction Input: {action_input_clean}" print(f"๐Ÿ”ง Found ReAct pattern - Action: {action_name}, Input: {action_input_clean[:100]}...") # Check if there's a FINAL ANSWER after the action input (this would be hallucination) # Check in the remaining part of 'text_to_process' remaining_text_in_process = text_to_process[react_match.end():] final_answer_after = re.search(r'FINAL ANSWER:', remaining_text_in_process, re.IGNORECASE) if final_answer_after: print(f"๐Ÿšซ Removed hallucinated FINAL ANSWER after tool call") return react_sequence # If no ReAct pattern in 'text_to_process', check for standalone FINAL ANSWER # This variable will hold the text being processed for FINAL ANSWER and then for fallback. current_text_for_processing = text_to_process final_answer_match = re.search(r"FINAL ANSWER:\s*(.+?)(?=\n|$)", current_text_for_processing, re.IGNORECASE) if final_answer_match: answer_content = final_answer_match.group(1).strip() template_phrases = [ '[concise answer only]', '[concise answer - number/word/list only]', '[brief answer]', '[your answer here]', 'concise answer only', 'brief answer', 'your answer here' ] if any(phrase.lower() in answer_content.lower() for phrase in template_phrases): print(f"๐Ÿšซ Ignoring template FINAL ANSWER: {answer_content}") # Remove the template FINAL ANSWER and continue cleaning on the remainder of 'current_text_for_processing' current_text_for_processing = current_text_for_processing[:final_answer_match.start()].strip() # Fall through to the general cleanup section below else: # Keep everything from the start of 'current_text_for_processing' up to and including the real FINAL ANSWER line only cleaned_output = current_text_for_processing[:final_answer_match.end()] # Check if there's additional content after FINAL ANSWER in 'current_text_for_processing' remaining_after_final_answer = current_text_for_processing[final_answer_match.end():].strip() if remaining_after_final_answer: print(f"๐Ÿšซ Removed content after FINAL ANSWER: {remaining_after_final_answer[:100]}...") return cleaned_output.strip() # If no ReAct, or FINAL ANSWER was a template or not found, apply fallback cleaning to 'current_text_for_processing' lines = current_text_for_processing.split('\n') cleaned_lines = [] for i, line in enumerate(lines): # Stop if we see the model repeating system instructions if re.search(r'\[SYSTEM\]|\[HUMAN\]|\[ASSISTANT\]|\[TOOL\]', line, re.IGNORECASE): print(f"๐Ÿšซ Stopped at repeated system format: {line}") break # Stop if we see the model generating format instructions if re.search(r'CRITICAL INSTRUCTIONS|FORMAT for tool use|ANSWER FORMAT', line, re.IGNORECASE): print(f"๐Ÿšซ Stopped at repeated instructions: {line}") break # Stop if we see the model role-playing as a human asking questions if re.search(r'(what are|what is|how many|can you tell me)', line, re.IGNORECASE) and not line.strip().startswith(('Thought:', 'Action:', 'Action Input:')): # Make sure this isn't part of a legitimate thought process if i > 0 and not any(keyword in lines[i-1] for keyword in ['Thought:', 'Action:', 'need to']): print(f"๐Ÿšซ Stopped at hallucinated question: {line}") break cleaned_lines.append(line) cleaned = '\n'.join(cleaned_lines).strip() print(f"Final cleaned response (fallback): {cleaned[:200]}...") return cleaned def parse_react_output(state: AgentState) -> AgentState: """ Enhanced parsing with better FINAL ANSWER detection and flow control. """ print("Running parse_react_output...") #ipdb.set_trace() messages = state.get("messages", []) if not messages: print("No messages in state.") new_state = state.copy() new_state["done"] = True return new_state # DEBUG print(f"parse_react_output: Entry message count: {len(messages)}") if messages and hasattr(messages[-1], 'tool_calls'): print(f"parse_react_output: Number of tool calls in last AIMessage: {len(messages[-1].tool_calls)}") last_message = messages[-1] new_state = state.copy() if not isinstance(last_message, AIMessage): print("Last message is not an AIMessage instance.") return new_state content = last_message.content if not isinstance(content, str): content = str(content) # Look for FINAL ANSWER first - this should take absolute priority # Use a more precise regex to capture just the answer line final_answer_match = re.search(r"FINAL ANSWER:\s*([^\n\r]+)", content, re.IGNORECASE) if final_answer_match: final_answer_text = final_answer_match.group(1).strip() # Check if this is template text (not a real answer) template_phrases = [ '[concise answer only]', '[concise answer - number/word/list only]', '[brief answer]', '[your answer here]', 'concise answer only', 'brief answer', 'your answer here' ] # If it's template text, don't treat it as a final answer if any(phrase.lower() in final_answer_text.lower() for phrase in template_phrases): print(f"๐Ÿšซ Ignoring template FINAL ANSWER: '{final_answer_text}'") # Continue processing as if no final answer was found else: print(f"๐ŸŽฏ FINAL ANSWER found: '{final_answer_text}' - ENDING") # Store the answer in state for easy access new_state["answer"] = final_answer_text # Clean up the message content to just show the final answer clean_content = f"FINAL ANSWER: {final_answer_text}" updated_ai_message = AIMessage(content=clean_content, tool_calls=[]) new_state["messages"] = messages[:-1] + [updated_ai_message] new_state["done"] = True return new_state # If no FINAL ANSWER, look for tool calls action_match = re.search(r"Action:\s*([^\n]+)", content, re.IGNORECASE) action_input_match = re.search(r"Action Input:\s*(.+)", content, re.IGNORECASE | re.DOTALL) if action_match and action_input_match: tool_name = action_match.group(1).strip() tool_input_raw = action_input_match.group(1).strip() if tool_name.lower() == "none": print("Action is 'None' - treating as regular response") updated_ai_message = AIMessage(content=content, tool_calls=[]) new_state["messages"] = messages[:-1] + [updated_ai_message] new_state.pop("done", None) return new_state print(f"๐Ÿ”ง Tool call: {tool_name} with input: {tool_input_raw[:100]}...") # Parse tool arguments tool_args = {} try: trimmed_input = tool_input_raw.strip() if (trimmed_input.startswith('{') and trimmed_input.endswith('}')) or \ (trimmed_input.startswith('[') and trimmed_input.endswith(']')): tool_args = ast.literal_eval(trimmed_input) if not isinstance(tool_args, dict): tool_args = {"query": tool_input_raw} else: tool_args = {"query": tool_input_raw} except (ValueError, SyntaxError): tool_args = {"query": tool_input_raw} tool_call_id = str(uuid.uuid4()) parsed_tool_calls = [{"name": tool_name, "args": tool_args, "id": tool_call_id}] updated_ai_message = AIMessage(content=content, tool_calls=parsed_tool_calls) new_state["messages"] = messages[:-1] + [updated_ai_message] new_state.pop("done", None) return new_state # No tool call or final answer - treat as regular response print("No actionable content found - continuing conversation") updated_ai_message = AIMessage(content=content, tool_calls=[]) new_state["messages"] = messages[:-1] + [updated_ai_message] new_state.pop("done", None) # DEBUG print(f"parse_react_output: Exit message count: {len(new_state['messages'])}") return new_state # 4. Improved call_tool_with_memory_management to prevent duplicate processing def call_tool_with_memory_management(state: AgentState) -> AgentState: """Process tool calls with memory management, avoiding duplicates.""" print("Running call_tool with memory management...") # Clear CUDA cache before processing try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"๐Ÿงน Cleared CUDA cache. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") except ImportError: pass except Exception as e: print(f"Error clearing CUDA cache: {e}") # Check if we have parsed tool calls from the condition function if 'parsed_tool_calls' in state and state.get('parsed_tool_calls'): print("Executing parsed tool calls...") return execute_parsed_tool_calls(state) # Fallback to original OpenAI-style tool calls handling messages = state.get("messages", []) if not messages: print("No messages found in state.") return state last_message = messages[-1] if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: print("No tool calls found in last message") return state # Avoid processing the same tool calls multiple times if hasattr(last_message, '_processed_tool_calls'): print("Tool calls already processed, skipping...") return state # Copy the messages to avoid mutating the original list new_messages = list(messages) print(f"Processing {len(last_message.tool_calls)} tool calls from last message") # Get file_path from state to pass to tools file_path_to_pass = state.get('file_path') for i, tool_call_item in enumerate(last_message.tool_calls): # Handle both dict and object-style tool calls if isinstance(tool_call_item, dict): tool_name = tool_call_item.get("name", "") raw_args = tool_call_item.get("args") tool_call_id = tool_call_item.get("id", str(uuid.uuid4())) elif hasattr(tool_call_item, "name") and hasattr(tool_call_item, "id"): tool_name = getattr(tool_call_item, "name", "") raw_args = getattr(tool_call_item, "args", None) tool_call_id = getattr(tool_call_item, "id", str(uuid.uuid4())) else: print(f"Skipping malformed tool call item: {tool_call_item}") continue print(f"Processing tool call {i+1}: {tool_name}") # Find the matching tool selected_tool = None for tool_instance in tools: if tool_instance.name.lower() == tool_name.lower(): selected_tool = tool_instance break if not selected_tool: tool_result = f"Error: Tool '{tool_name}' not found. Available tools: {', '.join(t.name for t in tools)}" print(f"Tool not found: {tool_name}") else: try: # Prepare the arguments for the tool.run() method tool_run_input_dict = {} if isinstance(raw_args, dict): tool_run_input_dict = raw_args.copy() elif raw_args is not None: tool_run_input_dict["query"] = str(raw_args) # Add file_path to the dictionary for the tool tool_run_input_dict['file_path'] = file_path_to_pass print(f"Executing {tool_name} with args: {tool_run_input_dict} ...") tool_result = selected_tool.run(tool_run_input_dict) #ipdb.set_trace() # Aggressive truncation to prevent memory issues if not isinstance(tool_result, str): tool_result = str(tool_result) max_length = 18000 if "wikipedia" in tool_name.lower() else 18000 if len(tool_result) > max_length: original_length = len(tool_result) tool_result = tool_result[:max_length] + f"... [Result truncated from {original_length} to {max_length} chars to prevent memory issues]" print(f"๐Ÿ“„ Truncated result to {max_length} characters") print(f"Tool result length: {len(tool_result)} characters") except Exception as e: tool_result = f"Error executing tool '{tool_name}': {str(e)}" print(f"Tool execution error: {e}") # Create tool message - ONLY ONE PER TOOL CALL tool_message = ToolMessage( content=tool_result, name=tool_name, tool_call_id=tool_call_id ) new_messages.append(tool_message) print(f"Added tool message for {tool_name}") # Mark the last message as processed to prevent re-processing if hasattr(last_message, '__dict__'): last_message._processed_tool_calls = True # Update the state new_state = state.copy() new_state["messages"] = new_messages # Clear CUDA cache after processing try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"๐Ÿงน Cleared CUDA cache post-processing. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") except ImportError: pass except Exception as e: print(f"Error clearing CUDA cache post-processing: {e}") return new_state # 3. Enhanced execute_parsed_tool_calls to prevent duplicate observations def execute_parsed_tool_calls(state: AgentState): """ Execute tool calls that were parsed from the Thought/Action/Action Input format. This is called by call_tool when parsed_tool_calls are present in state. """ # Tool name mappings tool_name_mappings = { 'wikipedia_semantic_search': 'wikipedia_tool', 'wikipedia': 'wikipedia_tool', 'search': 'enhanced_search', 'duckduckgo_search': 'enhanced_search', 'web_search': 'enhanced_search', 'enhanced_search': 'enhanced_search', 'youtube_screenshot_qa_tool': 'youtube_tool', 'youtube': 'youtube_tool', 'youtube_transcript_extractor': 'youtube_transcript_extractor', 'youtube_audio_tool': 'youtube_transcript_extractor' } # Create a lookup by tool names tools_by_name = {} for tool in tools: tools_by_name[tool.name.lower()] = tool # Copy messages to avoid mutation during iteration new_messages = list(state["messages"]) # Process each tool call ONCE for tool_call in state['parsed_tool_calls']: action = tool_call['action'] action_input = tool_call['action_input'] normalized_action = tool_call['normalized_action'] print(f"๐Ÿš€ Executing tool: {action} with input: {action_input}") # Find the tool instance tool_instance = None if normalized_action in tools_by_name: tool_instance = tools_by_name[normalized_action] elif normalized_action in tool_name_mappings: mapped_name = tool_name_mappings[normalized_action] if mapped_name in tools_by_name: tool_instance = tools_by_name[mapped_name] if tool_instance: try: # Pass file_path if the tool expects it if hasattr(tool_instance, 'run'): if 'file_path' in tool_instance.run.__code__.co_varnames: result = tool_instance.run(action_input, file_path=state.get('file_path')) else: result = tool_instance.run(action_input) else: result = str(tool_instance) # Truncate long results if len(result) > 6000: result = result[:6000] + "... [Result truncated due to length]" # Create a SINGLE observation message from langchain_core.messages import ToolMessage tool_message = ToolMessage( content=f"Observation: {result}", name=action, tool_call_id=str(uuid.uuid4()) ) new_messages.append(tool_message) print(f"โœ… Tool '{action}' executed successfully") except Exception as e: print(f"โŒ Error executing tool '{action}': {e}") from langchain_core.messages import ToolMessage error_message = ToolMessage( content=f"Observation: Error executing '{action}': {str(e)}", name=action, tool_call_id=str(uuid.uuid4()) ) new_messages.append(error_message) else: print(f"โŒ Tool '{action}' not found in available tools") available_tool_names = list(tools_by_name.keys()) from langchain_core.messages import ToolMessage error_message = ToolMessage( content=f"Observation: Tool '{action}' not found. Available tools: {', '.join(available_tool_names)}", name=action, tool_call_id=str(uuid.uuid4()) ) new_messages.append(error_message) # Update state with new messages and clear parsed tool calls new_state = state.copy() new_state["messages"] = new_messages new_state['parsed_tool_calls'] = [] # Clear to prevent re-execution return new_state # 1. Add loop detection to your AgentState def should_continue(state: AgentState) -> str: """Enhanced continuation logic with better limits.""" print("Running should_continue...") # Check done flag first if state.get("done", False): print("โœ… Done flag is True - ending") return "end" messages = state["messages"] # More aggressive message limit #if len(messages) > 20: # Reduced from 15 # print(f"โš ๏ธ Message limit reached ({len(messages)}/20) - forcing end") # return "end" # Check for repeated patterns (stuck in loop) if len(messages) >= 6: recent_contents = [str(msg.content)[:100] for msg in messages[-6:] if hasattr(msg, 'content')] if len(set(recent_contents)) < 3: # Too much repetition print("๐Ÿ”„ Detected repetitive pattern - ending") return "end" print(f"๐Ÿ“Š Continuing... ({len(messages)} messages so far)") return "continue" def route_after_parse_react(state: AgentState) -> str: """Determines the next step after parsing LLM output, prioritizing end state.""" if state.get("done", False): # Check if parse_react_output decided we are done return "end_processing" # Original logic: check for tool calls in the last message # Ensure messages list and last message exist before checking tool_calls messages = state.get("messages", []) if messages: last_message = messages[-1] if hasattr(last_message, 'tool_calls') and last_message.tool_calls: return "call_tool" return "call_llm" # --- Graph Construction --- # --- Graph Construction --- def create_memory_safe_workflow(): """Create a workflow with memory management and loop prevention.""" # These models are initialized here but might be better managed if they need to be released/reinitialized # like you attempt in run_agent. Consider passing them or managing their lifecycle carefully. hf_pipe = create_llm_pipeline() llm = HuggingFacePipeline(pipeline=hf_pipe) # vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly # processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used # model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used workflow = StateGraph(AgentState) # Bind the llm_model to the call_llm_with_memory_management function bound_call_llm = partial(call_llm_with_memory_management, llm_model=llm) # Add nodes with memory-safe versions workflow.add_node("call_llm", bound_call_llm) # Use the bound version here workflow.add_node("parse_react_output", parse_react_output) workflow.add_node("call_tool", call_tool_with_memory_management) # Ensure this doesn't also need llm if it calls back directly # Set entry point workflow.set_entry_point("call_llm") # Add conditional edges workflow.add_conditional_edges( "call_llm", should_continue, { "continue": "parse_react_output", "end": END } ) workflow.add_conditional_edges( "parse_react_output", route_after_parse_react, { "call_tool": "call_tool", "call_llm": "call_llm", "end_processing": END } ) workflow.add_edge("call_tool", "call_llm") return workflow.compile() def count_english_words(text): # Remove punctuation, lowercase, split into words table = str.maketrans('', '', string.punctuation) words_in_text = text.translate(table).lower().split() return sum(1 for word in words_in_text if word in english_words) def fix_backwards_text(text): reversed_text = text[::-1] original_count = count_english_words(text) reversed_count = count_english_words(reversed_text) if reversed_count > original_count: return reversed_text else: return text # --- Run the Agent --- # Enhanced system prompt for better behavior def run_agent(agent, state: AgentState): """Enhanced agent initialization with better prompt and hallucination prevention.""" global WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, YOUTUBE_AUDIO_TOOL, AUDIO_TRANSCRIPTION_TOOL, EXCEL_TOOL, PYTHON_TOOL, COMMUTATIVITY_TOOL, tools # Initialize tools WIKIPEDIA_TOOL = WikipediaSearchToolWithFAISS() SEARCH_TOOL = EnhancedDuckDuckGoSearchTool(max_results=3, max_chars_per_page=18000) YOUTUBE_TOOL = EnhancedYoutubeScreenshotQA() YOUTUBE_AUDIO_TOOL = YouTubeTranscriptExtractor() AUDIO_TRANSCRIPTION_TOOL = AudioTranscriptionTool() EXCEL_TOOL = ExcelReaderTool() PYTHON_TOOL = PythonExecutorTool() COMMUTATIVITY_TOOL = CommutativityAnalysisTool() tools = [WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_AUDIO_TOOL, YOUTUBE_TOOL, AUDIO_TRANSCRIPTION_TOOL, EXCEL_TOOL, PYTHON_TOOL, COMMUTATIVITY_TOOL] formatted_tools_description = render_text_description(tools) current_date_str = datetime.now().strftime("%Y-%m-%d") # Enhanced system prompt with stricter boundaries system_content = f"""You are an AI assistant with access to these tools: {formatted_tools_description} CRITICAL INSTRUCTIONS: 1. Answer ONLY the question asked by the human 2. Do NOT generate additional questions or continue conversations 3. Use tools ONLY when you need specific information you don't know 4. After using a tool, provide your FINAL ANSWER immediately 5. STOP after giving your FINAL ANSWER - do not continue 6. Do not repeat words in the question in the answer FORMAT for tool use: Thought: Action: Action Input: When you have the answer, immediately provide: FINAL ANSWER: [concise answer only] ANSWER FORMAT: - Numbers: no commas, no units unless specified - Questions on "how many" should be answered with a number ONLY - Strings: no articles, no abbreviations, digits in plain text - Lists: comma-separated either in ascending numeric order or alphabetical order as requested - Be extremely brief and concise - Do not provide additional context or explanations - Do not provide parentheticals IMPORTANT: You are responding to ONE question only. Do not ask follow-up questions or generate additional dialogue. Current date: {current_date_str} """ query = fix_backwards_text(state['question']) # Check for YouTube URLs yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" if re.search(yt_pattern, query): url_match = re.search(r"(https?://[^\s]+)", query) if url_match: state['youtube_url'] = url_match.group(0) # Initialize messages system_message = SystemMessage(content=system_content) human_message = HumanMessage(content=query) state['messages'] = [system_message, human_message] state["done"] = False # Run the agent result = agent.invoke(state) # Cleanup if result.get("done"): #torch.cuda.empty_cache() #torch.cuda.ipc_collect() gc.collect() print("๐Ÿงน Released GPU memory after completion") return result["messages"]