File size: 2,619 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from datetime import datetime
from typing import Dict, Tuple, Union

from langchain.chains.query_constructor.ir import (
    Comparator,
    Comparison,
    Operation,
    Operator,
    StructuredQuery,
    Visitor,
)


class WeaviateTranslator(Visitor):
    """Translate `Weaviate` internal query language elements to valid filters."""

    allowed_operators = [Operator.AND, Operator.OR]
    """Subset of allowed logical operators."""

    allowed_comparators = [
        Comparator.EQ,
        Comparator.NE,
        Comparator.GTE,
        Comparator.LTE,
        Comparator.LT,
        Comparator.GT,
    ]

    def _format_func(self, func: Union[Operator, Comparator]) -> str:
        self._validate_func(func)
        # https://weaviate.io/developers/weaviate/api/graphql/filters
        map_dict = {
            Operator.AND: "And",
            Operator.OR: "Or",
            Comparator.EQ: "Equal",
            Comparator.NE: "NotEqual",
            Comparator.GTE: "GreaterThanEqual",
            Comparator.LTE: "LessThanEqual",
            Comparator.LT: "LessThan",
            Comparator.GT: "GreaterThan",
        }
        return map_dict[func]

    def visit_operation(self, operation: Operation) -> Dict:
        args = [arg.accept(self) for arg in operation.arguments]
        return {"operator": self._format_func(operation.operator), "operands": args}

    def visit_comparison(self, comparison: Comparison) -> Dict:
        value_type = "valueText"
        value = comparison.value
        if isinstance(comparison.value, bool):
            value_type = "valueBoolean"
        elif isinstance(comparison.value, float):
            value_type = "valueNumber"
        elif isinstance(comparison.value, int):
            value_type = "valueInt"
        elif (
            isinstance(comparison.value, dict)
            and comparison.value.get("type") == "date"
        ):
            value_type = "valueDate"
            # ISO 8601 timestamp, formatted as RFC3339
            date = datetime.strptime(comparison.value["date"], "%Y-%m-%d")
            value = date.strftime("%Y-%m-%dT%H:%M:%SZ")
        filter = {
            "path": [comparison.attribute],
            "operator": self._format_func(comparison.comparator),
            value_type: value,
        }
        return filter

    def visit_structured_query(
        self, structured_query: StructuredQuery
    ) -> Tuple[str, dict]:
        if structured_query.filter is None:
            kwargs = {}
        else:
            kwargs = {"where_filter": structured_query.filter.accept(self)}
        return structured_query.query, kwargs