Spaces:
Sleeping
Sleeping
File size: 3,910 Bytes
5f58699 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
"""Public Python API for polyreactivity prediction."""
from __future__ import annotations
from pathlib import Path
from typing import Iterable
import copy
import joblib
import pandas as pd
from sklearn.preprocessing import StandardScaler
from .config import Config, load_config
from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline
def predict_batch( # noqa: ANN003
records: Iterable[dict],
*,
config: Config | str | Path | None = None,
backend: str | None = None,
plm_model: str | None = None,
weights: str | Path | None = None,
heavy_only: bool = True,
batch_size: int = 8,
device: str | None = None,
cache_dir: str | None = None,
) -> pd.DataFrame:
"""Predict polyreactivity scores for a batch of sequences."""
records_list = list(records)
if not records_list:
return pd.DataFrame(columns=["id", "score", "pred"])
artifact = _load_artifact(weights)
if config is None:
artifact_config = artifact.get("config")
if isinstance(artifact_config, Config):
config = copy.deepcopy(artifact_config)
else:
config = load_config("configs/default.yaml")
elif isinstance(config, (str, Path)):
config = load_config(config)
else:
config = copy.deepcopy(config)
if backend:
config.feature_backend.type = backend
if plm_model:
config.feature_backend.plm_model_name = plm_model
if device:
config.device = device
if cache_dir:
config.feature_backend.cache_dir = cache_dir
pipeline = _restore_pipeline(config, artifact)
trained_model = artifact["model"]
frame = pd.DataFrame(records_list)
if frame.empty:
raise ValueError("Prediction requires at least one record.")
if "id" not in frame.columns:
frame["id"] = frame.get("sequence_id", range(len(frame))).astype(str)
if "heavy_seq" in frame.columns:
frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str)
else:
heavy_series = frame.get("heavy")
if heavy_series is None:
heavy_series = pd.Series([""] * len(frame))
frame["heavy_seq"] = heavy_series.fillna("").astype(str)
if "light_seq" in frame.columns:
frame["light_seq"] = frame["light_seq"].fillna("").astype(str)
else:
light_series = frame.get("light")
if light_series is None:
light_series = pd.Series([""] * len(frame))
frame["light_seq"] = light_series.fillna("").astype(str)
if heavy_only:
frame["light_seq"] = ""
if frame["heavy_seq"].str.len().eq(0).all():
raise ValueError("No heavy chain sequences provided for prediction.")
features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size)
scores = trained_model.predict_proba(features)
preds = (scores >= 0.5).astype(int)
return pd.DataFrame(
{
"id": frame["id"].astype(str),
"score": scores,
"pred": preds,
}
)
def _load_artifact(weights: str | Path | None) -> dict:
if weights is None:
msg = "Prediction requires a path to model weights"
raise ValueError(msg)
artifact = joblib.load(weights)
if not isinstance(artifact, dict):
msg = "Model artifact must be a dictionary"
raise ValueError(msg)
return artifact
def _restore_pipeline(config: Config, artifact: dict) -> FeaturePipeline:
pipeline = build_feature_pipeline(config)
state = artifact.get("feature_state")
if isinstance(state, FeaturePipelineState):
pipeline.load_state(state)
if pipeline.backend.type in {"plm", "concat"} and pipeline._plm_scaler is None:
pipeline._plm_scaler = StandardScaler()
return pipeline
msg = "Model artifact is missing feature pipeline state"
raise ValueError(msg)
|