|
import uuid |
|
from tempfile import NamedTemporaryFile |
|
from typing import Tuple, List, Optional, Dict |
|
|
|
from langchain.agents import AgentExecutor, AgentType, initialize_agent |
|
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent |
|
from langchain.callbacks.base import BaseCallbackHandler |
|
from langchain.chains import LLMChain |
|
from langchain.chains import RetrievalQA |
|
from langchain.chat_models import ( |
|
AzureChatOpenAI, |
|
ChatOpenAI, |
|
ChatAnthropic, |
|
ChatAnyscale, |
|
) |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings |
|
from langchain.llms.base import BaseLLM |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.prompts import MessagesPlaceholder, ChatPromptTemplate |
|
from langchain.retrievers import EnsembleRetriever |
|
from langchain.retrievers.multi_query import MultiQueryRetriever |
|
from langchain.retrievers.multi_vector import MultiVectorRetriever |
|
from langchain.schema import Document, BaseRetriever |
|
from langchain.schema.chat_history import BaseChatMessageHistory |
|
from langchain.schema.runnable import RunnablePassthrough |
|
from langchain.storage import InMemoryStore |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.tools.base import BaseTool |
|
from langchain.vectorstores import FAISS |
|
from langchain_core.messages import SystemMessage |
|
|
|
from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K |
|
from qagen import get_rag_qa_gen_chain |
|
from summarize import get_rag_summarization_chain |
|
|
|
|
|
def get_agent( |
|
tools: list[BaseTool], |
|
chat_history: BaseChatMessageHistory, |
|
llm: BaseLLM, |
|
callbacks, |
|
): |
|
memory_key = "agent_history" |
|
system_message = SystemMessage( |
|
content=( |
|
"Do your best to answer the questions. " |
|
"Feel free to use any tools available to look up " |
|
"relevant information, only if necessary" |
|
), |
|
) |
|
prompt = OpenAIFunctionsAgent.create_prompt( |
|
system_message=system_message, |
|
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)], |
|
) |
|
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent_memory = ConversationBufferMemory( |
|
chat_memory=chat_history, |
|
return_messages=True, |
|
memory_key=memory_key, |
|
) |
|
|
|
agent_executor = AgentExecutor( |
|
agent=agent, |
|
tools=tools, |
|
memory=agent_memory, |
|
verbose=True, |
|
return_intermediate_steps=False, |
|
callbacks=callbacks, |
|
) |
|
return ( |
|
{"input": RunnablePassthrough()} |
|
| agent_executor |
|
| (lambda output: output["output"]) |
|
) |
|
|
|
|
|
def get_doc_agent( |
|
tools: list[BaseTool], |
|
llm: Optional[BaseLLM] = None, |
|
agent_type: AgentType = AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, |
|
): |
|
if llm is None: |
|
llm = ChatOpenAI( |
|
model_name="gpt-4-1106-preview", |
|
temperature=0.0, |
|
streaming=True, |
|
) |
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
( |
|
"system", |
|
""" |
|
You assist a chatbot with answering questions about a document. |
|
If necessary, break up incoming questions into multiple parts, |
|
and use the tools provided to answer smaller questions before |
|
answering the larger question. |
|
""", |
|
), |
|
("user", "{input}"), |
|
MessagesPlaceholder(variable_name="agent_scratchpad"), |
|
], |
|
) |
|
agent_executor = initialize_agent( |
|
tools, |
|
llm, |
|
agent=agent_type, |
|
verbose=True, |
|
memory=None, |
|
handle_parsing_errors=True, |
|
prompt=prompt, |
|
) |
|
return ( |
|
{"input": RunnablePassthrough()} |
|
| agent_executor |
|
| (lambda output: output["output"]) |
|
) |
|
|
|
|
|
def get_runnable( |
|
use_document_chat: bool, |
|
document_chat_chain_type: str, |
|
llm, |
|
retriever, |
|
memory, |
|
chat_prompt, |
|
summarization_prompt, |
|
): |
|
if not use_document_chat: |
|
return LLMChain( |
|
prompt=chat_prompt, |
|
llm=llm, |
|
memory=memory, |
|
) | (lambda output: output["text"]) |
|
|
|
if document_chat_chain_type == "Q&A Generation": |
|
return get_rag_qa_gen_chain( |
|
retriever, |
|
llm, |
|
) |
|
elif document_chat_chain_type == "Summarization": |
|
return get_rag_summarization_chain( |
|
summarization_prompt, |
|
retriever, |
|
llm, |
|
) |
|
else: |
|
return RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type=document_chat_chain_type, |
|
retriever=retriever, |
|
output_key="output_text", |
|
) | (lambda output: output["output_text"]) |
|
|
|
|
|
def get_llm( |
|
provider: str, |
|
model: str, |
|
provider_api_key: str, |
|
temperature: float, |
|
max_tokens: int, |
|
azure_available: bool, |
|
azure_dict: dict[str, str], |
|
): |
|
if azure_available and provider == "Azure OpenAI": |
|
return AzureChatOpenAI( |
|
azure_endpoint=azure_dict["AZURE_OPENAI_BASE_URL"], |
|
openai_api_version=azure_dict["AZURE_OPENAI_API_VERSION"], |
|
deployment_name=azure_dict["AZURE_OPENAI_DEPLOYMENT_NAME"], |
|
openai_api_key=azure_dict["AZURE_OPENAI_API_KEY"], |
|
openai_api_type="azure", |
|
model_version=azure_dict["AZURE_OPENAI_MODEL_VERSION"], |
|
temperature=temperature, |
|
streaming=True, |
|
max_tokens=max_tokens, |
|
) |
|
|
|
elif provider_api_key: |
|
if provider == "OpenAI": |
|
return ChatOpenAI( |
|
model_name=model, |
|
openai_api_key=provider_api_key, |
|
temperature=temperature, |
|
streaming=True, |
|
max_tokens=max_tokens, |
|
) |
|
|
|
elif provider == "Anthropic": |
|
return ChatAnthropic( |
|
model=model, |
|
anthropic_api_key=provider_api_key, |
|
temperature=temperature, |
|
streaming=True, |
|
max_tokens_to_sample=max_tokens, |
|
) |
|
|
|
elif provider == "Anyscale Endpoints": |
|
return ChatAnyscale( |
|
model_name=model, |
|
anyscale_api_key=provider_api_key, |
|
temperature=temperature, |
|
streaming=True, |
|
max_tokens=max_tokens, |
|
) |
|
|
|
return None |
|
|
|
|
|
def get_texts_and_multiretriever( |
|
uploaded_file_bytes: bytes, |
|
openai_api_key: str, |
|
chunk_size: int = DEFAULT_CHUNK_SIZE, |
|
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, |
|
k: int = DEFAULT_RETRIEVER_K, |
|
azure_kwargs: Optional[Dict[str, str]] = None, |
|
use_azure: bool = False, |
|
) -> Tuple[List[Document], BaseRetriever]: |
|
with NamedTemporaryFile() as temp_file: |
|
temp_file.write(uploaded_file_bytes) |
|
temp_file.seek(0) |
|
|
|
loader = PyPDFLoader(temp_file.name) |
|
documents = loader.load() |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=10000, |
|
chunk_overlap=0, |
|
) |
|
child_text_splitter = RecursiveCharacterTextSplitter(chunk_size=400) |
|
|
|
texts = text_splitter.split_documents(documents) |
|
id_key = "doc_id" |
|
|
|
text_ids = [str(uuid.uuid4()) for _ in texts] |
|
sub_texts = [] |
|
for i, text in enumerate(texts): |
|
_id = text_ids[i] |
|
_sub_texts = child_text_splitter.split_documents([text]) |
|
for _text in _sub_texts: |
|
_text.metadata[id_key] = _id |
|
sub_texts.extend(_sub_texts) |
|
|
|
embeddings_kwargs = {"openai_api_key": openai_api_key} |
|
if use_azure and azure_kwargs: |
|
azure_kwargs["azure_endpoint"] = azure_kwargs.pop("openai_api_base") |
|
embeddings_kwargs.update(azure_kwargs) |
|
embeddings = AzureOpenAIEmbeddings(**embeddings_kwargs) |
|
else: |
|
embeddings = OpenAIEmbeddings(**embeddings_kwargs) |
|
store = InMemoryStore() |
|
|
|
|
|
multivectorstore = FAISS.from_documents(sub_texts, embeddings) |
|
multivector_retriever = MultiVectorRetriever( |
|
vectorstore=multivectorstore, |
|
docstore=store, |
|
id_key=id_key, |
|
) |
|
multivector_retriever.docstore.mset(list(zip(text_ids, texts))) |
|
|
|
|
|
multiquery_text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, |
|
chunk_overlap=chunk_overlap, |
|
) |
|
|
|
multiquery_texts = multiquery_text_splitter.split_documents(documents) |
|
multiquerystore = FAISS.from_documents(multiquery_texts, embeddings) |
|
multiquery_retriever = MultiQueryRetriever.from_llm( |
|
retriever=multiquerystore.as_retriever(search_kwargs={"k": k}), |
|
llm=ChatOpenAI(), |
|
) |
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[multiquery_retriever, multivector_retriever], |
|
weights=[0.5, 0.5], |
|
) |
|
return multiquery_texts, ensemble_retriever |
|
|
|
|
|
class StreamHandler(BaseCallbackHandler): |
|
def __init__(self, container, initial_text=""): |
|
self.container = container |
|
self.text = initial_text |
|
|
|
def on_llm_new_token(self, token: str, **kwargs) -> None: |
|
self.text += token |
|
self.container.markdown(self.text) |
|
|