Spaces:
Sleeping
Sleeping
John Graham Reynolds
commited on
Commit
·
29cf982
1
Parent(s):
8df66b4
add chain for reformatting inputs and augmenting the question with relevant context
Browse files
chain.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import mlflow
|
3 |
+
import streamlit as st
|
4 |
+
from operator import itemgetter
|
5 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
6 |
+
from langchain_databricks.vectorstores import DatabricksVectorSearch
|
7 |
+
from langchain_community.chat_models import ChatDatabricks
|
8 |
+
from langchain_community.vectorstores import DatabricksVectorSearch
|
9 |
+
from langchain_core.runnables import RunnableLambda
|
10 |
+
from langchain_core.output_parsers import StrOutputParser
|
11 |
+
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
|
12 |
+
from langchain_core.runnables import RunnablePassthrough, RunnableBranch
|
13 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
14 |
+
|
15 |
+
# ## Enable MLflow Tracing
|
16 |
+
# mlflow.langchain.autolog()
|
17 |
+
|
18 |
+
class ChainBuilder:
|
19 |
+
|
20 |
+
def __init__(self):
|
21 |
+
# Load the chain's configuration from yaml
|
22 |
+
self.model_config = mlflow.models.ModelConfig(development_config="chain_config.yaml")
|
23 |
+
self.databricks_resources = self.model_config.get("databricks_resources")
|
24 |
+
self.llm_config = self.model_config.get("llm_config")
|
25 |
+
self.retriever_config = self.model_config.get("retriever_config")
|
26 |
+
self.vector_search_schema = self.retriever_config.get("schema")
|
27 |
+
|
28 |
+
# Return the string contents of the most recent message from the user
|
29 |
+
def extract_user_query_string(chat_messages_array):
|
30 |
+
return chat_messages_array[-1]["content"]
|
31 |
+
|
32 |
+
# Return the chat history, which is everything before the last question
|
33 |
+
def extract_chat_history(chat_messages_array):
|
34 |
+
return chat_messages_array[:-1]
|
35 |
+
|
36 |
+
# ** working logic for querying glossary embeddings
|
37 |
+
# Same embedding model we used to create embeddings of terms
|
38 |
+
# make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
|
39 |
+
# try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model
|
40 |
+
# does this cache to the given folder though? It does appear to populate the folder as expected after being run
|
41 |
+
@st.cache_resource # will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching
|
42 |
+
def load_embedding_model(self):
|
43 |
+
embeddings = HuggingFaceEmbeddings(model_name=self.retriever_config.get("embedding_model"), cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container
|
44 |
+
# update this to read from a presaved cache of bge-large
|
45 |
+
return embeddings
|
46 |
+
|
47 |
+
def get_retriever(self):
|
48 |
+
embeddings = self.load_embedding_model()
|
49 |
+
# instantiate the vector store for similarity search in our chain
|
50 |
+
# need to make this a function and decorate it with @st.experimental_memo as above?
|
51 |
+
# We are only calling this initiatially when the Space starts and builds the chain. Can we expedite this process for users when opening up this Space?
|
52 |
+
# @st.cache_data # TODO add this in
|
53 |
+
vector_search_as_retriever = DatabricksVectorSearch(
|
54 |
+
endpoint=self.databricks_resources.get("vector_search_endpoint_name"),
|
55 |
+
index_name=self.retriever_config.get("vector_search_index"),
|
56 |
+
embedding=embeddings,
|
57 |
+
text_column="name",
|
58 |
+
columns=["name", "description"],
|
59 |
+
).as_retriever(search_kwargs=self.retriever_config.get("parameters"))
|
60 |
+
return vector_search_as_retriever
|
61 |
+
|
62 |
+
# # *** TODO Evaluate this block as it relates to "RAG Studio Review App" ***
|
63 |
+
# # Enable the RAG Studio Review App to properly display retrieved chunks and evaluation suite to measure the retriever
|
64 |
+
# mlflow.models.set_retriever_schema(
|
65 |
+
# primary_key=self.vector_search_schema.get("primary_key"),
|
66 |
+
# text_column=vector_search_schema.get("chunked_terms"),
|
67 |
+
# # doc_uri=vector_search_schema.get("definition")
|
68 |
+
# other_columns=[vector_search_schema.get("definition")],
|
69 |
+
# # Review App uses `doc_uri` to display chunks from the same document in a single view
|
70 |
+
# )
|
71 |
+
|
72 |
+
# Method to format the terms and definitions returned by the retriever into the prompt
|
73 |
+
# TODO double check the contents here
|
74 |
+
def format_context(self, retrieved_terms):
|
75 |
+
chunk_template = self.retriever_config.get("chunk_template")
|
76 |
+
chunk_contents = [
|
77 |
+
chunk_template.format(
|
78 |
+
name=term.page_content,
|
79 |
+
description=term.metadata[self.vector_search_schema.get("description")],
|
80 |
+
)
|
81 |
+
for term in retrieved_terms
|
82 |
+
]
|
83 |
+
return "".join(chunk_contents)
|
84 |
+
|
85 |
+
def get_prompt(self):
|
86 |
+
# Prompt Template for generation
|
87 |
+
prompt = ChatPromptTemplate.from_messages(
|
88 |
+
[
|
89 |
+
("system", self.llm_config.get("llm_prompt_template")),
|
90 |
+
# *** Note: This chain does not compress the history, so very long converastions can overflow the context window. TODO
|
91 |
+
# We need to at some point chop this history down to fixed amount of recent messages
|
92 |
+
MessagesPlaceholder(variable_name="formatted_chat_history"),
|
93 |
+
# User's most current question
|
94 |
+
("user", "{question}"),
|
95 |
+
]
|
96 |
+
)
|
97 |
+
return prompt
|
98 |
+
|
99 |
+
# Format the converastion history to fit into the prompt template above.
|
100 |
+
# **** TODO after only a few statements this will likely overflow the context window
|
101 |
+
def format_chat_history_for_prompt(self, chat_messages_array):
|
102 |
+
history = self.extract_chat_history(chat_messages_array)
|
103 |
+
formatted_chat_history = []
|
104 |
+
if len(history) > 0:
|
105 |
+
for chat_message in history:
|
106 |
+
if chat_message["role"] == "user":
|
107 |
+
formatted_chat_history.append(HumanMessage(content=chat_message["content"]))
|
108 |
+
elif chat_message["role"] == "assistant":
|
109 |
+
formatted_chat_history.append(AIMessage(content=chat_message["content"]))
|
110 |
+
return formatted_chat_history
|
111 |
+
|
112 |
+
def get_query_rewrite_prompt():
|
113 |
+
# Prompt template for query rewriting from chat history. This will translate a query such as "how does it work?" after a question like "what is spark?" to "how does spark work?"
|
114 |
+
query_rewrite_template = """Based on the chat history below, we want you to generate a query for an external data source to retrieve relevant information so
|
115 |
+
that we can better answer the question. The query should be in natural language. The external data source uses similarity search to search for relevant
|
116 |
+
information in a vector space. So, the query should be similar to the relevant information semantically. Answer with only the query. Do not add explanation.
|
117 |
+
|
118 |
+
Chat history: {chat_history}
|
119 |
+
|
120 |
+
Question: {question}"""
|
121 |
+
|
122 |
+
query_rewrite_prompt = PromptTemplate(
|
123 |
+
template=query_rewrite_template,
|
124 |
+
input_variables=["chat_history", "question"],
|
125 |
+
)
|
126 |
+
return query_rewrite_prompt
|
127 |
+
|
128 |
+
@st.cache_resource
|
129 |
+
def get_model(self):
|
130 |
+
# Foundation Model for generation
|
131 |
+
model = ChatDatabricks(
|
132 |
+
endpoint=self.databricks_resources.get("llm_endpoint_name"),
|
133 |
+
extra_params=self.llm_config.get("llm_parameters"),
|
134 |
+
)
|
135 |
+
return model
|
136 |
+
|
137 |
+
@st.cache_resource
|
138 |
+
def build_chain(self):
|
139 |
+
model = self.get_model()
|
140 |
+
prompt = self.get_prompt()
|
141 |
+
format_context = self.format_context()
|
142 |
+
vector_search_as_retriever = self.get_retriever()
|
143 |
+
query_rewrite_prompt = self.get_query_rewrite_prompt()
|
144 |
+
|
145 |
+
# RAG Chain
|
146 |
+
chain = (
|
147 |
+
{
|
148 |
+
# set 'question' to the result of: grabbing the ["messages"] component of the dict we 'invoke()' or 'stream()', then passing to extract_user_query_string()
|
149 |
+
"question": itemgetter("messages") | RunnableLambda(self.extract_user_query_string),
|
150 |
+
"chat_history": itemgetter("messages") | RunnableLambda(self.extract_chat_history),
|
151 |
+
"formatted_chat_history": itemgetter("messages")
|
152 |
+
| RunnableLambda(self.format_chat_history_for_prompt),
|
153 |
+
}
|
154 |
+
| RunnablePassthrough() # allows one to pass elements unchanged through the chain to the next link in the chain
|
155 |
+
| {
|
156 |
+
"context": RunnableBranch( # Only re-write the question if there is a chat history - RunnableBranch() is essentially a LCEL if statement
|
157 |
+
(
|
158 |
+
lambda x: len(x["chat_history"]) > 0, #https://python.langchain.com/api_reference/core/runnables/langchain_core.runnables.branch.RunnableBranch.html
|
159 |
+
query_rewrite_prompt | model | StrOutputParser(), # rewrite question with context
|
160 |
+
),
|
161 |
+
itemgetter("question"), # else, just ask the question
|
162 |
+
)
|
163 |
+
| vector_search_as_retriever # set 'context' to the result of passing either the base question, or the reformatted question to the retriever for semantic search
|
164 |
+
| RunnableLambda(format_context),
|
165 |
+
"formatted_chat_history": itemgetter("formatted_chat_history"),
|
166 |
+
"question": itemgetter("question"),
|
167 |
+
}
|
168 |
+
| prompt # 'context', 'formatted_chat_history', and 'question' passed to prompt
|
169 |
+
| model # prompt passed to model
|
170 |
+
| StrOutputParser()
|
171 |
+
)
|
172 |
+
|
173 |
+
return chain
|
174 |
+
|
175 |
+
# ## Tell MLflow logging where to find your chain.
|
176 |
+
# mlflow.models.set_model(model=chain)
|