from langchain.output_parsers.structured import StructuredOutputParser, ResponseSchema from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch from climateqa.engine.prompts import reformulation_prompt_template from climateqa.engine.utils import pass_values, flatten_dict response_schemas = [ ResponseSchema(name="language", description="The detected language of the input message"), ResponseSchema(name="question", description="The reformulated question always in English") ] output_parser = StructuredOutputParser.from_response_schemas(response_schemas) format_instructions = output_parser.get_format_instructions() def fallback_default_values(x): if x["question"] is None: x["question"] = x["query"] x["language"] = "english" return x def make_reformulation_chain(llm): prompt = PromptTemplate( template=reformulation_prompt_template, input_variables=["query"], partial_variables={"format_instructions": format_instructions} ) chain = (prompt | llm.bind(stop=["```"]) | output_parser) reformulation_chain = ( {"reformulation":chain,**pass_values(["query"])} | RunnablePassthrough() | flatten_dict | fallback_default_values ) return reformulation_chain