Spaces:
Runtime error
Runtime error
| import datetime | |
| import warnings | |
| from typing import Any, Literal, Optional, Sequence, Union | |
| from langchain_core.utils import check_package_version | |
| from typing_extensions import TypedDict | |
| try: | |
| check_package_version("lark", gte_version="1.1.5") | |
| from lark import Lark, Transformer, v_args | |
| except ImportError: | |
| def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore | |
| """Dummy decorator for when lark is not installed.""" | |
| return lambda _: None | |
| Transformer = object # type: ignore | |
| Lark = object # type: ignore | |
| from langchain.chains.query_constructor.ir import ( | |
| Comparator, | |
| Comparison, | |
| FilterDirective, | |
| Operation, | |
| Operator, | |
| ) | |
| GRAMMAR = r""" | |
| ?program: func_call | |
| ?expr: func_call | |
| | value | |
| func_call: CNAME "(" [args] ")" | |
| ?value: SIGNED_INT -> int | |
| | SIGNED_FLOAT -> float | |
| | DATE -> date | |
| | list | |
| | string | |
| | ("false" | "False" | "FALSE") -> false | |
| | ("true" | "True" | "TRUE") -> true | |
| args: expr ("," expr)* | |
| DATE.2: /["']?(\d{4}-[01]\d-[0-3]\d)["']?/ | |
| string: /'[^']*'/ | ESCAPED_STRING | |
| list: "[" [args] "]" | |
| %import common.CNAME | |
| %import common.ESCAPED_STRING | |
| %import common.SIGNED_FLOAT | |
| %import common.SIGNED_INT | |
| %import common.WS | |
| %ignore WS | |
| """ | |
| class ISO8601Date(TypedDict): | |
| """A date in ISO 8601 format (YYYY-MM-DD).""" | |
| date: str | |
| type: Literal["date"] | |
| class QueryTransformer(Transformer): | |
| """Transforms a query string into an intermediate representation.""" | |
| def __init__( | |
| self, | |
| *args: Any, | |
| allowed_comparators: Optional[Sequence[Comparator]] = None, | |
| allowed_operators: Optional[Sequence[Operator]] = None, | |
| allowed_attributes: Optional[Sequence[str]] = None, | |
| **kwargs: Any, | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.allowed_comparators = allowed_comparators | |
| self.allowed_operators = allowed_operators | |
| self.allowed_attributes = allowed_attributes | |
| def program(self, *items: Any) -> tuple: | |
| return items | |
| def func_call(self, func_name: Any, args: list) -> FilterDirective: | |
| func = self._match_func_name(str(func_name)) | |
| if isinstance(func, Comparator): | |
| if self.allowed_attributes and args[0] not in self.allowed_attributes: | |
| raise ValueError( | |
| f"Received invalid attributes {args[0]}. Allowed attributes are " | |
| f"{self.allowed_attributes}" | |
| ) | |
| return Comparison(comparator=func, attribute=args[0], value=args[1]) | |
| elif len(args) == 1 and func in (Operator.AND, Operator.OR): | |
| return args[0] | |
| else: | |
| return Operation(operator=func, arguments=args) | |
| def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]: | |
| if func_name in set(Comparator): | |
| if self.allowed_comparators is not None: | |
| if func_name not in self.allowed_comparators: | |
| raise ValueError( | |
| f"Received disallowed comparator {func_name}. Allowed " | |
| f"comparators are {self.allowed_comparators}" | |
| ) | |
| return Comparator(func_name) | |
| elif func_name in set(Operator): | |
| if self.allowed_operators is not None: | |
| if func_name not in self.allowed_operators: | |
| raise ValueError( | |
| f"Received disallowed operator {func_name}. Allowed operators" | |
| f" are {self.allowed_operators}" | |
| ) | |
| return Operator(func_name) | |
| else: | |
| raise ValueError( | |
| f"Received unrecognized function {func_name}. Valid functions are " | |
| f"{list(Operator) + list(Comparator)}" | |
| ) | |
| def args(self, *items: Any) -> tuple: | |
| return items | |
| def false(self) -> bool: | |
| return False | |
| def true(self) -> bool: | |
| return True | |
| def list(self, item: Any) -> list: | |
| if item is None: | |
| return [] | |
| return list(item) | |
| def int(self, item: Any) -> int: | |
| return int(item) | |
| def float(self, item: Any) -> float: | |
| return float(item) | |
| def date(self, item: Any) -> ISO8601Date: | |
| item = str(item).strip("\"'") | |
| try: | |
| datetime.datetime.strptime(item, "%Y-%m-%d") | |
| except ValueError: | |
| warnings.warn( | |
| "Dates are expected to be provided in ISO 8601 date format " | |
| "(YYYY-MM-DD)." | |
| ) | |
| return {"date": item, "type": "date"} | |
| def string(self, item: Any) -> str: | |
| # Remove escaped quotes | |
| return str(item).strip("\"'") | |
| def get_parser( | |
| allowed_comparators: Optional[Sequence[Comparator]] = None, | |
| allowed_operators: Optional[Sequence[Operator]] = None, | |
| allowed_attributes: Optional[Sequence[str]] = None, | |
| ) -> Lark: | |
| """ | |
| Returns a parser for the query language. | |
| Args: | |
| allowed_comparators: Optional[Sequence[Comparator]] | |
| allowed_operators: Optional[Sequence[Operator]] | |
| Returns: | |
| Lark parser for the query language. | |
| """ | |
| # QueryTransformer is None when Lark cannot be imported. | |
| if QueryTransformer is None: | |
| raise ImportError( | |
| "Cannot import lark, please install it with 'pip install lark'." | |
| ) | |
| transformer = QueryTransformer( | |
| allowed_comparators=allowed_comparators, | |
| allowed_operators=allowed_operators, | |
| allowed_attributes=allowed_attributes, | |
| ) | |
| return Lark(GRAMMAR, parser="lalr", transformer=transformer, start="program") | |