DGEB / dgeb /tasks /tasks.py
Joshua Kravitz
feat: Launch
566f3c9
raw
history blame
3.54 kB
"""Task abstract class for evaluation and results."""
import logging
from abc import ABC, abstractmethod
from enum import Enum
from importlib.metadata import version
from typing import Any, List, Literal, Optional
import datasets
from pydantic import BaseModel, model_validator
# HACK: if Modality is not defined, then import it from modality.py
try:
from ..modality import Modality
except Exception:
# if not, super hack to get the leaderboard working.
# SHOULD MATCH the code exactly in modality.py
# can we read the file and run that code?
from enum import Enum
class Modality(Enum):
"""Data modality, either DNA or protein sequence."""
PROTEIN = "protein"
DNA = "dna"
logging.basicConfig(level=logging.INFO)
TaskType = Literal[
"classification",
"pair_classification",
"clustering",
"eds",
"bigene_mining",
"retrieval",
]
class TaskMetric(BaseModel):
id: str
display_name: str
description: Optional[str] = None
value: float = 0.0
class LayerResult(BaseModel):
layer_number: int
layer_display_name: str
metrics: List[TaskMetric]
class DGEBModel(BaseModel):
hf_name: str
num_layers: int
num_params: int
embed_dim: int
class Dataset(BaseModel):
path: str
revision: str
def load(self) -> datasets.DatasetDict:
ds = datasets.load_dataset(self.path, revision=self.revision)
if not isinstance(ds, datasets.DatasetDict):
raise ValueError(
f"Dataset {self.path} is not a datasets.DatasetDict object."
)
return ds
class TaskMetadata(BaseModel):
id: str
display_name: str
description: str
modality: Modality
type: TaskType
# List of datasets used by the task.
# Each dataset is a dict of all arguments to pass to `datasets.load_dataset()`.
datasets: List[Dataset]
primary_metric_id: str
# tasks.py
class TaskResult(BaseModel):
dgeb_version: str
task: "TaskMetadata"
# TODO: Convert model to ModelMetadata
model: DGEBModel
results: List[LayerResult]
@model_validator(mode="after")
def check_valid_primary_metric(self):
for result in self.results:
if all(
metric.id != self.task.primary_metric_id for metric in result.metrics
):
raise ValueError(
f"Primary metric {self.task.primary_metric_id} not found in results.metrics"
)
return self
@staticmethod
def from_dict(
task_metadata: "TaskMetadata",
layer_results: LayerResult,
model_metadata: DGEBModel,
):
return TaskResult(
dgeb_version=version("dgeb"),
task=task_metadata,
model=model_metadata,
results=list(
LayerResult(
layer_number=int(layer),
layer_display_name=str(layer),
metrics=[
TaskMetric(id=metric, display_name=metric, value=value)
for metric, value in metrics.items()
],
)
for layer, metrics in layer_results["layers"].items()
),
)
# move to model.py?
class Task(ABC):
metadata: TaskMetadata
# using Any instead of "BioSeqTransformer" to avoid installing all deps in leaderboard
@abstractmethod
def run(self, model: Any, layers: Optional[List[int]] = None) -> TaskResult:
pass