Spaces:
Running
Running
import asyncio | |
from dataclasses import dataclass | |
from tqdm.asyncio import tqdm as tqdm_async | |
from graphgen.utils import create_event_loop | |
from graphgen.models.text.text_pair import TextPair | |
class BaseEvaluator: | |
max_concurrent: int = 100 | |
results: list[float] = None | |
def evaluate(self, pairs: list[TextPair]) -> list[float]: | |
""" | |
Evaluate the text and return a score. | |
""" | |
return create_event_loop().run_until_complete(self.async_evaluate(pairs)) | |
async def async_evaluate(self, pairs: list[TextPair]) -> list[float]: | |
semaphore = asyncio.Semaphore(self.max_concurrent) | |
async def evaluate_with_semaphore(pair): | |
async with semaphore: # 获取Semaphore | |
return await self.evaluate_single(pair) | |
results = [] | |
for result in tqdm_async( | |
asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]), | |
total=len(pairs), | |
): | |
results.append(await result) | |
return results | |
async def evaluate_single(self, pair: TextPair) -> float: | |
raise NotImplementedError() | |
def get_average_score(self, pairs: list[TextPair]) -> float: | |
""" | |
Get the average score of a batch of texts. | |
""" | |
results = self.evaluate(pairs) | |
self.results = results | |
return sum(self.results) / len(pairs) | |
def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]: | |
""" | |
Get the min and max score of a batch of texts. | |
""" | |
if self.results is None: | |
self.get_average_score(pairs) | |
return min(self.results), max(self.results) | |