|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from typing import List |
|
from typing import Literal |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain_core.utils.function_calling import convert_to_openai_function |
|
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser |
|
|
|
|
|
ROUTING_INDEX = { |
|
"Vector":["IPCC","IPBES","IPOS"], |
|
"OpenAlex":["OpenAlex"], |
|
} |
|
|
|
POSSIBLE_SOURCES = [y for values in ROUTING_INDEX.values() for y in values] |
|
|
|
|
|
|
|
class QueryDecomposition(BaseModel): |
|
""" |
|
Decompose the user query into smaller parts to think step by step to answer this question |
|
Act as a simple planning agent |
|
""" |
|
|
|
questions: List[str] = Field( |
|
description=""" |
|
Think step by step to answer this question, and provide one or several search engine questions in English for knowledge that you need. |
|
Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature |
|
- If it's already a standalone and explicit question, just return the reformulated question for the search engine |
|
- If you need to decompose the question, output a list of maximum 2 to 3 questions |
|
""" |
|
) |
|
|
|
|
|
class Location(BaseModel): |
|
country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...") |
|
location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...") |
|
|
|
class QueryAnalysis(BaseModel): |
|
""" |
|
Analyzing the user query to extract topics, sources and date |
|
Also do query expansion to get alternative search queries |
|
Also provide simple keywords to feed a search engine |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sources: List[Literal["IPCC", "IPBES", "IPOS"]] = Field( |
|
..., |
|
description=""" |
|
Given a user question choose which documents would be most relevant for answering their question, |
|
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports |
|
- IPBES is for questions about biodiversity and nature |
|
- IPOS is for questions about the ocean and deep sea mining |
|
""", |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_query_decomposition_chain(llm): |
|
|
|
openai_functions = [convert_to_openai_function(QueryDecomposition)] |
|
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryDecomposition"}) |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"), |
|
("user", "input: {input}") |
|
]) |
|
|
|
chain = prompt | llm_with_functions | JsonOutputFunctionsParser() |
|
return chain |
|
|
|
|
|
def make_query_rewriter_chain(llm): |
|
|
|
openai_functions = [convert_to_openai_function(QueryAnalysis)] |
|
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"}) |
|
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"), |
|
("user", "input: {input}") |
|
]) |
|
|
|
|
|
chain = prompt | llm_with_functions | JsonOutputFunctionsParser() |
|
return chain |
|
|
|
|
|
def make_query_transform_node(llm,k_final=15): |
|
|
|
decomposition_chain = make_query_decomposition_chain(llm) |
|
rewriter_chain = make_query_rewriter_chain(llm) |
|
|
|
def transform_query(state): |
|
print("---- Transform query ----") |
|
|
|
|
|
if "sources_auto" not in state or state["sources_auto"] is None or state["sources_auto"] is False: |
|
auto_mode = False |
|
else: |
|
auto_mode = True |
|
|
|
sources_input = state.get("sources_input") |
|
if sources_input is None: sources_input = ROUTING_INDEX["Vector"] |
|
|
|
new_state = {} |
|
|
|
|
|
decomposition_output = decomposition_chain.invoke({"input":state["query"]}) |
|
new_state.update(decomposition_output) |
|
|
|
|
|
questions = [] |
|
for question in new_state["questions"]: |
|
question_state = {"question":question} |
|
analysis_output = rewriter_chain.invoke({"input":question}) |
|
|
|
|
|
|
|
if not analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS"] for source in analysis_output["sources"]): |
|
analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"] |
|
|
|
question_state.update(analysis_output) |
|
questions.append(question_state) |
|
|
|
|
|
new_questions = [] |
|
for q in questions: |
|
question,sources = q["question"],q["sources"] |
|
|
|
|
|
if not auto_mode: |
|
sources = sources_input |
|
|
|
for index,index_sources in ROUTING_INDEX.items(): |
|
selected_sources = list(set(sources).intersection(index_sources)) |
|
if len(selected_sources) > 0: |
|
new_questions.append({"question":question,"sources":selected_sources,"index":index}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_state = { |
|
"remaining_questions":new_questions, |
|
"n_questions":len(new_questions), |
|
} |
|
|
|
return new_state |
|
|
|
return transform_query |