Fixed bugs with multi LLMs
Browse files- app.py +7 -57
- climateqa/engine/rag.py +8 -17
- climateqa/engine/utils.py +23 -6
app.py
CHANGED
|
@@ -146,88 +146,38 @@ async def chat(query,history,audience,sources,reports):
|
|
| 146 |
if len(reports) == 0:
|
| 147 |
reports = []
|
| 148 |
|
| 149 |
-
|
| 150 |
retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5)
|
| 151 |
rag_chain = make_rag_chain(retriever,llm)
|
| 152 |
-
|
| 153 |
-
# gradio_format = make_pairs([a.content for a in history]) + [(query, "")]
|
| 154 |
-
# history = history + [(query,"")]
|
| 155 |
-
# print(history)
|
| 156 |
-
# print(gradio_format)
|
| 157 |
-
|
| 158 |
-
# # reset memory
|
| 159 |
-
# memory.clear()
|
| 160 |
-
# for message in history:
|
| 161 |
-
# memory.chat_memory.add_message(message)
|
| 162 |
|
| 163 |
inputs = {"query": query,"audience": audience_prompt}
|
| 164 |
result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
|
| 165 |
# result = rag_chain.stream(inputs)
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
final_output_path_id = "/streamed_output/-"
|
| 171 |
|
| 172 |
docs_html = ""
|
| 173 |
output_query = ""
|
| 174 |
output_language = ""
|
| 175 |
gallery = []
|
| 176 |
|
| 177 |
-
# for output in result:
|
| 178 |
-
|
| 179 |
-
# if "language" in output:
|
| 180 |
-
# output_language = output["language"]
|
| 181 |
-
# if "question" in output:
|
| 182 |
-
# output_query = output["question"]
|
| 183 |
-
# if "docs" in output:
|
| 184 |
-
|
| 185 |
-
# try:
|
| 186 |
-
# docs = output['docs'] # List[Document]
|
| 187 |
-
# docs_html = []
|
| 188 |
-
# for i, d in enumerate(docs, 1):
|
| 189 |
-
# docs_html.append(make_html_source(d, i))
|
| 190 |
-
# docs_html = "".join(docs_html)
|
| 191 |
-
# except TypeError:
|
| 192 |
-
# print("No documents found")
|
| 193 |
-
# continue
|
| 194 |
-
|
| 195 |
-
# if "answer" in output:
|
| 196 |
-
# new_token = output["answer"] # str
|
| 197 |
-
# time.sleep(0.03)
|
| 198 |
-
# answer_yet = history[-1][1] + new_token
|
| 199 |
-
# answer_yet = parse_output_llm_with_sources(answer_yet)
|
| 200 |
-
# history[-1] = (query,answer_yet)
|
| 201 |
-
|
| 202 |
-
# yield history,docs_html,output_query,output_language,gallery
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
# async def fallback_iterator(iterable):
|
| 207 |
-
# async for item in iterable:
|
| 208 |
-
# try:
|
| 209 |
-
# yield item
|
| 210 |
-
# except Exception as e:
|
| 211 |
-
# print(f"Error in fallback iterator: {e}")
|
| 212 |
-
# raise gr.Error(f"ClimateQ&A Error: {e}\nThe error has been noted, try another question and if the error remains, you can contact us :)")
|
| 213 |
-
|
| 214 |
try:
|
| 215 |
async for op in result:
|
| 216 |
|
| 217 |
-
|
| 218 |
op = op.ops[0]
|
| 219 |
# print("ITERATION",op)
|
| 220 |
|
| 221 |
-
if op['path'] ==
|
| 222 |
try:
|
| 223 |
output_language = op['value']["language"] # str
|
| 224 |
output_query = op["value"]["question"]
|
| 225 |
except Exception as e:
|
| 226 |
raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
|
| 227 |
|
| 228 |
-
elif op['path'] ==
|
| 229 |
try:
|
| 230 |
-
docs = op['value']['
|
| 231 |
docs_html = []
|
| 232 |
for i, d in enumerate(docs, 1):
|
| 233 |
docs_html.append(make_html_source(d, i))
|
|
@@ -237,7 +187,7 @@ async def chat(query,history,audience,sources,reports):
|
|
| 237 |
print("op: ",op)
|
| 238 |
continue
|
| 239 |
|
| 240 |
-
elif op['path'] ==
|
| 241 |
new_token = op['value'] # str
|
| 242 |
time.sleep(0.01)
|
| 243 |
answer_yet = history[-1][1] + new_token
|
|
|
|
| 146 |
if len(reports) == 0:
|
| 147 |
reports = []
|
| 148 |
|
|
|
|
| 149 |
retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5)
|
| 150 |
rag_chain = make_rag_chain(retriever,llm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
inputs = {"query": query,"audience": audience_prompt}
|
| 153 |
result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
|
| 154 |
# result = rag_chain.stream(inputs)
|
| 155 |
|
| 156 |
+
path_reformulation = "/logs/reformulation/final_output"
|
| 157 |
+
path_retriever = "/logs/find_documents/final_output"
|
| 158 |
+
path_answer = "/logs/answer/streamed_output_str/-"
|
|
|
|
| 159 |
|
| 160 |
docs_html = ""
|
| 161 |
output_query = ""
|
| 162 |
output_language = ""
|
| 163 |
gallery = []
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
try:
|
| 166 |
async for op in result:
|
| 167 |
|
|
|
|
| 168 |
op = op.ops[0]
|
| 169 |
# print("ITERATION",op)
|
| 170 |
|
| 171 |
+
if op['path'] == path_reformulation: # reforulated question
|
| 172 |
try:
|
| 173 |
output_language = op['value']["language"] # str
|
| 174 |
output_query = op["value"]["question"]
|
| 175 |
except Exception as e:
|
| 176 |
raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
|
| 177 |
|
| 178 |
+
elif op['path'] == path_retriever: # documents
|
| 179 |
try:
|
| 180 |
+
docs = op['value']['docs'] # List[Document]
|
| 181 |
docs_html = []
|
| 182 |
for i, d in enumerate(docs, 1):
|
| 183 |
docs_html.append(make_html_source(d, i))
|
|
|
|
| 187 |
print("op: ",op)
|
| 188 |
continue
|
| 189 |
|
| 190 |
+
elif op['path'] == path_answer: # final answer
|
| 191 |
new_token = op['value'] # str
|
| 192 |
time.sleep(0.01)
|
| 193 |
answer_yet = history[-1][1] + new_token
|
climateqa/engine/rag.py
CHANGED
|
@@ -8,8 +8,7 @@ from langchain_core.prompts.base import format_document
|
|
| 8 |
|
| 9 |
from climateqa.engine.reformulation import make_reformulation_chain
|
| 10 |
from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
| 11 |
-
from climateqa.engine.utils import pass_values, flatten_dict
|
| 12 |
-
|
| 13 |
|
| 14 |
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
| 15 |
|
|
@@ -44,21 +43,13 @@ def make_rag_chain(retriever,llm):
|
|
| 44 |
prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
|
| 45 |
|
| 46 |
# ------- CHAIN 0 - Reformulation
|
| 47 |
-
|
| 48 |
-
reformulation = (
|
| 49 |
-
{"reformulation":reformulation_chain,**pass_values(["audience","query"])}
|
| 50 |
-
| RunnablePassthrough()
|
| 51 |
-
| flatten_dict
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
|
| 55 |
# ------- CHAIN 1
|
| 56 |
# Retrieved documents
|
| 57 |
-
find_documents =
|
| 58 |
-
|
| 59 |
-
**pass_values(["question","audience","language","query"])
|
| 60 |
-
} | RunnablePassthrough()
|
| 61 |
-
|
| 62 |
|
| 63 |
# ------- CHAIN 2
|
| 64 |
# Construct inputs for the llm
|
|
@@ -69,15 +60,15 @@ def make_rag_chain(retriever,llm):
|
|
| 69 |
|
| 70 |
# ------- CHAIN 3
|
| 71 |
# Bot answer
|
| 72 |
-
|
| 73 |
|
| 74 |
answer_with_docs = {
|
| 75 |
-
"answer": input_documents | prompt |
|
| 76 |
**pass_values(["question","audience","language","query","docs"]),
|
| 77 |
}
|
| 78 |
|
| 79 |
answer_without_docs = {
|
| 80 |
-
"answer": prompt_without_docs |
|
| 81 |
**pass_values(["question","audience","language","query","docs"]),
|
| 82 |
}
|
| 83 |
|
|
|
|
| 8 |
|
| 9 |
from climateqa.engine.reformulation import make_reformulation_chain
|
| 10 |
from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
| 11 |
+
from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain
|
|
|
|
| 12 |
|
| 13 |
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
| 14 |
|
|
|
|
| 43 |
prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
|
| 44 |
|
| 45 |
# ------- CHAIN 0 - Reformulation
|
| 46 |
+
reformulation = make_reformulation_chain(llm)
|
| 47 |
+
reformulation = prepare_chain(reformulation,"reformulation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# ------- CHAIN 1
|
| 50 |
# Retrieved documents
|
| 51 |
+
find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
|
| 52 |
+
find_documents = prepare_chain(find_documents,"find_documents")
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# ------- CHAIN 2
|
| 55 |
# Construct inputs for the llm
|
|
|
|
| 60 |
|
| 61 |
# ------- CHAIN 3
|
| 62 |
# Bot answer
|
| 63 |
+
llm_final = rename_chain(llm,"answer")
|
| 64 |
|
| 65 |
answer_with_docs = {
|
| 66 |
+
"answer": input_documents | prompt | llm_final | StrOutputParser(),
|
| 67 |
**pass_values(["question","audience","language","query","docs"]),
|
| 68 |
}
|
| 69 |
|
| 70 |
answer_without_docs = {
|
| 71 |
+
"answer": prompt_without_docs | llm_final | StrOutputParser(),
|
| 72 |
**pass_values(["question","audience","language","query","docs"]),
|
| 73 |
}
|
| 74 |
|
climateqa/engine/utils.py
CHANGED
|
@@ -1,10 +1,29 @@
|
|
| 1 |
-
|
| 2 |
-
from typing import Any, Dict, Iterable, Tuple, Union
|
| 3 |
from operator import itemgetter
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
def pass_values(x):
|
| 6 |
-
if not isinstance(x,list):
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
# Drawn from langchain utils and modified to remove the parent key
|
|
@@ -48,5 +67,3 @@ def flatten_dict(
|
|
| 48 |
"""
|
| 49 |
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
|
| 50 |
return flat_dict
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
| 1 |
from operator import itemgetter
|
| 2 |
+
from typing import Any, Dict, Iterable, Tuple
|
| 3 |
+
from langchain_core.runnables import RunnablePassthrough
|
| 4 |
+
|
| 5 |
|
| 6 |
def pass_values(x):
|
| 7 |
+
if not isinstance(x, list):
|
| 8 |
+
x = [x]
|
| 9 |
+
return {k: itemgetter(k) for k in x}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def prepare_chain(chain,name):
|
| 13 |
+
chain = propagate_inputs(chain)
|
| 14 |
+
chain = rename_chain(chain,name)
|
| 15 |
+
return chain
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def propagate_inputs(chain):
|
| 19 |
+
chain_with_values = {
|
| 20 |
+
"outputs": chain,
|
| 21 |
+
"inputs": RunnablePassthrough()
|
| 22 |
+
} | RunnablePassthrough() | flatten_dict
|
| 23 |
+
return chain_with_values
|
| 24 |
+
|
| 25 |
+
def rename_chain(chain,name):
|
| 26 |
+
return chain.with_config({"run_name":name})
|
| 27 |
|
| 28 |
|
| 29 |
# Drawn from langchain utils and modified to remove the parent key
|
|
|
|
| 67 |
"""
|
| 68 |
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
|
| 69 |
return flat_dict
|
|
|
|
|
|