Spaces:
Sleeping
Sleeping
| import operator | |
| import os | |
| import time | |
| from typing import Optional | |
| from langchain.chat_models import init_chat_model | |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader, YoutubeLoader | |
| from langchain_community.tools import TavilySearchResults | |
| from langchain_core.messages import HumanMessage | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langgraph.graph import add_messages, START, END, StateGraph | |
| from langchain_core.tools import tool | |
| from langgraph.prebuilt import ToolNode | |
| from pydantic import SecretStr | |
| from langchain_custom import WikipediaTableLoader | |
| from typing_extensions import TypedDict, Annotated | |
| class State(TypedDict): | |
| messages: Annotated[list, add_messages] | |
| content_type: Optional[str] | |
| content: Optional[str] | |
| aggregate: Annotated[list, operator.add] | |
| # graph_state: str | |
| def get_llm(): | |
| os.getenv("GROQ_API_KEY") | |
| #return init_chat_model("llama-3.3-70b-versatile", model_provider="groq") | |
| return init_chat_model("gemini-2.0-flash", model_provider="google_genai") | |
| #return AzureChatOpenAI( | |
| # api_key=SecretStr(os.environ["AZURE_OPENAI_API_KEY"]), | |
| # azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], | |
| #azure_deployment="gpt-4o-mini", | |
| #api_version=os.environ["AZURE_OPENAI_API_VERSION"], | |
| #) | |
| def get_graph(llm): | |
| with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file: | |
| system_prompt = markdown_file.read() | |
| prompt_template = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| MessagesPlaceholder(variable_name="messages"), | |
| ] | |
| ) | |
| from langchain_community.retrievers import WikipediaRetriever | |
| from langchain_community.retrievers import TavilySearchAPIRetriever | |
| # Wikipedia retriever | |
| wiki_retriever = WikipediaRetriever() | |
| # Tavily retriever | |
| tavily_retriever = TavilySearchAPIRetriever(k=3) | |
| def multiply(a: int, b: int) -> int: | |
| """Multiply two numbers. | |
| Args: | |
| a: first int | |
| b: second int | |
| """ | |
| print("\n-------------------- Tool (Multiplication) has been called --------------------\n") | |
| return a * b | |
| def add(a: int, b: int) -> int: | |
| """Add two numbers. | |
| Args: | |
| a: first int | |
| b: second int | |
| """ | |
| print("\n-------------------- Tool (Addition) has been called --------------------\n") | |
| return a + b | |
| def subtract(a: int, b: int) -> int: | |
| """Subtract two numbers. | |
| Args: | |
| a: first int | |
| b: second int | |
| """ | |
| print("\n-------------------- Tool (Subtraction) has been called --------------------\n") | |
| return a - b | |
| def divide(a: int, b: int) -> float: | |
| """Divide two numbers. | |
| Args: | |
| a: first int | |
| b: second int | |
| """ | |
| print("\n-------------------- Tool (Division) has been called --------------------\n") | |
| if b == 0: | |
| raise ValueError("Cannot divide by zero.") | |
| return a / b | |
| def modulus(a: int, b: int) -> int: | |
| """Get the modulus of two numbers. | |
| Args: | |
| a: first int | |
| b: second int | |
| """ | |
| print("\n-------------------- Tool (Modulus) has been called --------------------\n") | |
| return a % b | |
| def retrieve(query: str): | |
| """ | |
| This function retrieves Wikipedia entries based on the query. | |
| """ | |
| print("\n-------------------- Tool (Wikipedia) has been called --------------------\n") | |
| print("The query is: ", query) | |
| docs = wiki_retriever.invoke(query) | |
| serialized = "\n\n".join( | |
| f"\nContent:\n{doc.page_content}" | |
| for doc in docs | |
| ) | |
| return serialized | |
| def wiki_search(query: str) -> str: | |
| """Search Wikipedia for a query and return maximum 2 results. | |
| Args: | |
| query: The search query.""" | |
| print("\n-------------------- Tool (Wikipedia) has been called --------------------\n") | |
| search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| parts: list[str] = [] | |
| for doc in search_docs: | |
| parts.append( | |
| f'<Document source="{doc.metadata["source"]}" ' | |
| f'title="{doc.metadata["title"]}" ' | |
| f'page="{doc.metadata.get("page", "")}">\n' | |
| f'{doc.page_content}\n</Document>' | |
| ) | |
| try: | |
| print("---------------------------------") | |
| print("Loading tables from: ", doc.metadata["source"]) | |
| print("---------------------------------") | |
| tables = WikipediaTableLoader(url=doc.metadata["source"], title=doc.metadata["title"]).load() | |
| for i, table in enumerate(tables): | |
| parts.append( | |
| f'<Document source="{table.metadata["source"]}" ' | |
| f'title="{table.metadata["title"]}" ' | |
| f'table_index="{i}">\n' | |
| f'{table.page_content}\n</Document>' | |
| ) | |
| except Exception: | |
| pass | |
| formatted_search_docs = "\n\n---\n\n".join(parts) | |
| return formatted_search_docs | |
| def wiki_table_search(url: str, title: str) -> str: | |
| """Get Wikipedia tables for a given URL and title. | |
| Args: | |
| url: The Wikipedia URL. | |
| title: The title of the Wikipedia page.""" | |
| print("\n-------------------- Tool (Wikipedia-Table) has been called --------------------\n") | |
| search_docs = WikipediaTableLoader(url=url, title=title).load() | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata["source"]}" title="{doc.metadata["title"]}" table_index={doc.metadata["table_index"]}/>\n{doc.page_content}\n</Document>' | |
| for doc in search_docs | |
| ]) | |
| return formatted_search_docs | |
| def online_search(query: str): | |
| """ | |
| This function does a web search based on the query. | |
| """ | |
| print("\n-------------------- Tool (Tavily) has been called --------------------\n") | |
| print("The query is: ", query) | |
| # docs = tavily_retriever.invoke(query) | |
| docs = TavilySearchResults(max_results=3).invoke({'query': query}) | |
| serialized = "\n\n".join( | |
| f"\nContent:\n{doc.page_content}" | |
| for doc in docs | |
| ) | |
| return serialized | |
| def web_search(query: str) -> str: | |
| """Search Tavily for a query and return maximum 3 results. | |
| Args: | |
| query: The search query.""" | |
| print("\n-------------------- Tool (Tavily) has been called --------------------\n") | |
| search_docs = TavilySearchResults(max_results=3).invoke({'query': query}) | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'URL: {doc["url"]}\nTitle= {doc["title"]}\nContent: {doc["content"]}' | |
| for doc in search_docs | |
| ]) | |
| return formatted_search_docs | |
| def arvix_search(query: str) -> str: | |
| """Search Arxiv for a query and return maximum 3 result. | |
| Args: | |
| query: The search query.""" | |
| print() | |
| search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' | |
| for doc in search_docs | |
| ]) | |
| return formatted_search_docs | |
| def youtube_transcript(url: str) -> str: | |
| """Download a transcript of a YouTube video. | |
| Args: | |
| url: URL of the YouTube video.""" | |
| print("\n-------------------- Tool (YouTube Transcript) has been called --------------------\n") | |
| loader = YoutubeLoader.from_youtube_url( | |
| url, add_video_info=False | |
| ) | |
| docs = loader.load() | |
| transcript = "\n\n".join( | |
| f"\nContent:\n{doc.page_content}" | |
| for doc in docs | |
| ) | |
| return transcript | |
| tools = [wiki_search, web_search, arvix_search, youtube_transcript, multiply, add, subtract, divide, modulus] | |
| tool_node = ToolNode(tools) | |
| llm_with_tools = llm.bind_tools(tools) | |
| def make_plan(state: State): | |
| print("\n-------------------- Starting to create a plan --------------------\n") | |
| print("Waiting for 5 seconds...") | |
| time.sleep(5) | |
| if "content_type" in state: | |
| print("Content is: ", state["content"]) | |
| # get all messages from the state | |
| messages = state["messages"] | |
| # append planning message | |
| messages.append(HumanMessage(content="Write a plan how to solve this qustion?")) | |
| # create prompt | |
| prompt = prompt_template.invoke(messages) | |
| # invoke LLM | |
| response = llm.invoke(prompt) | |
| print("The plan is: ", response.content) | |
| return {"messages": [response], "aggregate": ["Plan"]} | |
| def call_model(state: State): | |
| print("\n-------------------- Agent has been called -----------------------------------\n") | |
| print("Waiting for 5 seconds...") | |
| time.sleep(5) | |
| # get all messages from the state | |
| messages = state["messages"] | |
| # append instruction message | |
| messages.append(HumanMessage(content="Please provide me the answer to the question in detail.")) | |
| # create prompt | |
| prompt_answer = prompt_template.invoke(messages) | |
| # invoke LLM | |
| response = llm_with_tools.invoke(prompt_answer) | |
| print("Agent has made a decision:\n", response.content, response.tool_calls) | |
| return {"messages": [response], "aggregate": ["Agent"]} | |
| def get_answer(state: State): | |
| print("\n-------------------- Generating Answer -----------------------------------\n") | |
| print("Waiting for 5 seconds...") | |
| time.sleep(5) | |
| # get all messages from the state | |
| messages = state["messages"] | |
| # add prompt message | |
| messages.append(HumanMessage(content="Please provide me just the plain answer to the question")) | |
| # create prompt | |
| prompt_answer = prompt_template.invoke(messages) | |
| # invoke LLM | |
| response = llm.invoke(prompt_answer) | |
| print("The final answer is: ", response.content) | |
| return {"messages": [response], "aggregate": ["Answer"]} | |
| def should_continue(state: State): | |
| print("\n-------------------- Decision of forwarding has been made --------------------\n") | |
| print("Waiting for 2 seconds...") | |
| time.sleep(2) | |
| messages = state["messages"] | |
| print("This is round: ",len(state["aggregate"])) | |
| print("The last message is: ", messages[-1]) | |
| if len(state["aggregate"]) < 8: | |
| last_message = messages[-1] | |
| if last_message.tool_calls: | |
| return "tools" | |
| return "Answer" | |
| else: | |
| return "Answer" | |
| # Build graph | |
| builder = StateGraph(State) | |
| builder.add_node("tools", tool_node) | |
| builder.add_node("Plan", make_plan) | |
| builder.add_node("Agent", call_model) | |
| builder.add_node("Answer", get_answer) | |
| # Logic | |
| builder.add_edge(START, "Plan") | |
| builder.add_edge("Plan", "Agent") | |
| builder.add_conditional_edges("Agent", should_continue, ["tools", "Answer"]) | |
| builder.add_edge("tools", "Agent") | |
| builder.add_edge("Answer", END) | |
| return builder.compile() | |