Spaces:
Running
Running
"""Task abstract class for evaluation and results.""" | |
import logging | |
from abc import ABC, abstractmethod | |
from enum import Enum | |
from importlib.metadata import version | |
from typing import Any, List, Literal, Optional | |
import datasets | |
from pydantic import BaseModel, model_validator | |
# HACK: if Modality is not defined, then import it from modality.py | |
try: | |
from ..modality import Modality | |
except Exception: | |
# if not, super hack to get the leaderboard working. | |
# SHOULD MATCH the code exactly in modality.py | |
# can we read the file and run that code? | |
from enum import Enum | |
class Modality(Enum): | |
"""Data modality, either DNA or protein sequence.""" | |
PROTEIN = "protein" | |
DNA = "dna" | |
logging.basicConfig(level=logging.INFO) | |
TaskType = Literal[ | |
"classification", | |
"pair_classification", | |
"clustering", | |
"eds", | |
"bigene_mining", | |
"retrieval", | |
] | |
class TaskMetric(BaseModel): | |
id: str | |
display_name: str | |
description: Optional[str] = None | |
value: float = 0.0 | |
class LayerResult(BaseModel): | |
layer_number: int | |
layer_display_name: str | |
metrics: List[TaskMetric] | |
class DGEBModel(BaseModel): | |
hf_name: str | |
num_layers: int | |
num_params: int | |
embed_dim: int | |
class Dataset(BaseModel): | |
path: str | |
revision: str | |
def load(self) -> datasets.DatasetDict: | |
ds = datasets.load_dataset(self.path, revision=self.revision) | |
if not isinstance(ds, datasets.DatasetDict): | |
raise ValueError( | |
f"Dataset {self.path} is not a datasets.DatasetDict object." | |
) | |
return ds | |
class TaskMetadata(BaseModel): | |
id: str | |
display_name: str | |
description: str | |
modality: Modality | |
type: TaskType | |
# List of datasets used by the task. | |
# Each dataset is a dict of all arguments to pass to `datasets.load_dataset()`. | |
datasets: List[Dataset] | |
primary_metric_id: str | |
# tasks.py | |
class TaskResult(BaseModel): | |
dgeb_version: str | |
task: "TaskMetadata" | |
# TODO: Convert model to ModelMetadata | |
model: DGEBModel | |
results: List[LayerResult] | |
def check_valid_primary_metric(self): | |
for result in self.results: | |
if all( | |
metric.id != self.task.primary_metric_id for metric in result.metrics | |
): | |
raise ValueError( | |
f"Primary metric {self.task.primary_metric_id} not found in results.metrics" | |
) | |
return self | |
def from_dict( | |
task_metadata: "TaskMetadata", | |
layer_results: LayerResult, | |
model_metadata: DGEBModel, | |
): | |
return TaskResult( | |
dgeb_version=version("dgeb"), | |
task=task_metadata, | |
model=model_metadata, | |
results=list( | |
LayerResult( | |
layer_number=int(layer), | |
layer_display_name=str(layer), | |
metrics=[ | |
TaskMetric(id=metric, display_name=metric, value=value) | |
for metric, value in metrics.items() | |
], | |
) | |
for layer, metrics in layer_results["layers"].items() | |
), | |
) | |
# move to model.py? | |
class Task(ABC): | |
metadata: TaskMetadata | |
# using Any instead of "BioSeqTransformer" to avoid installing all deps in leaderboard | |
def run(self, model: Any, layers: Optional[List[int]] = None) -> TaskResult: | |
pass | |