Itz-Amethyst's picture
refactor: remove unimported packages
1706d7c unverified
import os
from dotenv import load_dotenv
from tools.python_interpreter import CodeInterpreter
interpreter_instance = CodeInterpreter()
from tools.image import *
"""Langraph"""
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_groq import ChatGroq
from langchain_huggingface import (
ChatHuggingFace,
HuggingFaceEndpoint,
HuggingFaceEmbeddings,
)
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.tools.retriever import create_retriever_tool
from supabase.client import Client, create_client
# ------- Tools
from tools.browse import web_search, wiki_search, arxiv_search
from tools.document_process import save_and_read_file, analyze_csv_file, analyze_excel_file, extract_text_from_image, download_file_from_url
from tools.image_tools import analyze_image, generate_simple_image , transform_image, draw_on_image, combine_images
from tools.simple_math import multiply, add, subtract, divide, modulus, power, square_root
from tools.python_interpreter import execute_code_lang
load_dotenv()
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
print(system_prompt)
# System message
sys_msg = SystemMessage(content=system_prompt)
# build a retriever
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2",
) # dim=768
supabase: Client = create_client(
os.environ.get("SUPABASE_URL_HUGGING_FACE"), os.environ.get("SUPABASE_SERVICE_ROLE_HUGGING_FACE")
)
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="documents",
query_name="match_documents_langchain",
)
create_retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="Question Search",
description="A tool to retrieve similar questions from a vector store.",
)
tools = [
web_search,
wiki_search,
arxiv_search,
multiply,
add,
subtract,
divide,
modulus,
power,
square_root,
save_and_read_file,
download_file_from_url,
extract_text_from_image,
analyze_csv_file,
analyze_excel_file,
execute_code_lang,
analyze_image,
transform_image,
draw_on_image,
generate_simple_image,
combine_images,
]
def build_graph(provider: str = "groq"):
if provider == "groq":
# Groq https://console.groq.com/docs/models
llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
# llm = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0)
elif provider == "huggingface":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
task="text-generation", # for chat‐style use “text-generation”
max_new_tokens=1024,
do_sample=False,
repetition_penalty=1.03,
temperature=0,
),
verbose=True,
)
else:
raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
llm_with_tools = llm.bind_tools(tools)
def assistant(state: MessagesState):
"""Assistant Node"""
return {"messages": [llm_with_tools.invoke(state['messages'])]}
def retriever(state: MessagesState):
"""Retriever Node"""
# Extract the latest message content
query = state['messages'][-1].content
similar_question = vector_store.similarity_search(query, k = 2)
if similar_question:
example_msg = HumanMessage(
content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
)
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
else:
return {"messages": [sys_msg] + state["messages"]}
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
if __name__ == "__main__":
question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
# question = """Q is Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec. What does Teal'c say in response to the question "Isn't that hot?"""
graph = build_graph(provider="groq")
messages = [HumanMessage(content=question)]
messages = graph.invoke({"messages": messages})
for m in messages["messages"]:
m.pretty_print()