fsal's picture
first commit
c8ebe28
raw history blame
No virus
9.57 kB
import uuid
from tempfile import NamedTemporaryFile
from typing import Tuple, List, Optional, Dict
from langchain.agents import AgentExecutor, AgentType, initialize_agent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import LLMChain
from langchain.chains import RetrievalQA
from langchain.chat_models import (
AzureChatOpenAI,
ChatOpenAI,
ChatAnthropic,
ChatAnyscale,
)
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.memory import ConversationBufferMemory
from langchain.prompts import MessagesPlaceholder, ChatPromptTemplate
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.schema import Document, BaseRetriever
from langchain.schema.chat_history import BaseChatMessageHistory
from langchain.schema.runnable import RunnablePassthrough
from langchain.storage import InMemoryStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool
from langchain.vectorstores import FAISS
from langchain_core.messages import SystemMessage
from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
from qagen import get_rag_qa_gen_chain
from summarize import get_rag_summarization_chain
def get_agent(
tools: list[BaseTool],
chat_history: BaseChatMessageHistory,
llm: BaseLLM,
callbacks,
):
memory_key = "agent_history"
system_message = SystemMessage(
content=(
"Do your best to answer the questions. "
"Feel free to use any tools available to look up "
"relevant information, only if necessary"
),
)
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=system_message,
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
)
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
# agent_memory = AgentTokenBufferMemory(
# chat_memory=chat_history,
# memory_key=memory_key,
# llm=llm,
# )
agent_memory = ConversationBufferMemory(
chat_memory=chat_history,
return_messages=True,
memory_key=memory_key,
)
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
memory=agent_memory,
verbose=True,
return_intermediate_steps=False,
callbacks=callbacks,
)
return (
{"input": RunnablePassthrough()}
| agent_executor
| (lambda output: output["output"])
)
def get_doc_agent(
tools: list[BaseTool],
llm: Optional[BaseLLM] = None,
agent_type: AgentType = AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
):
if llm is None:
llm = ChatOpenAI(
model_name="gpt-4-1106-preview",
temperature=0.0,
streaming=True,
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You assist a chatbot with answering questions about a document.
If necessary, break up incoming questions into multiple parts,
and use the tools provided to answer smaller questions before
answering the larger question.
""",
),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
],
)
agent_executor = initialize_agent(
tools,
llm,
agent=agent_type,
verbose=True,
memory=None,
handle_parsing_errors=True,
prompt=prompt,
)
return (
{"input": RunnablePassthrough()}
| agent_executor
| (lambda output: output["output"])
)
def get_runnable(
use_document_chat: bool,
document_chat_chain_type: str,
llm,
retriever,
memory,
chat_prompt,
summarization_prompt,
):
if not use_document_chat:
return LLMChain(
prompt=chat_prompt,
llm=llm,
memory=memory,
) | (lambda output: output["text"])
if document_chat_chain_type == "Q&A Generation":
return get_rag_qa_gen_chain(
retriever,
llm,
)
elif document_chat_chain_type == "Summarization":
return get_rag_summarization_chain(
summarization_prompt,
retriever,
llm,
)
else:
return RetrievalQA.from_chain_type(
llm=llm,
chain_type=document_chat_chain_type,
retriever=retriever,
output_key="output_text",
) | (lambda output: output["output_text"])
def get_llm(
provider: str,
model: str,
provider_api_key: str,
temperature: float,
max_tokens: int,
azure_available: bool,
azure_dict: dict[str, str],
):
if azure_available and provider == "Azure OpenAI":
return AzureChatOpenAI(
azure_endpoint=azure_dict["AZURE_OPENAI_BASE_URL"],
openai_api_version=azure_dict["AZURE_OPENAI_API_VERSION"],
deployment_name=azure_dict["AZURE_OPENAI_DEPLOYMENT_NAME"],
openai_api_key=azure_dict["AZURE_OPENAI_API_KEY"],
openai_api_type="azure",
model_version=azure_dict["AZURE_OPENAI_MODEL_VERSION"],
temperature=temperature,
streaming=True,
max_tokens=max_tokens,
)
elif provider_api_key:
if provider == "OpenAI":
return ChatOpenAI(
model_name=model,
openai_api_key=provider_api_key,
temperature=temperature,
streaming=True,
max_tokens=max_tokens,
)
elif provider == "Anthropic":
return ChatAnthropic(
model=model,
anthropic_api_key=provider_api_key,
temperature=temperature,
streaming=True,
max_tokens_to_sample=max_tokens,
)
elif provider == "Anyscale Endpoints":
return ChatAnyscale(
model_name=model,
anyscale_api_key=provider_api_key,
temperature=temperature,
streaming=True,
max_tokens=max_tokens,
)
return None
def get_texts_and_multiretriever(
uploaded_file_bytes: bytes,
openai_api_key: str,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
k: int = DEFAULT_RETRIEVER_K,
azure_kwargs: Optional[Dict[str, str]] = None,
use_azure: bool = False,
) -> Tuple[List[Document], BaseRetriever]:
with NamedTemporaryFile() as temp_file:
temp_file.write(uploaded_file_bytes)
temp_file.seek(0)
loader = PyPDFLoader(temp_file.name)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=10000,
chunk_overlap=0,
)
child_text_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
texts = text_splitter.split_documents(documents)
id_key = "doc_id"
text_ids = [str(uuid.uuid4()) for _ in texts]
sub_texts = []
for i, text in enumerate(texts):
_id = text_ids[i]
_sub_texts = child_text_splitter.split_documents([text])
for _text in _sub_texts:
_text.metadata[id_key] = _id
sub_texts.extend(_sub_texts)
embeddings_kwargs = {"openai_api_key": openai_api_key}
if use_azure and azure_kwargs:
azure_kwargs["azure_endpoint"] = azure_kwargs.pop("openai_api_base")
embeddings_kwargs.update(azure_kwargs)
embeddings = AzureOpenAIEmbeddings(**embeddings_kwargs)
else:
embeddings = OpenAIEmbeddings(**embeddings_kwargs)
store = InMemoryStore()
# MultiVectorRetriever
multivectorstore = FAISS.from_documents(sub_texts, embeddings)
multivector_retriever = MultiVectorRetriever(
vectorstore=multivectorstore,
docstore=store,
id_key=id_key,
)
multivector_retriever.docstore.mset(list(zip(text_ids, texts)))
# multivector_retriever.k = k
multiquery_text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
# MultiQueryRetriever
multiquery_texts = multiquery_text_splitter.split_documents(documents)
multiquerystore = FAISS.from_documents(multiquery_texts, embeddings)
multiquery_retriever = MultiQueryRetriever.from_llm(
retriever=multiquerystore.as_retriever(search_kwargs={"k": k}),
llm=ChatOpenAI(),
)
ensemble_retriever = EnsembleRetriever(
retrievers=[multiquery_retriever, multivector_retriever],
weights=[0.5, 0.5],
)
return multiquery_texts, ensemble_retriever
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.markdown(self.text)