|
from langchain_community.chat_models import ChatOllama |
|
from langgraph.graph import MessagesState, StateGraph, START, END |
|
from langchain_core.messages import SystemMessage, HumanMessage |
|
from langchain_community.tools import DuckDuckGoSearchRun |
|
from langchain_core.tools import tool |
|
from langgraph.prebuilt import ToolNode |
|
from langchain_community.document_loaders import WikipediaLoader |
|
from langgraph.prebuilt import tools_condition |
|
from langchain_huggingface import HuggingFaceEndpoint |
|
from langchain_huggingface import ChatHuggingFace |
|
from langchain.llms import HuggingFaceHub |
|
import os |
|
from huggingface_hub import login |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HF_TOKEN") |
|
|
|
@tool |
|
def use_search_tool(query: str) -> str: |
|
"""Use the search tool to find information. |
|
|
|
Args: query (str): The search query. |
|
Returns: str: The search result. |
|
""" |
|
search_result = DuckDuckGoSearchRun(verbose=0).run(query) |
|
return {"messages": search_result} |
|
|
|
@tool |
|
def use_wikipedia_tool(query: str) -> str: |
|
"""Fetch a summary from Wikipedia. |
|
|
|
Args: |
|
query (str): The topic to search on Wikipedia. |
|
Returns: |
|
str: A summary of the topic from Wikipedia. |
|
""" |
|
result = WikipediaLoader(query=query, load_max_docs=2).load() |
|
if result: |
|
return {"messages": result} |
|
else: |
|
return {"messages" : f"Sorry, I couldn't find any information on '{query}' in Wikipedia."} |
|
|
|
def build_agent(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
llm = HuggingFaceHub(repo_id="openai-community/gpt2-medium", task="text-generation", |
|
model_kwargs={ |
|
"temperature": 0.7, |
|
"max_new_tokens": 100 |
|
}, |
|
verbose=True) |
|
|
|
tools = [use_wikipedia_tool, use_search_tool] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
system_template = ( |
|
"You are a helpful assistant tasked with answering questions using a set of tools. " |
|
"""Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: |
|
FINAL ANSWER: [YOUR FINAL ANSWER]. |
|
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. |
|
Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """ |
|
) |
|
|
|
def call_model(state: MessagesState): |
|
"""Call the LLM with the given state.""" |
|
messages = [SystemMessage(content=system_template)] + state["messages"] |
|
response = llm.invoke(messages) |
|
return {"messages" : response} |
|
|
|
workflow = StateGraph(MessagesState) |
|
workflow.add_node("Assistent", call_model) |
|
workflow.add_node("tools", ToolNode(tools)) |
|
workflow.add_edge(START, "Assistent") |
|
workflow.add_conditional_edges("Assistent", tools_condition) |
|
workflow.add_edge("tools", "Assistent") |
|
workflow.add_edge("Assistent", END) |
|
return workflow.compile() |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
graph = build_agent() |
|
input = HumanMessage(content="Hello, how are you?") |
|
response = graph.invoke(input) |
|
|
|
print(response) |
|
|
|
|
|
|
|
|