File size: 2,297 Bytes
5f58699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Loader for the Boughter et al. 2020 dataset."""

from __future__ import annotations

from typing import Iterable

import numpy as np
import pandas as pd

from .utils import LOGGER, standardize_frame

_COLUMN_ALIASES = {
    "id": ("sequence_id",),
    "heavy_seq": ("heavy", "heavy_chain"),
    "light_seq": ("light", "light_chain"),
    "label": ("polyreactive",),
}


def _find_flag_columns(columns: Iterable[str]) -> list[str]:
    flag_cols: list[str] = []
    for column in columns:
        normalized = column.lower().replace(" ", "")
        if "flag" in normalized:
            flag_cols.append(column)
    return flag_cols


def _apply_flag_policy(frame: pd.DataFrame, flag_columns: list[str]) -> pd.DataFrame:
    if not flag_columns:
        return frame

    flag_values = (
        frame[flag_columns]
        .apply(pd.to_numeric, errors="coerce")
        .fillna(0.0)
    )
    flag_binary = (flag_values > 0).astype(int)
    flags_total = flag_binary.sum(axis=1)

    specific_mask = flags_total == 0
    nonspecific_mask = flags_total >= 4
    keep_mask = specific_mask | nonspecific_mask

    dropped = int((~keep_mask).sum())
    if dropped:
        LOGGER.info("Dropped %s mildly polyreactive sequences (1-3 ELISA flags)", dropped)

    filtered = frame.loc[keep_mask].copy()
    filtered["flags_total"] = flags_total.loc[keep_mask].astype(int)
    filtered["label"] = np.where(nonspecific_mask.loc[keep_mask], 1, 0)
    filtered["polyreactive"] = filtered["label"]
    return filtered


def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame:
    """Load the Boughter dataset into the canonical format."""

    frame = pd.read_csv(path_or_url)
    flag_columns = _find_flag_columns(frame.columns)
    frame = _apply_flag_policy(frame, flag_columns)

    label_series = frame.get("label")
    if label_series is not None:
        frame = frame[label_series.isin({0, 1})].copy()

    standardized = standardize_frame(
        frame,
        source="boughter2020",
        heavy_only=heavy_only,
        column_aliases=_COLUMN_ALIASES,
        is_test=False,
    )
    if "flags_total" in frame.columns and "flags_total" not in standardized.columns:
        standardized["flags_total"] = frame["flags_total"].to_numpy(dtype=int)
    return standardized