yangdx commited on
Commit
38074de
·
1 Parent(s): a940648

Fix race condition for health_check and ensure_workers

Browse files
Files changed (1) hide show
  1. lightrag/utils.py +44 -57
lightrag/utils.py CHANGED
@@ -289,9 +289,10 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
289
  def final_decro(func):
290
  queue = asyncio.PriorityQueue(maxsize=max_queue_size)
291
  tasks = set()
292
- lock = asyncio.Lock()
293
  counter = 0
294
  shutdown_event = asyncio.Event()
 
295
  worker_health_check_task = None
296
 
297
  # Track active future objects for cleanup
@@ -352,76 +353,62 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
352
  while not shutdown_event.is_set():
353
  await asyncio.sleep(5) # Check every 5 seconds
354
 
355
- async with lock:
356
- # Directly remove completed tasks from the tasks set
357
- tasks.difference_update({t for t in tasks if t.done()})
 
 
358
 
359
- # Create new workers if active tasks less than max_size for better performance
360
- active_tasks_count = len(tasks)
361
- workers_needed = max_size - active_tasks_count
362
- if workers_needed > 0:
363
- logger.info(
364
- f"limit_async: Creating {workers_needed} new workers"
365
- )
366
- for _ in range(workers_needed):
367
- task = asyncio.create_task(worker())
368
- tasks.add(task)
369
- task.add_done_callback(tasks.discard)
 
 
 
 
370
  except Exception as e:
371
  logger.error(f"limit_async: Error in health check: {str(e)}")
372
  finally:
373
  logger.warning("limit_async: Health check task exiting")
374
 
375
- # Ensure worker tasks are started
376
  async def ensure_workers():
377
- """Ensure worker tasks and health check are started"""
378
- nonlocal tasks, worker_health_check_task
379
 
380
- # Use timeout lock to prevent deadlock
381
- try:
382
- lock_acquired = False
383
- try:
384
- # Try to acquire the lock, wait up to 5 seconds
385
- lock_acquired = await asyncio.wait_for(lock.acquire(), timeout=5.0)
386
- except asyncio.TimeoutError:
387
- logger.error(
388
- "limit_async: Timeout acquiring lock in ensure_workers"
389
- )
390
- # Even if acquiring the lock times out, continue trying to create workers
391
 
392
- try:
393
- # Start the health check task (if not already started)
394
- if (
395
- worker_health_check_task is None
396
- or worker_health_check_task.done()
397
- ):
398
- worker_health_check_task = asyncio.create_task(health_check())
399
 
400
- # Directly remove completed tasks from the tasks set
401
- tasks.difference_update({t for t in tasks if t.done()})
 
402
 
403
- # Calculate the number of active tasks
404
- active_tasks_count = len(tasks)
405
 
406
- # If active tasks count is less than max_size, create new workers
407
- workers_needed = max_size - active_tasks_count
408
- if workers_needed > 0:
409
- for _ in range(workers_needed):
410
- task = asyncio.create_task(worker())
411
- tasks.add(task)
412
- task.add_done_callback(tasks.discard)
413
- finally:
414
- # Ensure the lock is released
415
- if lock_acquired:
416
- lock.release()
417
- except Exception as e:
418
- logger.error(f"limit_async: Error in ensure_workers: {str(e)}")
419
- # Even if an exception occurs, try to create at least one worker
420
- if not any(not t.done() for t in tasks):
421
  task = asyncio.create_task(worker())
422
  tasks.add(task)
423
  task.add_done_callback(tasks.discard)
424
 
 
 
 
 
 
 
425
  async def shutdown():
426
  """Gracefully shut down all workers and the queue"""
427
  logger.info("limit_async: Shutting down priority queue workers")
@@ -480,7 +467,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
480
  QueueFullError: If the queue is full and waiting times out
481
  Any exception raised by the decorated function
482
  """
483
- # Ensure workers are started
484
  await ensure_workers()
485
 
486
  # Create a future for the result
@@ -488,7 +475,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
488
  active_futures.add(future)
489
 
490
  nonlocal counter
491
- async with lock:
492
  current_count = counter
493
  counter += 1
494
 
 
289
  def final_decro(func):
290
  queue = asyncio.PriorityQueue(maxsize=max_queue_size)
291
  tasks = set()
292
+ initialization_lock = asyncio.Lock()
293
  counter = 0
294
  shutdown_event = asyncio.Event()
295
+ initialized = False # Global initialization flag
296
  worker_health_check_task = None
297
 
298
  # Track active future objects for cleanup
 
353
  while not shutdown_event.is_set():
354
  await asyncio.sleep(5) # Check every 5 seconds
355
 
356
+ # No longer acquire lock, directly operate on task set
357
+ # Use a copy of the task set to avoid concurrent modification
358
+ current_tasks = set(tasks)
359
+ done_tasks = {t for t in current_tasks if t.done()}
360
+ tasks.difference_update(done_tasks)
361
 
362
+ # Calculate active tasks count
363
+ active_tasks_count = len(tasks)
364
+ workers_needed = max_size - active_tasks_count
365
+
366
+ if workers_needed > 0:
367
+ logger.info(
368
+ f"limit_async: Creating {workers_needed} new workers"
369
+ )
370
+ new_tasks = set()
371
+ for _ in range(workers_needed):
372
+ task = asyncio.create_task(worker())
373
+ new_tasks.add(task)
374
+ task.add_done_callback(tasks.discard)
375
+ # Update task set in one operation
376
+ tasks.update(new_tasks)
377
  except Exception as e:
378
  logger.error(f"limit_async: Error in health check: {str(e)}")
379
  finally:
380
  logger.warning("limit_async: Health check task exiting")
381
 
 
382
  async def ensure_workers():
383
+ """Ensure worker threads and health check system are available
 
384
 
385
+ This function checks if the worker system is already initialized.
386
+ If not, it performs a one-time initialization of all worker threads
387
+ and starts the health check system.
388
+ """
389
+ nonlocal initialized, worker_health_check_task, tasks
 
 
 
 
 
 
390
 
391
+ if initialized:
392
+ return
 
 
 
 
 
393
 
394
+ async with initialization_lock:
395
+ if initialized:
396
+ return
397
 
398
+ logger.info("limit_async: Initializing worker system")
 
399
 
400
+ # Create initial worker tasks
401
+ for _ in range(max_size):
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  task = asyncio.create_task(worker())
403
  tasks.add(task)
404
  task.add_done_callback(tasks.discard)
405
 
406
+ # Start health check
407
+ worker_health_check_task = asyncio.create_task(health_check())
408
+
409
+ initialized = True
410
+ logger.info("limit_async: Worker system initialized")
411
+
412
  async def shutdown():
413
  """Gracefully shut down all workers and the queue"""
414
  logger.info("limit_async: Shutting down priority queue workers")
 
467
  QueueFullError: If the queue is full and waiting times out
468
  Any exception raised by the decorated function
469
  """
470
+ # Ensure worker system is initialized
471
  await ensure_workers()
472
 
473
  # Create a future for the result
 
475
  active_futures.add(future)
476
 
477
  nonlocal counter
478
+ async with initialization_lock:
479
  current_count = counter
480
  counter += 1
481