""" Pair classification tasks evaluating distances between functionally relevant gene pairs. For instance, distance thresholds distinguish between co-transcribed and non-co-transcribed gene pairs. """ import logging from collections import defaultdict from dgeb.evaluators import PairClassificationEvaluator from dgeb.modality import Modality from dgeb.models import BioSeqTransformer from dgeb.tasks import Dataset, Task, TaskMetadata, TaskResult from ..eval_utils import paired_dataset logger = logging.getLogger(__name__) def run_pair_classification_task( model: BioSeqTransformer, metadata: TaskMetadata ) -> TaskResult: """Evaluate pair classification task. Utilizes the PairClassificationEvaluator.""" if len(metadata.datasets) != 1: raise ValueError("Pair classification tasks require 1 dataset.") ds = metadata.datasets[0].load()["train"] embeds = model.encode(ds["Sequence"]) layer_results = defaultdict(dict) for i, layer in enumerate(model.layers): labels = ds["Label"] embeds1, embeds2, labels = paired_dataset(labels, embeds[:, i]) evaluator = PairClassificationEvaluator(embeds1, embeds2, labels) layer_results["layers"][layer] = evaluator() logger.info( f"Layer: {layer}, {metadata.display_name} classification results: {layer_results['layers'][layer]}" ) return TaskResult.from_dict(metadata, layer_results, model.metadata) class EcoliOperon(Task): metadata = TaskMetadata( id="ecoli_operonic_pair", display_name="E.coli Operonic Pair", description="Evaluate on E.coli K-12 operonic pair classification task.", type="pair_classification", modality=Modality.PROTEIN, datasets=[ Dataset( path="tattabio/ecoli_operonic_pair", revision="a62c01143a842696fc8200b91c1acb825e8cb891", ) ], primary_metric_id="top_ap", ) def run(self, model: BioSeqTransformer) -> TaskResult: return run_pair_classification_task(model, self.metadata) class CyanoOperonPair(Task): metadata = TaskMetadata( id="cyano_operonic_pair", display_name="Cyano Operonic Pair", description="Evaluate on Cyano operonic pair classification task.", type="pair_classification", modality=Modality.PROTEIN, datasets=[ Dataset( path="tattabio/cyano_operonic_pair", revision="eeb4cb71ec2a4ff688af9de7c0662123577d32ec", ) ], primary_metric_id="top_ap", ) def run(self, model: BioSeqTransformer) -> TaskResult: return run_pair_classification_task(model, self.metadata) class VibrioOperonPair(Task): metadata = TaskMetadata( id="vibrio_operonic_pair", display_name="Vibrio Operonic Pair", description="Evaluate on Vibrio operonic pair classification task.", type="pair_classification", modality=Modality.PROTEIN, datasets=[ Dataset( path="tattabio/vibrio_operonic_pair", revision="24781b12b45bf81a079a6164ef0d2124948c1878", ) ], primary_metric_id="top_ap", ) def run(self, model: BioSeqTransformer) -> TaskResult: return run_pair_classification_task(model, self.metadata)