File size: 3,033 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import Dict, Tuple, Union

from langchain.chains.query_constructor.ir import (
    Comparator,
    Comparison,
    Operation,
    Operator,
    StructuredQuery,
    Visitor,
)


class ElasticsearchTranslator(Visitor):
    """Translate `Elasticsearch` internal query language elements to valid filters."""

    allowed_comparators = [
        Comparator.EQ,
        Comparator.GT,
        Comparator.GTE,
        Comparator.LT,
        Comparator.LTE,
        Comparator.CONTAIN,
        Comparator.LIKE,
    ]
    """Subset of allowed logical comparators."""

    allowed_operators = [Operator.AND, Operator.OR, Operator.NOT]
    """Subset of allowed logical operators."""

    def _format_func(self, func: Union[Operator, Comparator]) -> str:
        self._validate_func(func)
        map_dict = {
            Operator.OR: "should",
            Operator.NOT: "must_not",
            Operator.AND: "must",
            Comparator.EQ: "term",
            Comparator.GT: "gt",
            Comparator.GTE: "gte",
            Comparator.LT: "lt",
            Comparator.LTE: "lte",
            Comparator.CONTAIN: "match",
            Comparator.LIKE: "match",
        }
        return map_dict[func]

    def visit_operation(self, operation: Operation) -> Dict:
        args = [arg.accept(self) for arg in operation.arguments]

        return {"bool": {self._format_func(operation.operator): args}}

    def visit_comparison(self, comparison: Comparison) -> Dict:
        # ElasticsearchStore filters require to target
        # the metadata object field
        field = f"metadata.{comparison.attribute}"

        is_range_comparator = comparison.comparator in [
            Comparator.GT,
            Comparator.GTE,
            Comparator.LT,
            Comparator.LTE,
        ]

        if is_range_comparator:
            return {
                "range": {
                    field: {self._format_func(comparison.comparator): comparison.value}
                }
            }

        if comparison.comparator == Comparator.CONTAIN:
            return {
                self._format_func(comparison.comparator): {
                    field: {"query": comparison.value}
                }
            }

        if comparison.comparator == Comparator.LIKE:
            return {
                self._format_func(comparison.comparator): {
                    field: {"query": comparison.value, "fuzziness": "AUTO"}
                }
            }

        # we assume that if the value is a string,
        # we want to use the keyword field
        field = f"{field}.keyword" if isinstance(comparison.value, str) else field

        return {self._format_func(comparison.comparator): {field: comparison.value}}

    def visit_structured_query(
        self, structured_query: StructuredQuery
    ) -> Tuple[str, dict]:
        if structured_query.filter is None:
            kwargs = {}
        else:
            kwargs = {"filter": [structured_query.filter.accept(self)]}
        return structured_query.query, kwargs