|
import logging |
|
import os |
|
import traceback |
|
from itertools import chain |
|
from typing import Any, List |
|
|
|
from rich.console import Console |
|
|
|
from .eval_utils import set_all_seeds |
|
from .modality import Modality |
|
from .models import BioSeqTransformer |
|
from .tasks.tasks import Task |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class DGEB: |
|
"""GEB class to run the evaluation pipeline.""" |
|
|
|
def __init__(self, tasks: List[type[Task]], seed: int = 42): |
|
self.tasks = tasks |
|
set_all_seeds(seed) |
|
|
|
def print_selected_tasks(self): |
|
"""Print the selected tasks.""" |
|
console = Console() |
|
console.rule("[bold]Selected Tasks\n", style="grey15") |
|
for task in self.tasks: |
|
prefix = " - " |
|
name = f"{task.metadata.display_name}" |
|
category = f", [italic grey39]{task.metadata.type}[/]" |
|
console.print(f"{prefix}{name}{category}") |
|
console.print("\n") |
|
|
|
def run( |
|
self, |
|
model, |
|
output_folder: str = "results", |
|
): |
|
"""Run the evaluation pipeline on the selected tasks. |
|
|
|
Args: |
|
model: Model to be used for evaluation |
|
output_folder: Folder where the results will be saved. Default to 'results'. Where it will save the results in the format: |
|
`{output_folder}/{model_name}/{model_revision}/{task_name}.json`. |
|
|
|
Returns: |
|
A list of MTEBResults objects, one for each task evaluated. |
|
""" |
|
|
|
self.print_selected_tasks() |
|
results = [] |
|
|
|
for task in self.tasks: |
|
logger.info( |
|
f"\n\n********************** Evaluating {task.metadata.display_name} **********************" |
|
) |
|
|
|
try: |
|
result = task().run(model) |
|
except Exception as e: |
|
logger.error(e) |
|
logger.error(traceback.format_exc()) |
|
logger.error(f"Error running task {task}") |
|
continue |
|
|
|
results.append(result) |
|
|
|
save_path = get_output_folder(model.hf_name, task, output_folder) |
|
with open(save_path, "w") as f_out: |
|
f_out.write(result.model_dump_json(indent=2)) |
|
return results |
|
|
|
|
|
def get_model(model_name: str, **kwargs: Any) -> type[BioSeqTransformer]: |
|
all_names = get_all_model_names() |
|
for cls in BioSeqTransformer.__subclasses__(): |
|
if model_name in cls.MODEL_NAMES: |
|
return cls(model_name, **kwargs) |
|
raise ValueError(f"Model {model_name} not found in {all_names}.") |
|
|
|
|
|
def get_all_model_names() -> List[str]: |
|
return list( |
|
chain.from_iterable( |
|
cls.MODEL_NAMES for cls in BioSeqTransformer.__subclasses__() |
|
) |
|
) |
|
|
|
|
|
def get_all_task_names() -> List[str]: |
|
return [task.metadata.id for task in get_all_tasks()] |
|
|
|
|
|
def get_tasks_by_name(tasks: List[str]) -> List[type[Task]]: |
|
return [_get_task(task) for task in tasks] |
|
|
|
|
|
def get_tasks_by_modality(modality: Modality) -> List[type[Task]]: |
|
return [task for task in get_all_tasks() if task.metadata.modality == modality] |
|
|
|
|
|
def get_all_tasks() -> List[type[Task]]: |
|
return Task.__subclasses__() |
|
|
|
|
|
def _get_task(task_name: str) -> type[Task]: |
|
logger.info(f"Getting task {task_name}") |
|
for task in get_all_tasks(): |
|
if task.metadata.id == task_name: |
|
return task |
|
|
|
raise ValueError( |
|
f"Task {task_name} not found, available tasks are: {[task.metadata.id for task in get_all_tasks()]}" |
|
) |
|
|
|
|
|
def get_output_folder( |
|
model_hf_name: str, task: type[Task], output_folder: str, create: bool = True |
|
): |
|
output_folder = os.path.join(output_folder, os.path.basename(model_hf_name)) |
|
|
|
if create and not os.path.exists(output_folder): |
|
os.makedirs(output_folder) |
|
return os.path.join( |
|
output_folder, |
|
f"{task.metadata.id}.json", |
|
) |
|
|