czhang17 commited on
Commit
561b7e9
·
1 Parent(s): 81e036c

Fix 'TOO MANY OPEN FILE' problem while using redis vector DB:

Browse files

Enhance RedisKVStorage: Implement connection pooling and error handling. Refactor async methods to use context managers for Redis operations, improving resource management and error logging. Batch processing added for key operations to optimize performance.

Files changed (2) hide show
  1. lightrag/kg/redis_impl.py +148 -89
  2. lightrag/operate.py +12 -3
lightrag/kg/redis_impl.py CHANGED
@@ -3,12 +3,14 @@ from typing import Any, final
3
  from dataclasses import dataclass
4
  import pipmaster as pm
5
  import configparser
 
6
 
7
  if not pm.is_installed("redis"):
8
  pm.install("redis")
9
 
10
  # aioredis is a depricated library, replaced with redis
11
- from redis.asyncio import Redis
 
12
  from lightrag.utils import logger, compute_mdhash_id
13
  from lightrag.base import BaseKVStorage
14
  import json
@@ -17,6 +19,11 @@ import json
17
  config = configparser.ConfigParser()
18
  config.read("config.ini", "utf-8")
19
 
 
 
 
 
 
20
 
21
  @final
22
  @dataclass
@@ -25,125 +32,177 @@ class RedisKVStorage(BaseKVStorage):
25
  redis_url = os.environ.get(
26
  "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
27
  )
