| import os |
| from typing import List, Dict, Any, Optional |
| from dotenv import load_dotenv |
|
|
| |
| load_dotenv() |
|
|
| from langchain.agents import AgentType, initialize_agent, Tool |
| from langchain.memory import ConversationBufferWindowMemory, ConversationSummaryBufferMemory |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage |
| from langchain_google_genai import ChatGoogleGenerativeAI |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint |
| from langchain_groq import ChatGroq |
| from langchain_community.tools.tavily_search import TavilySearchResults |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader |
| from langchain_core.tools import tool |
| from langchain.prompts import PromptTemplate |
| from langchain.chains import LLMChain |
|
|
| |
| GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') |
| HUGGINGFACE_API_TOKEN = os.getenv('HUGGINGFACE_API_TOKEN') |
| GROQ_API_KEY = os.getenv('GROQ_API_KEY') |
| TAVILY_API_KEY = os.getenv('TAVILY_API_KEY') |
|
|
|
|
| @tool |
| def calculator_tool(expression: str) -> str: |
| """Perform mathematical calculations and evaluate expressions |
| |
| Args: |
| expression: A mathematical expression to evaluate (e.g., "2+2", "25*34", "sqrt(16)", "sin(0.5)") |
| |
| Returns: |
| Result of the mathematical expression |
| """ |
| try: |
| import math |
| import re |
| |
| |
| expression = expression.strip() |
| |
| |
| expression = expression.replace('sqrt', 'math.sqrt') |
| expression = expression.replace('sin', 'math.sin') |
| expression = expression.replace('cos', 'math.cos') |
| expression = expression.replace('tan', 'math.tan') |
| expression = expression.replace('log', 'math.log') |
| expression = expression.replace('ln', 'math.log') |
| expression = expression.replace('log10', 'math.log10') |
| expression = expression.replace('pi', 'math.pi') |
| expression = expression.replace('e', 'math.e') |
| expression = expression.replace('^', '**') |
| expression = expression.replace('pow', '**') |
| |
| |
| safe_pattern = r'^[0-9+\-*/.() mathsqrtsincolgtanpienpow]+$' |
| if re.match(safe_pattern, expression.replace(' ', '')): |
| |
| safe_dict = { |
| "__builtins__": {}, |
| "math": math, |
| "abs": abs, |
| "round": round, |
| "min": min, |
| "max": max |
| } |
| result = eval(expression, safe_dict) |
| return str(result) |
| else: |
| return "Error: Invalid characters in expression. Use only numbers and basic math operations." |
| |
| except ZeroDivisionError: |
| return "Error: Cannot divide by zero" |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
| @tool |
| def wikipedia_search_tool(query: str) -> str: |
| """Search Wikipedia for information on any topic |
| |
| Args: |
| query: The search query for Wikipedia |
| |
| Returns: |
| Formatted Wikipedia search results |
| """ |
| try: |
| search_docs = WikipediaLoader(query=query, load_max_docs=2).load() |
| formatted_results = "\n\n---\n\n".join([ |
| f'Source: {doc.metadata["source"]}\nPage: {doc.metadata.get("page", "")}\n\nContent:\n{doc.page_content[:2000]}...' |
| for doc in search_docs |
| ]) |
| return formatted_results |
| except Exception as e: |
| return f"Error searching Wikipedia: {str(e)}" |
|
|
| @tool |
| def web_search_tool(query: str) -> str: |
| """Search the web for current information using Tavily |
| |
| Args: |
| query: The search query for web search |
| |
| Returns: |
| Formatted web search results |
| """ |
| try: |
| if not TAVILY_API_KEY: |
| return "Error: TAVILY_API_KEY not found in environment variables" |
| |
| search_results = TavilySearchResults(max_results=3, api_key=TAVILY_API_KEY).invoke(query) |
| formatted_results = "\n\n---\n\n".join([ |
| f'Source: {result.get("url", "")}\n\nContent:\n{result.get("content", "")}' |
| for result in search_results |
| ]) |
| return formatted_results |
| except Exception as e: |
| return f"Error searching web: {str(e)}" |
|
|
| @tool |
| def arxiv_search_tool(query: str) -> str: |
| """Search ArXiv for academic papers and research |
| |
| Args: |
| query: The search query for ArXiv |
| |
| Returns: |
| Formatted ArXiv search results |
| """ |
| try: |
| search_docs = ArxivLoader(query=query, load_max_docs=3).load() |
| formatted_results = "\n\n---\n\n".join([ |
| f'Source: {doc.metadata["source"]}\nTitle: {doc.metadata.get("Title", "")}\n\nContent:\n{doc.page_content[:1500]}...' |
| for doc in search_docs |
| ]) |
| return formatted_results |
| except Exception as e: |
| return f"Error searching ArXiv: {str(e)}" |
|
|
| class LangChainAgent: |
| """Multi-purpose LangChain agent with various capabilities.""" |
| |
| def __init__(self, provider: str = "groq"): |
| """Initialize the LangChain agent with specified LLM provider.""" |
| self.provider = provider |
| self.llm = self._get_llm(provider) |
| self.tools = self._initialize_tools() |
| |
| self.memory = ConversationSummaryBufferMemory( |
| llm=self.llm, |
| memory_key="chat_history", |
| return_messages=True, |
| max_token_limit=2000, |
| moving_summary_buffer="The human and AI are having a conversation about various topics." |
| ) |
| self.agent = self._create_agent() |
| |
| def _get_llm(self, provider: str): |
| """Get the specified LLM.""" |
| if provider == "groq": |
| if not GROQ_API_KEY: |
| raise ValueError("GROQ_API_KEY not found in environment variables") |
| return ChatGroq( |
| model="llama-3.3-70b-versatile", |
| temperature=0.1, |
| max_tokens=8192, |
| api_key=GROQ_API_KEY, |
| streaming=False |
| ) |
| elif provider == "google": |
| if not GOOGLE_API_KEY: |
| raise ValueError("GOOGLE_API_KEY not found in environment variables") |
| return ChatGoogleGenerativeAI( |
| model="gemini-1.5-flash", |
| temperature=0, |
| max_tokens=2048, |
| google_api_key=GOOGLE_API_KEY |
| ) |
| elif provider == "huggingface": |
| if not HUGGINGFACE_API_TOKEN: |
| raise ValueError("HUGGINGFACE_API_TOKEN not found in environment variables") |
| return ChatHuggingFace( |
| llm=HuggingFaceEndpoint( |
| repo_id="microsoft/DialoGPT-medium", |
| temperature=0, |
| max_length=2048, |
| huggingfacehub_api_token=HUGGINGFACE_API_TOKEN |
| ), |
| ) |
| else: |
| raise ValueError("Invalid provider. Choose 'groq', 'google' or 'huggingface'.") |
| |
| def _initialize_tools(self) -> List[Tool]: |
| """Initialize all available tools.""" |
| return [ |
| calculator_tool, |
| wikipedia_search_tool, |
| web_search_tool, |
| arxiv_search_tool, |
| ] |
| |
| def _create_agent(self): |
| """Create the LangChain agent with tools.""" |
| try: |
| return initialize_agent( |
| tools=self.tools, |
| llm=self.llm, |
| agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, |
| memory=self.memory, |
| verbose=True, |
| handle_parsing_errors=True, |
| max_iterations=4, |
| early_stopping_method="generate", |
| return_intermediate_steps=False |
| ) |
| except Exception as e: |
| print(f"Error creating agent: {e}") |
| |
| return None |
| |
| def _determine_approach(self, question: str) -> str: |
| """Determine the best approach for answering the question.""" |
| question_lower = question.lower() |
| |
| |
| math_keywords = ['calculate', 'compute', 'add', 'subtract', 'multiply', 'divide', 'math', 'equation', '+', '-', '*', '/', '%'] |
| if any(keyword in question_lower for keyword in math_keywords): |
| return 'calculation' |
| |
| |
| research_keywords = ['search', 'find', 'research', 'information', 'what is', 'who is', 'when', 'where', 'how', 'why'] |
| if any(keyword in question_lower for keyword in research_keywords): |
| return 'research' |
| |
| |
| academic_keywords = ['paper', 'study', 'research', 'academic', 'scientific', 'arxiv', 'journal'] |
| if any(keyword in question_lower for keyword in academic_keywords): |
| return 'academic' |
| |
| return 'general' |
| |
| def __call__(self, question: str) -> str: |
| """Process a question and return an answer.""" |
| try: |
| print(f"Processing question: {question[:100]}...") |
| |
| |
| if self.agent is None: |
| print("Agent not available, using direct LLM response") |
| try: |
| response = self.llm.invoke([HumanMessage(content=question)]) |
| return response.content |
| except Exception as llm_error: |
| return f"Error: Unable to process question. {str(llm_error)}" |
| |
| |
| approach = self._determine_approach(question) |
| print(f"Selected approach: {approach}") |
| |
| |
| if approach == 'calculation': |
| enhanced_question = f""" |
| You are a mathematical assistant. Solve this problem step by step: |
| |
| {question} |
| |
| IMPORTANT: Use the calculator_tool for ALL mathematical calculations, even simple ones. |
| Examples: |
| - For "25 * 34", use: calculator_tool("25 * 34") |
| - For "sqrt(16)", use: calculator_tool("sqrt(16)") |
| - For "2 + 2", use: calculator_tool("2 + 2") |
| |
| Always show your work and use the tools provided. |
| """ |
| elif approach == 'research': |
| enhanced_question = f""" |
| You are a research assistant. Provide comprehensive information about: |
| |
| {question} |
| |
| IMPORTANT: Use the appropriate search tools to gather information: |
| - wikipedia_search_tool("your search query") for general knowledge |
| - web_search_tool("your search query") for current information |
| - arxiv_search_tool("your search query") for academic papers |
| |
| Always cite your sources and provide detailed explanations. |
| """ |
| elif approach == 'academic': |
| enhanced_question = f""" |
| You are an academic research assistant. Find scholarly information about: |
| |
| {question} |
| |
| IMPORTANT: Use research tools to find information: |
| - arxiv_search_tool("your search query") for academic papers |
| - wikipedia_search_tool("your search query") for background information |
| |
| Provide citations and summarize key findings. |
| """ |
| else: |
| enhanced_question = f""" |
| You are a helpful assistant. Answer this question comprehensively: |
| |
| {question} |
| |
| IMPORTANT: Use the appropriate tools as needed: |
| - calculator_tool("expression") for mathematical calculations |
| - wikipedia_search_tool("query") for general information |
| - web_search_tool("query") for current information |
| - arxiv_search_tool("query") for academic research |
| |
| Always use tools when they can help provide better answers. |
| """ |
| |
| |
| result = self.agent.run(enhanced_question) |
| |
| print(f"Generated answer: {str(result)[:200]}...") |
| return str(result) |
| |
| except Exception as e: |
| error_msg = f"Error processing question: {str(e)}" |
| print(error_msg) |
| |
| try: |
| |
| fallback_result = self.llm.invoke([HumanMessage(content=question)]) |
| return fallback_result.content |
| except Exception as fallback_error: |
| return f"Error: Unable to process question. {str(e)}" |
|
|
| def reset_memory(self): |
| """Reset the conversation memory.""" |
| self.memory.clear() |
|
|
| |
| def test_langchain_agent(): |
| """Test the LangChain agent with sample questions.""" |
| print("Testing LangChain Agent with Groq Llama...") |
| agent = LangChainAgent(provider="groq") |
| |
| test_questions = [ |
| "What is 25 * 34 + 100?", |
| "Who was Albert Einstein and what were his major contributions?", |
| "Search for recent developments in artificial intelligence", |
| "What is quantum computing?" |
| ] |
| |
| for question in test_questions: |
| print(f"\nQuestion: {question}") |
| print("-" * 50) |
| answer = agent(question) |
| print(f"Answer: {answer}") |
| print("=" * 80) |
| agent.reset_memory() |
|
|
| if __name__ == "__main__": |
| test_langchain_agent() |
|
|