Spaces:
Sleeping
Sleeping
from typing import Optional | |
# langchain imports | |
from langchain.chat_models import ChatOpenAI | |
from langchain.schema.runnable import RunnableMap | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema.runnable import RunnablePassthrough | |
from langchain.schema.output_parser import StrOutputParser | |
from operator import itemgetter | |
from langchain.schema.messages import HumanMessage, SystemMessage, AIMessage | |
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler | |
def format_docs(docs): | |
res = "" | |
# res = str(docs) | |
for doc in docs: | |
escaped_page_content = doc.page_content.replace("\n", "\\n") | |
res += "<doc>\n" | |
res += f" <content>{escaped_page_content}</content>\n" | |
for m in doc.metadata: | |
res += f" <{m}>{doc.metadata[m]}</{m}>\n" | |
res += "</doc>\n" | |
return res | |
def get_search_index(file_name="Mahmoudi_Nima_202202_PhD.pdf", index_folder="index"): | |
# load embeddings | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
search_index = FAISS.load_local( | |
folder_path=index_folder, | |
index_name=file_name + ".index", | |
embeddings=OpenAIEmbeddings(), | |
) | |
return search_index | |
def convert_message(m): | |
if m["role"] == "user": | |
return HumanMessage(content=m["content"]) | |
elif m["role"] == "assistant": | |
return AIMessage(content=m["content"]) | |
elif m["role"] == "system": | |
return SystemMessage(content=m["content"]) | |
else: | |
raise ValueError(f"Unknown role {m['role']}") | |
_condense_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. | |
Chat History: | |
{chat_history} | |
Follow Up Input: {input} | |
Standalone question:""" | |
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_condense_template) | |
_rag_template = """Answer the question based only on the following context, citing the page number(s) of the document(s) you used to answer the question: | |
{context} | |
Question: {question} | |
""" | |
ANSWER_PROMPT = ChatPromptTemplate.from_template(_rag_template) | |
def _format_chat_history(chat_history): | |
def format_single_chat_message(m): | |
if type(m) is HumanMessage: | |
return "Human: " + m.content | |
elif type(m) is AIMessage: | |
return "Assistant: " + m.content | |
elif type(m) is SystemMessage: | |
return "System: " + m.content | |
else: | |
raise ValueError(f"Unknown role {m['role']}") | |
return "\n".join([format_single_chat_message(m) for m in chat_history]) | |
def get_standalone_question_from_chat_history_chain(): | |
_inputs = RunnableMap( | |
standalone_question=RunnablePassthrough.assign( | |
chat_history=lambda x: _format_chat_history(x["chat_history"]) | |
) | |
| CONDENSE_QUESTION_PROMPT | |
| ChatOpenAI(temperature=0) | |
| StrOutputParser(), | |
) | |
return _inputs | |
def get_rag_chain(file_name, index_folder="index", retrieval_cb=None): | |
vectorstore = get_search_index(file_name, index_folder) | |
retriever = vectorstore.as_retriever() | |
if retrieval_cb is None: | |
retrieval_cb = lambda x: x | |
def context_update_fn(q): | |
retrieval_cb([q]) | |
return q | |
_inputs = RunnableMap( | |
standalone_question=RunnablePassthrough.assign( | |
chat_history=lambda x: _format_chat_history(x["chat_history"]) | |
) | |
| CONDENSE_QUESTION_PROMPT | |
| ChatOpenAI(temperature=0) | |
| StrOutputParser(), | |
) | |
_context = { | |
"context": itemgetter("standalone_question") | RunnablePassthrough(context_update_fn) | retriever | format_docs, | |
"question": lambda x: x["standalone_question"], | |
} | |
conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI() | |
return conversational_qa_chain | |
# RAG fusion chain | |
# source1: https://youtu.be/GchC5WxeXGc?si=6i7J0rPZI7SNwFYZ | |
# source2: https://towardsdatascience.com/forget-rag-the-future-is-rag-fusion-1147298d8ad1 | |
def reciprocal_rank_fusion(results: list[list], k=60): | |
from langchain.load import dumps, loads | |
fused_scores = {} | |
for docs in results: | |
# Assumes the docs are returned in sorted order of relevance | |
for rank, doc in enumerate(docs): | |
doc_str = dumps(doc) | |
if doc_str not in fused_scores: | |
fused_scores[doc_str] = 0 | |
fused_scores[doc_str] += 1 / (rank + k) | |
reranked_results = [ | |
(loads(doc), score) | |
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) | |
] | |
return reranked_results | |
def get_search_query_generation_chain(): | |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate | |
prompt = ChatPromptTemplate( | |
input_variables=['original_query'], | |
messages=[ | |
SystemMessagePromptTemplate( | |
prompt=PromptTemplate( | |
input_variables=[], | |
template='You are a helpful assistant that generates multiple search queries based on a single input query.' | |
) | |
), | |
HumanMessagePromptTemplate( | |
prompt=PromptTemplate( | |
input_variables=['original_query'], | |
template='Generate multiple search queries related to: {original_query} \n OUTPUT (4 queries):' | |
) | |
) | |
] | |
) | |
generate_queries = ( | |
prompt | | |
ChatOpenAI(temperature=0) | | |
StrOutputParser() | | |
(lambda x: x.split("\n")) | |
) | |
return generate_queries | |
def get_rag_fusion_chain(file_name, index_folder="index", retrieval_cb=None): | |
vectorstore = get_search_index(file_name, index_folder) | |
retriever = vectorstore.as_retriever() | |
query_generation_chain = get_search_query_generation_chain() | |
_inputs = RunnableMap( | |
standalone_question=RunnablePassthrough.assign( | |
chat_history=lambda x: _format_chat_history(x["chat_history"]) | |
) | |
| CONDENSE_QUESTION_PROMPT | |
| ChatOpenAI(temperature=0) | |
| StrOutputParser(), | |
) | |
if retrieval_cb is None: | |
retrieval_cb = lambda x: x | |
_context = { | |
"context": | |
RunnablePassthrough.assign( | |
original_query=lambda x: x["standalone_question"] | |
) | |
| query_generation_chain | |
| retrieval_cb | |
| retriever.map() | |
| reciprocal_rank_fusion | |
| (lambda x: [item[0] for item in x]) | |
| format_docs, | |
"question": lambda x: x["standalone_question"], | |
} | |
conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI() | |
return conversational_qa_chain | |
#################### | |
# Adding agent chain with OpenAI function calling | |
def get_search_tool_from_index(search_index, st_cb: Optional[StreamlitCallbackHandler] = None, ): | |
from langchain.agents import tool | |
from agent_helper import retry_and_streamlit_callback | |
def search(query: str) -> str: | |
"""Search the contents of the source document for the queries.""" | |
docs = search_index.similarity_search(query, k=5) | |
return format_docs(docs) | |
return search | |
def get_lc_oai_tools(file_name:str = "Mahmoudi_Nima_202202_PhD.pdf", index_folder: str = "index", st_cb: Optional[StreamlitCallbackHandler] = None, ): | |
from langchain.tools.render import format_tool_to_openai_tool | |
search_index = get_search_index(file_name, index_folder) | |
lc_tools = [get_search_tool_from_index(search_index=search_index, st_cb=st_cb)] | |
oai_tools = [format_tool_to_openai_tool(t) for t in lc_tools] | |
return lc_tools, oai_tools | |
def get_agent_chain(file_name="Mahmoudi_Nima_202202_PhD.pdf", index_folder="index", callbacks=None, st_cb: Optional[StreamlitCallbackHandler] = None, ): | |
if callbacks is None: | |
callbacks = [] | |
from langchain.agents import initialize_agent, AgentType | |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain.agents.format_scratchpad.openai_tools import ( | |
format_to_openai_tool_messages, | |
) | |
from langchain.agents import AgentExecutor | |
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser | |
lc_tools, oai_tools = get_lc_oai_tools(file_name, index_folder, st_cb) | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", "You are a helpful assistant, use the search tool to answer the user's question and cite only the page number when you use information coming (like [p1]) from the source document.\nchat history: {chat_history}"), | |
("user", "{input}"), | |
MessagesPlaceholder(variable_name="agent_scratchpad"), | |
] | |
) | |
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-1106") | |
agent = ( | |
{ | |
"input": lambda x: x["input"], | |
"agent_scratchpad": lambda x: format_to_openai_tool_messages( | |
x["intermediate_steps"] | |
), | |
"chat_history": lambda x: _format_chat_history(x["chat_history"]), | |
} | |
| prompt | |
| llm.bind(tools=oai_tools) | |
| OpenAIToolsAgentOutputParser() | |
) | |
agent_executor = AgentExecutor(agent=agent, tools=lc_tools, verbose=True, callbacks=callbacks) | |
return agent_executor | |
if __name__ == "__main__": | |
question_generation_chain = get_search_query_generation_chain() | |
print('='*50) | |
print('RAG Chain') | |
chain = get_rag_chain() | |
print(chain.invoke({'input': 'serverless computing', 'chat_history': []})) | |
print('='*50) | |
print('Question Generation Chain') | |
print(question_generation_chain.invoke({'original_query': 'serverless computing'})) | |
print('-'*50) | |
print('RAG Fusion Chain') | |
chain = get_rag_fusion_chain() | |
print(chain.invoke({'input': 'serverless computing', 'chat_history': []})) | |
agent_executor = get_agent_chain() | |
print( | |
agent_executor.invoke({ | |
"input": "based on the source document, compare FaaS with BaaS??", | |
"chat_history": [], | |
}) | |
) | |