Spaces:
Sleeping
Sleeping
File size: 2,895 Bytes
498ffec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import time
from multiprocessing import Process, Manager
from tqdm import tqdm
def worker_main(work_queue, result_queue, process_func, config):
while True:
item = work_queue.get()
if item is None:
result_queue.put(None)
break
try:
results, cost = process_func(config, item)
result_queue.put((results, cost))
except Exception as e:
item_info = item.get('idx', item.get('id', 'unknown item'))
print(f"Error processing item {item_info}: {e}")
result_queue.put(None)
finally:
work_queue.task_done()
def run_parallel_evaluation(dataset, process_func, config, num_workers, description):
"""
Runs parallel evaluation on the given dataset and returns the results.
Args:
dataset (list or datasets.Dataset): Data to evaluate.
process_func (callable): Function to process each data item.
config (dict): Configuration for the process_func.
num_workers (int): Number of worker processes to use.
description (str): Description to display on the tqdm progress bar.
Returns:
tuple: (list of evaluation results, total cost)
"""
manager = Manager()
work_queue = manager.Queue()
result_queue = manager.Queue()
# Add data to the work queue
dataset_list = list(dataset) if not isinstance(dataset, list) else dataset
for data in dataset_list:
work_queue.put(data)
# Add termination signals for workers
for _ in range(num_workers):
work_queue.put(None)
# Start parallel processing
processes = []
for _ in range(num_workers):
p = Process(target=worker_main, args=(work_queue, result_queue, process_func, config))
p.start()
processes.append(p)
# Show progress bar and collect results
process_results = []
process_cost = 0
completed_workers = 0
with tqdm(total=len(dataset_list), desc=description) as pbar:
while completed_workers < num_workers:
result_item = result_queue.get()
if result_item is None:
completed_workers += 1
else:
results, cost = result_item
if results is not None:
process_results.append(results)
process_cost += cost if cost is not None else 0
pbar.update(1)
# Wait for all processes to finish
for p in processes:
p.join()
# Collect remaining results
while not result_queue.empty():
result_item = result_queue.get_nowait()
if result_item is not None:
results, cost = result_item
if results is not None:
process_results.append(results)
process_cost += cost if cost is not None else 0
return process_results, process_cost |