Spaces:
Runtime error
Runtime error
"""LLM Chain for turning a user text query into a structured query.""" | |
from __future__ import annotations | |
import json | |
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast | |
from langchain_core.exceptions import OutputParserException | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.output_parsers import BaseOutputParser | |
from langchain_core.prompts import BasePromptTemplate | |
from langchain_core.prompts.few_shot import FewShotPromptTemplate | |
from langchain_core.runnables import Runnable | |
from langchain.chains.llm import LLMChain | |
from langchain.chains.query_constructor.ir import ( | |
Comparator, | |
Comparison, | |
FilterDirective, | |
Operation, | |
Operator, | |
StructuredQuery, | |
) | |
from langchain.chains.query_constructor.parser import get_parser | |
from langchain.chains.query_constructor.prompt import ( | |
DEFAULT_EXAMPLES, | |
DEFAULT_PREFIX, | |
DEFAULT_SCHEMA_PROMPT, | |
DEFAULT_SUFFIX, | |
EXAMPLE_PROMPT, | |
EXAMPLES_WITH_LIMIT, | |
PREFIX_WITH_DATA_SOURCE, | |
SCHEMA_WITH_LIMIT_PROMPT, | |
SUFFIX_WITHOUT_DATA_SOURCE, | |
USER_SPECIFIED_EXAMPLE_PROMPT, | |
) | |
from langchain.chains.query_constructor.schema import AttributeInfo | |
from langchain.output_parsers.json import parse_and_check_json_markdown | |
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): | |
"""Output parser that parses a structured query.""" | |
ast_parse: Callable | |
"""Callable that parses dict into internal representation of query language.""" | |
def parse(self, text: str) -> StructuredQuery: | |
try: | |
expected_keys = ["query", "filter"] | |
allowed_keys = ["query", "filter", "limit"] | |
parsed = parse_and_check_json_markdown(text, expected_keys) | |
if len(parsed["query"]) == 0: | |
parsed["query"] = " " | |
if parsed["filter"] == "NO_FILTER" or not parsed["filter"]: | |
parsed["filter"] = None | |
else: | |
parsed["filter"] = self.ast_parse(parsed["filter"]) | |
if not parsed.get("limit"): | |
parsed.pop("limit", None) | |
return StructuredQuery( | |
**{k: v for k, v in parsed.items() if k in allowed_keys} | |
) | |
except Exception as e: | |
raise OutputParserException( | |
f"Parsing text\n{text}\n raised following error:\n{e}" | |
) | |
def from_components( | |
cls, | |
allowed_comparators: Optional[Sequence[Comparator]] = None, | |
allowed_operators: Optional[Sequence[Operator]] = None, | |
allowed_attributes: Optional[Sequence[str]] = None, | |
fix_invalid: bool = False, | |
) -> StructuredQueryOutputParser: | |
""" | |
Create a structured query output parser from components. | |
Args: | |
allowed_comparators: allowed comparators | |
allowed_operators: allowed operators | |
Returns: | |
a structured query output parser | |
""" | |
ast_parse: Callable | |
if fix_invalid: | |
def ast_parse(raw_filter: str) -> Optional[FilterDirective]: | |
filter = cast(Optional[FilterDirective], get_parser().parse(raw_filter)) | |
fixed = fix_filter_directive( | |
filter, | |
allowed_comparators=allowed_comparators, | |
allowed_operators=allowed_operators, | |
allowed_attributes=allowed_attributes, | |
) | |
return fixed | |
else: | |
ast_parse = get_parser( | |
allowed_comparators=allowed_comparators, | |
allowed_operators=allowed_operators, | |
allowed_attributes=allowed_attributes, | |
).parse | |
return cls(ast_parse=ast_parse) | |
def fix_filter_directive( | |
filter: Optional[FilterDirective], | |
*, | |
allowed_comparators: Optional[Sequence[Comparator]] = None, | |
allowed_operators: Optional[Sequence[Operator]] = None, | |
allowed_attributes: Optional[Sequence[str]] = None, | |
) -> Optional[FilterDirective]: | |
"""Fix invalid filter directive. | |
Args: | |
filter: Filter directive to fix. | |
allowed_comparators: allowed comparators. Defaults to all comparators. | |
allowed_operators: allowed operators. Defaults to all operators. | |
allowed_attributes: allowed attributes. Defaults to all attributes. | |
Returns: | |
Fixed filter directive. | |
""" | |
if ( | |
not (allowed_comparators or allowed_operators or allowed_attributes) | |
) or not filter: | |
return filter | |
elif isinstance(filter, Comparison): | |
if allowed_comparators and filter.comparator not in allowed_comparators: | |
return None | |
if allowed_attributes and filter.attribute not in allowed_attributes: | |
return None | |
return filter | |
elif isinstance(filter, Operation): | |
if allowed_operators and filter.operator not in allowed_operators: | |
return None | |
args = [ | |
fix_filter_directive( | |
arg, | |
allowed_comparators=allowed_comparators, | |
allowed_operators=allowed_operators, | |
allowed_attributes=allowed_attributes, | |
) | |
for arg in filter.arguments | |
] | |
args = [arg for arg in args if arg is not None] | |
if not args: | |
return None | |
elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR): | |
return args[0] | |
else: | |
return Operation( | |
operator=filter.operator, | |
arguments=args, | |
) | |
else: | |
return filter | |
def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str: | |
info_dicts = {} | |
for i in info: | |
i_dict = dict(i) | |
info_dicts[i_dict.pop("name")] = i_dict | |
return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}") | |
def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]: | |
"""Construct examples from input-output pairs. | |
Args: | |
input_output_pairs: Sequence of input-output pairs. | |
Returns: | |
List of examples. | |
""" | |
examples = [] | |
for i, (_input, output) in enumerate(input_output_pairs): | |
structured_request = ( | |
json.dumps(output, indent=4).replace("{", "{{").replace("}", "}}") | |
) | |
example = { | |
"i": i + 1, | |
"user_query": _input, | |
"structured_request": structured_request, | |
} | |
examples.append(example) | |
return examples | |
def get_query_constructor_prompt( | |
document_contents: str, | |
attribute_info: Sequence[Union[AttributeInfo, dict]], | |
*, | |
examples: Optional[Sequence] = None, | |
allowed_comparators: Sequence[Comparator] = tuple(Comparator), | |
allowed_operators: Sequence[Operator] = tuple(Operator), | |
enable_limit: bool = False, | |
schema_prompt: Optional[BasePromptTemplate] = None, | |
**kwargs: Any, | |
) -> BasePromptTemplate: | |
"""Create query construction prompt. | |
Args: | |
document_contents: The contents of the document to be queried. | |
attribute_info: A list of AttributeInfo objects describing | |
the attributes of the document. | |
examples: Optional list of examples to use for the chain. | |
allowed_comparators: Sequence of allowed comparators. | |
allowed_operators: Sequence of allowed operators. | |
enable_limit: Whether to enable the limit operator. Defaults to False. | |
schema_prompt: Prompt for describing query schema. Should have string input | |
variables allowed_comparators and allowed_operators. | |
**kwargs: Additional named params to pass to FewShotPromptTemplate init. | |
Returns: | |
A prompt template that can be used to construct queries. | |
""" | |
default_schema_prompt = ( | |
SCHEMA_WITH_LIMIT_PROMPT if enable_limit else DEFAULT_SCHEMA_PROMPT | |
) | |
schema_prompt = schema_prompt or default_schema_prompt | |
attribute_str = _format_attribute_info(attribute_info) | |
schema = schema_prompt.format( | |
allowed_comparators=" | ".join(allowed_comparators), | |
allowed_operators=" | ".join(allowed_operators), | |
) | |
if examples and isinstance(examples[0], tuple): | |
examples = construct_examples(examples) | |
example_prompt = USER_SPECIFIED_EXAMPLE_PROMPT | |
prefix = PREFIX_WITH_DATA_SOURCE.format( | |
schema=schema, content=document_contents, attributes=attribute_str | |
) | |
suffix = SUFFIX_WITHOUT_DATA_SOURCE.format(i=len(examples) + 1) | |
else: | |
examples = examples or ( | |
EXAMPLES_WITH_LIMIT if enable_limit else DEFAULT_EXAMPLES | |
) | |
example_prompt = EXAMPLE_PROMPT | |
prefix = DEFAULT_PREFIX.format(schema=schema) | |
suffix = DEFAULT_SUFFIX.format( | |
i=len(examples) + 1, content=document_contents, attributes=attribute_str | |
) | |
return FewShotPromptTemplate( | |
examples=list(examples), | |
example_prompt=example_prompt, | |
input_variables=["query"], | |
suffix=suffix, | |
prefix=prefix, | |
**kwargs, | |
) | |
def load_query_constructor_chain( | |
llm: BaseLanguageModel, | |
document_contents: str, | |
attribute_info: Sequence[Union[AttributeInfo, dict]], | |
examples: Optional[List] = None, | |
allowed_comparators: Sequence[Comparator] = tuple(Comparator), | |
allowed_operators: Sequence[Operator] = tuple(Operator), | |
enable_limit: bool = False, | |
schema_prompt: Optional[BasePromptTemplate] = None, | |
**kwargs: Any, | |
) -> LLMChain: | |
"""Load a query constructor chain. | |
Args: | |
llm: BaseLanguageModel to use for the chain. | |
document_contents: The contents of the document to be queried. | |
attribute_info: Sequence of attributes in the document. | |
examples: Optional list of examples to use for the chain. | |
allowed_comparators: Sequence of allowed comparators. Defaults to all | |
Comparators. | |
allowed_operators: Sequence of allowed operators. Defaults to all Operators. | |
enable_limit: Whether to enable the limit operator. Defaults to False. | |
schema_prompt: Prompt for describing query schema. Should have string input | |
variables allowed_comparators and allowed_operators. | |
**kwargs: Arbitrary named params to pass to LLMChain. | |
Returns: | |
A LLMChain that can be used to construct queries. | |
""" | |
prompt = get_query_constructor_prompt( | |
document_contents, | |
attribute_info, | |
examples=examples, | |
allowed_comparators=allowed_comparators, | |
allowed_operators=allowed_operators, | |
enable_limit=enable_limit, | |
schema_prompt=schema_prompt, | |
) | |
allowed_attributes = [] | |
for ainfo in attribute_info: | |
allowed_attributes.append( | |
ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] | |
) | |
output_parser = StructuredQueryOutputParser.from_components( | |
allowed_comparators=allowed_comparators, | |
allowed_operators=allowed_operators, | |
allowed_attributes=allowed_attributes, | |
) | |
# For backwards compatibility. | |
prompt.output_parser = output_parser | |
return LLMChain(llm=llm, prompt=prompt, output_parser=output_parser, **kwargs) | |
def load_query_constructor_runnable( | |
llm: BaseLanguageModel, | |
document_contents: str, | |
attribute_info: Sequence[Union[AttributeInfo, dict]], | |
*, | |
examples: Optional[Sequence] = None, | |
allowed_comparators: Sequence[Comparator] = tuple(Comparator), | |
allowed_operators: Sequence[Operator] = tuple(Operator), | |
enable_limit: bool = False, | |
schema_prompt: Optional[BasePromptTemplate] = None, | |
fix_invalid: bool = False, | |
**kwargs: Any, | |
) -> Runnable: | |
"""Load a query constructor runnable chain. | |
Args: | |
llm: BaseLanguageModel to use for the chain. | |
document_contents: The contents of the document to be queried. | |
attribute_info: Sequence of attributes in the document. | |
examples: Optional list of examples to use for the chain. | |
allowed_comparators: Sequence of allowed comparators. Defaults to all | |
Comparators. | |
allowed_operators: Sequence of allowed operators. Defaults to all Operators. | |
enable_limit: Whether to enable the limit operator. Defaults to False. | |
schema_prompt: Prompt for describing query schema. Should have string input | |
variables allowed_comparators and allowed_operators. | |
fix_invalid: Whether to fix invalid filter directives by ignoring invalid | |
operators, comparators and attributes. | |
**kwargs: Additional named params to pass to FewShotPromptTemplate init. | |
Returns: | |
A Runnable that can be used to construct queries. | |
""" | |
prompt = get_query_constructor_prompt( | |
document_contents, | |
attribute_info, | |
examples=examples, | |
allowed_comparators=allowed_comparators, | |
allowed_operators=allowed_operators, | |
enable_limit=enable_limit, | |
schema_prompt=schema_prompt, | |
**kwargs, | |
) | |
allowed_attributes = [] | |
for ainfo in attribute_info: | |
allowed_attributes.append( | |
ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] | |
) | |
output_parser = StructuredQueryOutputParser.from_components( | |
allowed_comparators=allowed_comparators, | |
allowed_operators=allowed_operators, | |
allowed_attributes=allowed_attributes, | |
fix_invalid=fix_invalid, | |
) | |
return prompt | llm | output_parser | |