28
- self._redis = Redis.from_url(redis_url, decode_responses=True)
29
- logger.info(f"Use Redis as KV {self.namespace}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
32
- data = await self._redis.get(f"{self.namespace}:{id}")
33
- return json.loads(data) if data else None
 
 
 
 
 
34
 
35
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
36
- pipe = self._redis.pipeline()
37
- for id in ids:
38
- pipe.get(f"{self.namespace}:{id}")
39
- results = await pipe.execute()
40
- return [json.loads(result) if result else None for result in results]
 
 
 
 
 
41
 
42
  async def filter_keys(self, keys: set[str]) -> set[str]:
43
- pipe = self._redis.pipeline()
44
- for key in keys:
45
- pipe.exists(f"{self.namespace}:{key}")
46
- results = await pipe.execute()
 
47
 
48
- existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
49
- return set(keys) - existing_ids
50
 
51
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
52
- logger.info(f"Inserting {len(data)} to {self.namespace}")
53
  if not data:
54
  return
55
- pipe = self._redis.pipeline()
56
-
57
- for k, v in data.items():
58
- pipe.set(f"{self.namespace}:{k}", json.dumps(v))
59
- await pipe.execute()
60
-
61
- for k in data:
62
- data[k]["_id"] = k
63
-
64
- async def index_done_callback(self) -> None:
65
- # Redis handles persistence automatically
66
- pass
 
 
67
 
68
  async def delete(self, ids: list[str]) -> None:
69
- """Delete entries with specified IDs
70
-
71
- Args:
72
- ids: List of entry IDs to be deleted
73
- """
74
  if not ids:
75
  return
76
 
77
- pipe = self._redis.pipeline()
78
- for id in ids:
79
- pipe.delete(f"{self.namespace}:{id}")
 
80
 
81
- results = await pipe.execute()
82
- deleted_count = sum(results)
83
- logger.info(
84
- f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
85
- )
86
 
87
  async def delete_entity(self, entity_name: str) -> None:
88
- """Delete an entity by name
89
-
90
- Args:
91
- entity_name: Name of the entity to delete
92
- """
93
-
94
  try:
95
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
96
  logger.debug(
97
  f"Attempting to delete entity {entity_name} with ID {entity_id}"
98
  )
99
 
100
- # Delete the entity
101
- result = await self._redis.delete(f"{self.namespace}:{entity_id}")
102
 
103
- if result:
104
- logger.debug(f"Successfully deleted entity {entity_name}")
105
- else:
106
- logger.debug(f"Entity {entity_name} not found in storage")
107
  except Exception as e:
108
  logger.error(f"Error deleting entity {entity_name}: {e}")
109
 
110
  async def delete_entity_relation(self, entity_name: str) -> None:
111
- """Delete all relations associated with an entity
112
-
113
- Args:
114
- entity_name: Name of the entity whose relations should be deleted
115
- """
116
  try:
117
- # Get all keys in this namespace
118
- cursor = 0
119
- relation_keys = []
120
- pattern = f"{self.namespace}:*"
121
-
122
- while True:
123
- cursor, keys = await self._redis.scan(cursor, match=pattern)
124
-
125
- # For each key, get the value and check if it's related to entity_name
126
- for key in keys:
127
- value = await self._redis.get(key)
128
- if value:
129
- data = json.loads(value)
130
- # Check if this is a relation involving the entity
131
- if (
132
- data.get("src_id") == entity_name
133
- or data.get("tgt_id") == entity_name
134
- ):
135
- relation_keys.append(key)
136
-
137
- # Exit loop when cursor returns to 0
138
- if cursor == 0:
139
- break
140
-
141
- # Delete the relation keys
142
- if relation_keys:
143
- deleted = await self._redis.delete(*relation_keys)
144
- logger.debug(f"Deleted {deleted} relations for {entity_name}")
145
- else:
146
- logger.debug(f"No relations found for entity {entity_name}")
 
 
 
 
 
 
 
 
 
 
147
 
148
  except Exception as e:
149
  logger.error(f"Error deleting relations for {entity_name}: {e}")
 
 
 
 
 
3
  from dataclasses import dataclass
4
  import pipmaster as pm
5
  import configparser
6
+ from contextlib import asynccontextmanager
7
 
8
  if not pm.is_installed("redis"):
9
  pm.install("redis")
10
 
11
  # aioredis is a depricated library, replaced with redis
12
+ from redis.asyncio import Redis, ConnectionPool
13
+ from redis.exceptions import RedisError, ConnectionError
14
  from lightrag.utils import logger, compute_mdhash_id
15
  from lightrag.base import BaseKVStorage
16
  import json
 
19
  config = configparser.ConfigParser()
20
  config.read("config.ini", "utf-8")
21
 
22
+ # Constants for Redis connection pool
23
+ MAX_CONNECTIONS = 50
24
+ SOCKET_TIMEOUT = 5.0
25
+ SOCKET_CONNECT_TIMEOUT = 3.0
26
+
27
 
28
  @final
29
  @dataclass
 
32
  redis_url = os.environ.get(
33
  "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
34
  )
35
+ # Create a connection pool with limits
36
+ self._pool = ConnectionPool.from_url(
37
+ redis_url,
38
+ max_connections=MAX_CONNECTIONS,
39
+ decode_responses=True,
40
+ socket_timeout=SOCKET_TIMEOUT,
41
+ socket_connect_timeout=SOCKET_CONNECT_TIMEOUT
42
+ )
43
+ self._redis = Redis(connection_pool=self._pool)
44
+ logger.info(f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections")
45
+
46
+ @asynccontextmanager
47
+ async def _get_redis_connection(self):
48
+ """Safe context manager for Redis operations."""
49
+ try:
50
+ yield self._redis
51
+ except ConnectionError as e:
52
+ logger.error(f"Redis connection error in {self.namespace}: {e}")
53
+ raise
54
+ except RedisError as e:
55
+ logger.error(f"Redis operation error in {self.namespace}: {e}")
56
+ raise
57
+ except Exception as e:
58
+ logger.error(f"Unexpected error in Redis operation for {self.namespace}: {e}")
59
+ raise
60
+
61
+ async def close(self):
62
+ """Close the Redis connection pool to prevent resource leaks."""
63
+ if hasattr(self, '_redis') and self._redis:
64
+ await self._redis.close()
65
+ await self._pool.disconnect()
66
+ logger.debug(f"Closed Redis connection pool for {self.namespace}")
67
+
68
+ async def __aenter__(self):
69
+ """Support for async context manager."""
70
+ return self
71
+
72
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
73
+ """Ensure Redis resources are cleaned up when exiting context."""
74
+ await self.close()
75
 
76
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
77
+ async with self._get_redis_connection() as redis:
78
+ try:
79
+ data = await redis.get(f"{self.namespace}:{id}")
80
+ return json.loads(data) if data else None
81
+ except json.JSONDecodeError as e:
82
+ logger.error(f"JSON decode error for id {id}: {e}")
83
+ return None
84
 
85
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
86
+ async with self._get_redis_connection() as redis:
87
+ try:
88
+ pipe = redis.pipeline()
89
+ for id in ids:
90
+ pipe.get(f"{self.namespace}:{id}")
91
+ results = await pipe.execute()
92
+ return [json.loads(result) if result else None for result in results]
93
+ except json.JSONDecodeError as e:
94
+ logger.error(f"JSON decode error in batch get: {e}")
95
+ return [None] * len(ids)
96
 
97
  async def filter_keys(self, keys: set[str]) -> set[str]:
98
+ async with self._get_redis_connection() as redis:
99
+ pipe = redis.pipeline()
100
+ for key in keys:
101
+ pipe.exists(f"{self.namespace}:{key}")
102
+ results = await pipe.execute()
103
 
104
+ existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
105
+ return set(keys) - existing_ids
106
 
107
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
108
  if not data:
109
  return
110
+
111
+ logger.info(f"Inserting {len(data)} items to {self.namespace}")
112
+ async with self._get_redis_connection() as redis:
113
+ try:
114
+ pipe = redis.pipeline()
115
+ for k, v in data.items():
116
+ pipe.set(f"{self.namespace}:{k}", json.dumps(v))
117
+ await pipe.execute()
118
+
119
+ for k in data:
120
+ data[k]["_id"] = k
121
+ except json.JSONEncodeError as e:
122
+ logger.error(f"JSON encode error during upsert: {e}")
123
+ raise
124
 
125
  async def delete(self, ids: list[str]) -> None:
126
+ """Delete entries with specified IDs"""
 
 
 
 
127
  if not ids:
128
  return
129
 
130
+ async with self._get_redis_connection() as redis:
131
+ pipe = redis.pipeline()
132
+ for id in ids:
133
+ pipe.delete(f"{self.namespace}:{id}")
134
 
135
+ results = await pipe.execute()
136
+ deleted_count = sum(results)
137
+ logger.info(
138
+ f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
139
+ )
140
 
141
  async def delete_entity(self, entity_name: str) -> None:
142
+ """Delete an entity by name"""
 
 
 
 
 
143
  try:
144
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
145
  logger.debug(
146
  f"Attempting to delete entity {entity_name} with ID {entity_id}"
147
  )
148
 
149
+ async with self._get_redis_connection() as redis:
150
+ result = await redis.delete(f"{self.namespace}:{entity_id}")
151
 
152
+ if result:
153
+ logger.debug(f"Successfully deleted entity {entity_name}")
154
+ else:
155
+ logger.debug(f"Entity {entity_name} not found in storage")
156
  except Exception as e:
157
  logger.error(f"Error deleting entity {entity_name}: {e}")
158
 
159
  async def delete_entity_relation(self, entity_name: str) -> None:
160
+ """Delete all relations associated with an entity"""
 
 
 
 
161
  try:
162
+ async with self._get_redis_connection() as redis:
163
+ cursor = 0
164
+ relation_keys = []
165
+ pattern = f"{self.namespace}:*"
166
+
167
+ while True:
168
+ cursor, keys = await redis.scan(cursor, match=pattern)
169
+
170
+ # Process keys in batches
171
+ pipe = redis.pipeline()
172
+ for key in keys:
173
+ pipe.get(key)
174
+ values = await pipe.execute()
175
+
176
+ for key, value in zip(keys, values):
177
+ if value:
178
+ try:
179
+ data = json.loads(value)
180
+ if (
181
+ data.get("src_id") == entity_name
182
+ or data.get("tgt_id") == entity_name
183
+ ):
184
+ relation_keys.append(key)
185
+ except json.JSONDecodeError:
186
+ logger.warning(f"Invalid JSON in key {key}")
187
+ continue
188
+
189
+ if cursor == 0:
190
+ break
191
+
192
+ # Delete relations in batches
193
+ if relation_keys:
194
+ # Delete in chunks to avoid too many arguments
195
+ chunk_size = 1000
196
+ for i in range(0, len(relation_keys), chunk_size):
197
+ chunk = relation_keys[i:i + chunk_size]
198
+ deleted = await redis.delete(*chunk)
199
+ logger.debug(f"Deleted {deleted} relations for {entity_name} (batch {i//chunk_size + 1})")
200
+ else:
201
+ logger.debug(f"No relations found for entity {entity_name}")
202
 
203
  except Exception as e:
204
  logger.error(f"Error deleting relations for {entity_name}: {e}")
205
+
206
+ async def index_done_callback(self) -> None:
207
+ # Redis handles persistence automatically
208
+ pass
lightrag/operate.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import asyncio
 
4
  import json
5
  import re
6
  import os
@@ -994,6 +995,7 @@ async def mix_kg_vector_query(
994
 
995
  except Exception as e:
996
  logger.error(f"Error in get_kg_context: {str(e)}")
 
997
  return None
998
 
999
  async def get_vector_context():
@@ -1382,9 +1384,16 @@ async def _find_most_related_text_unit_from_entities(
1382
  all_text_units_lookup[c_id] = index
1383
  tasks.append((c_id, index, this_edges))
1384
 
1385
- results = await asyncio.gather(
1386
- *[text_chunks_db.get_by_id(c_id) for c_id, _, _ in tasks]
1387
- )
 
 
 
 
 
 
 
1388
 
1389
  for (c_id, index, this_edges), data in zip(tasks, results):
1390
  all_text_units_lookup[c_id] = {
 
1
  from __future__ import annotations
2
 
3
  import asyncio
4
+ import traceback
5
  import json
6
  import re
7
  import os
 
995
 
996
  except Exception as e:
997
  logger.error(f"Error in get_kg_context: {str(e)}")
998
+ traceback.print_exc()
999
  return None
1000
 
1001
  async def get_vector_context():
 
1384
  all_text_units_lookup[c_id] = index
1385
  tasks.append((c_id, index, this_edges))
1386
 
1387
+ # Process in batches of 25 tasks at a time to avoid overwhelming resources
1388
+ batch_size = 25
1389
+ results = []
1390
+
1391
+ for i in range(0, len(tasks), batch_size):
1392
+ batch_tasks = tasks[i:i + batch_size]
1393
+ batch_results = await asyncio.gather(
1394
+ *[text_chunks_db.get_by_id(c_id) for c_id, _, _ in batch_tasks]
1395
+ )
1396
+ results.extend(batch_results)
1397
 
1398
  for (c_id, index, this_edges), data in zip(tasks, results):
1399
  all_text_units_lookup[c_id] = {