File size: 3,982 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast

from langchain_core.stores import BaseStore


class UpstashRedisStore(BaseStore[str, str]):
    """BaseStore implementation using Upstash Redis as the underlying store."""

    def __init__(
        self,
        *,
        client: Any = None,
        url: Optional[str] = None,
        token: Optional[str] = None,
        ttl: Optional[int] = None,
        namespace: Optional[str] = None,
    ) -> None:
        """Initialize the UpstashRedisStore with HTTP API.

        Must provide either an Upstash Redis client or a url.

        Args:
            client: An Upstash Redis instance
            url: UPSTASH_REDIS_REST_URL
            token: UPSTASH_REDIS_REST_TOKEN
            ttl: time to expire keys in seconds if provided,
                 if None keys will never expire
            namespace: if provided, all keys will be prefixed with this namespace
        """
        try:
            from upstash_redis import Redis
        except ImportError as e:
            raise ImportError(
                "UpstashRedisStore requires the upstash_redis library to be installed. "
                "pip install upstash_redis"
            ) from e

        if client and url:
            raise ValueError(
                "Either an Upstash Redis client or a url must be provided, not both."
            )

        if client:
            if not isinstance(client, Redis):
                raise TypeError(
                    f"Expected Upstash Redis client, got {type(client).__name__}."
                )
            _client = client
        else:
            if not url or not token:
                raise ValueError(
                    "Either an Upstash Redis client or url and token must be provided."
                )
            _client = Redis(url=url, token=token)

        self.client = _client

        if not isinstance(ttl, int) and ttl is not None:
            raise TypeError(f"Expected int or None, got {type(ttl)} instead.")

        self.ttl = ttl
        self.namespace = namespace

    def _get_prefixed_key(self, key: str) -> str:
        """Get the key with the namespace prefix.

        Args:
            key (str): The original key.

        Returns:
            str: The key with the namespace prefix.
        """
        delimiter = "/"
        if self.namespace:
            return f"{self.namespace}{delimiter}{key}"
        return key

    def mget(self, keys: Sequence[str]) -> List[Optional[str]]:
        """Get the values associated with the given keys."""

        keys = [self._get_prefixed_key(key) for key in keys]
        return cast(
            List[Optional[str]],
            self.client.mget(*keys),
        )

    def mset(self, key_value_pairs: Sequence[Tuple[str, str]]) -> None:
        """Set the given key-value pairs."""
        for key, value in key_value_pairs:
            self.client.set(self._get_prefixed_key(key), value, ex=self.ttl)

    def mdelete(self, keys: Sequence[str]) -> None:
        """Delete the given keys."""
        _keys = [self._get_prefixed_key(key) for key in keys]
        self.client.delete(*_keys)

    def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
        """Yield keys in the store."""
        if prefix:
            pattern = self._get_prefixed_key(prefix)
        else:
            pattern = self._get_prefixed_key("*")

        cursor, keys = self.client.scan(0, match=pattern)
        for key in keys:
            if self.namespace:
                relative_key = key[len(self.namespace) + 1 :]
                yield relative_key
            else:
                yield key

        while cursor != 0:
            cursor, keys = self.client.scan(cursor, match=pattern)
            for key in keys:
                if self.namespace:
                    relative_key = key[len(self.namespace) + 1 :]
                    yield relative_key
                else:
                    yield key