tabesink92's picture
Shipgit add .
af39c18
# You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python)
import os
import chainlit as cl # importing chainlit for our app
from typing import Annotated, List
from dotenv import load_dotenv
from typing_extensions import List, TypedDict
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langgraph.graph import START, StateGraph, END
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from langchain_community.tools import TavilySearchResults
from langgraph.prebuilt.tool_node import ToolNode
from langgraph.graph.message import add_messages
from langchain_community.vectorstores import FAISS
from vectorstore import VectorStore
load_dotenv()
# Using OpenAI API for embeddings/llms
""" OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
os.environ["COHERE_API_KEY"] = COHERE_API_KEY """
# ------- Models/Tools ------- #
embed_model = HuggingFaceEmbeddings(
model_name="Snowflake/snowflake-arctic-embed-l",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
llm_sml = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
)
# ------- Prompts ------- #
rag_prompt = ChatPromptTemplate.from_template("""\
You are a helpful assistant who answers questions based on provided context. You must only use the provided context. Do NOT use your own knowledge.
if you don't know the answer, say so.
### Question
{question}
### Context
{context}
""")
# load documents and create vector store
vectorstore = VectorStore(
collection_name="mg_alloy_collection_snowflake",
)
documents = VectorStore.load_chunks_as_documents("data/contextual_chunks")
vectorstore.add_documents(documents)
retriever = vectorstore.as_retriever(k=5)
# ------- Pydantic Models ------- #
class State(TypedDict):
question: str
context: List[Document]
response: str
# ------- Functions ------- #
def generate(state):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = rag_prompt.format_messages(question=state["question"], context=docs_content)
response = llm_sml.invoke(messages)
return {"response" : response.content}
def retrieve_adjusted(state: State):
compressor = CohereRerank(model="rerank-v3.5")
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever, search_kwargs={"k": 5}
)
retrieved_docs = compression_retriever.invoke(state["question"])
return {"context" : retrieved_docs}
def should_continue(state):
last_message = state["messages"][-1]
if last_message.tool_calls:
return "action"
return END
# ------- Runnables ------- #
# retrieve graph
graph_builder = StateGraph(State)
graph_builder.add_node("retrieve", retrieve_adjusted)
graph_builder.add_node("generate", generate)
graph_builder.add_edge(START, "retrieve")
graph_builder.add_edge("retrieve", "generate")
graph_builder.add_edge("generate", END)
graph = graph_builder.compile()
@tool
def ai_rag_tool(question: str) -> str:
"""Useful for when you need to answer questions about magnesium alloys. Input should be a fully formed question."""
response = graph.invoke({"question" : question})
return {
"messages": [HumanMessage(content=response["response"])],
"context": response["context"]
}
# ------------------------------------------------ #
tool_belt = [
ai_rag_tool
]
class AgentState(TypedDict):
messages: Annotated[list, add_messages]
context: List[Document]
tool_node = ToolNode(tool_belt)
uncompiled_graph = StateGraph(AgentState)
def call_model(state):
messages = state["messages"]
response = llm_sml.invoke(messages)
return {
"messages": [response],
"context": state.get("context", [])
}
uncompiled_graph.add_node("agent", call_model)
uncompiled_graph.add_node("action", tool_node)
uncompiled_graph.set_entry_point("agent")
def should_continue(state):
last_message = state["messages"][-1]
if last_message.tool_calls:
return "action"
return END
uncompiled_graph.add_conditional_edges(
"agent",
should_continue
)
uncompiled_graph.add_edge("action", "agent")
compiled_graph = uncompiled_graph.compile()
# ------- Chainlit ------- #
@cl.on_chat_start
async def start():
cl.user_session.set(
"graph", compiled_graph)
@cl.on_message
async def handle(message: cl.Message):
graph = cl.user_session.get("graph")
state = {"messages" : [HumanMessage(content=message.content)]}
response = await graph.ainvoke(state)
await cl.Message(content=response["messages"][-1].content).send()