Spaces:
Runtime error
Runtime error
File size: 5,707 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import datetime
import warnings
from typing import Any, Literal, Optional, Sequence, Union
from langchain_core.utils import check_package_version
from typing_extensions import TypedDict
try:
check_package_version("lark", gte_version="1.1.5")
from lark import Lark, Transformer, v_args
except ImportError:
def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore
"""Dummy decorator for when lark is not installed."""
return lambda _: None
Transformer = object # type: ignore
Lark = object # type: ignore
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
FilterDirective,
Operation,
Operator,
)
GRAMMAR = r"""
?program: func_call
?expr: func_call
| value
func_call: CNAME "(" [args] ")"
?value: SIGNED_INT -> int
| SIGNED_FLOAT -> float
| DATE -> date
| list
| string
| ("false" | "False" | "FALSE") -> false
| ("true" | "True" | "TRUE") -> true
args: expr ("," expr)*
DATE.2: /["']?(\d{4}-[01]\d-[0-3]\d)["']?/
string: /'[^']*'/ | ESCAPED_STRING
list: "[" [args] "]"
%import common.CNAME
%import common.ESCAPED_STRING
%import common.SIGNED_FLOAT
%import common.SIGNED_INT
%import common.WS
%ignore WS
"""
class ISO8601Date(TypedDict):
"""A date in ISO 8601 format (YYYY-MM-DD)."""
date: str
type: Literal["date"]
@v_args(inline=True)
class QueryTransformer(Transformer):
"""Transforms a query string into an intermediate representation."""
def __init__(
self,
*args: Any,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
allowed_attributes: Optional[Sequence[str]] = None,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.allowed_comparators = allowed_comparators
self.allowed_operators = allowed_operators
self.allowed_attributes = allowed_attributes
def program(self, *items: Any) -> tuple:
return items
def func_call(self, func_name: Any, args: list) -> FilterDirective:
func = self._match_func_name(str(func_name))
if isinstance(func, Comparator):
if self.allowed_attributes and args[0] not in self.allowed_attributes:
raise ValueError(
f"Received invalid attributes {args[0]}. Allowed attributes are "
f"{self.allowed_attributes}"
)
return Comparison(comparator=func, attribute=args[0], value=args[1])
elif len(args) == 1 and func in (Operator.AND, Operator.OR):
return args[0]
else:
return Operation(operator=func, arguments=args)
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
if func_name in set(Comparator):
if self.allowed_comparators is not None:
if func_name not in self.allowed_comparators:
raise ValueError(
f"Received disallowed comparator {func_name}. Allowed "
f"comparators are {self.allowed_comparators}"
)
return Comparator(func_name)
elif func_name in set(Operator):
if self.allowed_operators is not None:
if func_name not in self.allowed_operators:
raise ValueError(
f"Received disallowed operator {func_name}. Allowed operators"
f" are {self.allowed_operators}"
)
return Operator(func_name)
else:
raise ValueError(
f"Received unrecognized function {func_name}. Valid functions are "
f"{list(Operator) + list(Comparator)}"
)
def args(self, *items: Any) -> tuple:
return items
def false(self) -> bool:
return False
def true(self) -> bool:
return True
def list(self, item: Any) -> list:
if item is None:
return []
return list(item)
def int(self, item: Any) -> int:
return int(item)
def float(self, item: Any) -> float:
return float(item)
def date(self, item: Any) -> ISO8601Date:
item = str(item).strip("\"'")
try:
datetime.datetime.strptime(item, "%Y-%m-%d")
except ValueError:
warnings.warn(
"Dates are expected to be provided in ISO 8601 date format "
"(YYYY-MM-DD)."
)
return {"date": item, "type": "date"}
def string(self, item: Any) -> str:
# Remove escaped quotes
return str(item).strip("\"'")
def get_parser(
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
allowed_attributes: Optional[Sequence[str]] = None,
) -> Lark:
"""
Returns a parser for the query language.
Args:
allowed_comparators: Optional[Sequence[Comparator]]
allowed_operators: Optional[Sequence[Operator]]
Returns:
Lark parser for the query language.
"""
# QueryTransformer is None when Lark cannot be imported.
if QueryTransformer is None:
raise ImportError(
"Cannot import lark, please install it with 'pip install lark'."
)
transformer = QueryTransformer(
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
)
return Lark(GRAMMAR, parser="lalr", transformer=transformer, start="program")
|