|
"""Task abstract class for evaluation and results.""" |
|
|
|
import logging |
|
from abc import ABC, abstractmethod |
|
from enum import Enum |
|
from importlib.metadata import version |
|
from typing import Any, List, Literal, Optional |
|
|
|
import datasets |
|
from pydantic import BaseModel, model_validator |
|
|
|
|
|
try: |
|
from ..modality import Modality |
|
except Exception: |
|
|
|
|
|
|
|
from enum import Enum |
|
|
|
class Modality(Enum): |
|
"""Data modality, either DNA or protein sequence.""" |
|
|
|
PROTEIN = "protein" |
|
DNA = "dna" |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
TaskType = Literal[ |
|
"classification", |
|
"pair_classification", |
|
"clustering", |
|
"eds", |
|
"bigene_mining", |
|
"retrieval", |
|
] |
|
|
|
|
|
class TaskMetric(BaseModel): |
|
id: str |
|
display_name: str |
|
description: Optional[str] = None |
|
value: float = 0.0 |
|
|
|
|
|
class LayerResult(BaseModel): |
|
layer_number: int |
|
layer_display_name: str |
|
metrics: List[TaskMetric] |
|
|
|
|
|
class DGEBModel(BaseModel): |
|
hf_name: str |
|
num_layers: int |
|
num_params: int |
|
embed_dim: int |
|
|
|
|
|
class Dataset(BaseModel): |
|
path: str |
|
revision: str |
|
|
|
def load(self) -> datasets.DatasetDict: |
|
ds = datasets.load_dataset(self.path, revision=self.revision) |
|
if not isinstance(ds, datasets.DatasetDict): |
|
raise ValueError( |
|
f"Dataset {self.path} is not a datasets.DatasetDict object." |
|
) |
|
return ds |
|
|
|
|
|
class TaskMetadata(BaseModel): |
|
id: str |
|
display_name: str |
|
description: str |
|
modality: Modality |
|
type: TaskType |
|
|
|
|
|
datasets: List[Dataset] |
|
primary_metric_id: str |
|
|
|
|
|
|
|
class TaskResult(BaseModel): |
|
dgeb_version: str |
|
task: "TaskMetadata" |
|
|
|
model: DGEBModel |
|
results: List[LayerResult] |
|
|
|
@model_validator(mode="after") |
|
def check_valid_primary_metric(self): |
|
for result in self.results: |
|
if all( |
|
metric.id != self.task.primary_metric_id for metric in result.metrics |
|
): |
|
raise ValueError( |
|
f"Primary metric {self.task.primary_metric_id} not found in results.metrics" |
|
) |
|
return self |
|
|
|
@staticmethod |
|
def from_dict( |
|
task_metadata: "TaskMetadata", |
|
layer_results: LayerResult, |
|
model_metadata: DGEBModel, |
|
): |
|
return TaskResult( |
|
dgeb_version=version("dgeb"), |
|
task=task_metadata, |
|
model=model_metadata, |
|
results=list( |
|
LayerResult( |
|
layer_number=int(layer), |
|
layer_display_name=str(layer), |
|
metrics=[ |
|
TaskMetric(id=metric, display_name=metric, value=value) |
|
for metric, value in metrics.items() |
|
], |
|
) |
|
for layer, metrics in layer_results["layers"].items() |
|
), |
|
) |
|
|
|
|
|
|
|
class Task(ABC): |
|
metadata: TaskMetadata |
|
|
|
|
|
@abstractmethod |
|
def run(self, model: Any, layers: Optional[List[int]] = None) -> TaskResult: |
|
pass |
|
|