DGEB / dgeb /dgeb.py
Joshua Kravitz
Initial commit
e284167
raw
history blame
3.98 kB
import logging
import os
import traceback
from itertools import chain
from typing import Any, List
from rich.console import Console
from .eval_utils import set_all_seeds
from .modality import Modality
from .models import BioSeqTransformer
from .tasks.tasks import Task
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DGEB:
"""GEB class to run the evaluation pipeline."""
def __init__(self, tasks: List[type[Task]], seed: int = 42):
self.tasks = tasks
set_all_seeds(seed)
def print_selected_tasks(self):
"""Print the selected tasks."""
console = Console()
console.rule("[bold]Selected Tasks\n", style="grey15")
for task in self.tasks:
prefix = " - "
name = f"{task.metadata.display_name}"
category = f", [italic grey39]{task.metadata.type}[/]"
console.print(f"{prefix}{name}{category}")
console.print("\n")
def run(
self,
model, # type encoder
output_folder: str = "results",
):
"""Run the evaluation pipeline on the selected tasks.
Args:
model: Model to be used for evaluation
output_folder: Folder where the results will be saved. Default to 'results'. Where it will save the results in the format:
`{output_folder}/{model_name}/{model_revision}/{task_name}.json`.
Returns:
A list of MTEBResults objects, one for each task evaluated.
"""
# Run selected tasks
self.print_selected_tasks()
results = []
for task in self.tasks:
logger.info(
f"\n\n********************** Evaluating {task.metadata.display_name} **********************"
)
try:
result = task().run(model)
except Exception as e:
logger.error(e)
logger.error(traceback.format_exc())
logger.error(f"Error running task {task}")
continue
results.append(result)
save_path = get_output_folder(model.hf_name, task, output_folder)
with open(save_path, "w") as f_out:
f_out.write(result.model_dump_json(indent=2))
return results
def get_model(model_name: str, **kwargs: Any) -> type[BioSeqTransformer]:
all_names = get_all_model_names()
for cls in BioSeqTransformer.__subclasses__():
if model_name in cls.MODEL_NAMES:
return cls(model_name, **kwargs)
raise ValueError(f"Model {model_name} not found in {all_names}.")
def get_all_model_names() -> List[str]:
return list(
chain.from_iterable(
cls.MODEL_NAMES for cls in BioSeqTransformer.__subclasses__()
)
)
def get_all_task_names() -> List[str]:
return [task.metadata.id for task in get_all_tasks()]
def get_tasks_by_name(tasks: List[str]) -> List[type[Task]]:
return [_get_task(task) for task in tasks]
def get_tasks_by_modality(modality: Modality) -> List[type[Task]]:
return [task for task in get_all_tasks() if task.metadata.modality == modality]
def get_all_tasks() -> List[type[Task]]:
return Task.__subclasses__()
def _get_task(task_name: str) -> type[Task]:
logger.info(f"Getting task {task_name}")
for task in get_all_tasks():
if task.metadata.id == task_name:
return task
raise ValueError(
f"Task {task_name} not found, available tasks are: {[task.metadata.id for task in get_all_tasks()]}"
)
def get_output_folder(
model_hf_name: str, task: type[Task], output_folder: str, create: bool = True
):
output_folder = os.path.join(output_folder, os.path.basename(model_hf_name))
# create output folder if it does not exist
if create and not os.path.exists(output_folder):
os.makedirs(output_folder)
return os.path.join(
output_folder,
f"{task.metadata.id}.json",
)