File size: 5,740 Bytes
5f58699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""Utility helpers for dataset loading."""

from __future__ import annotations

import logging
from typing import Iterable, Sequence

import pandas as pd

EXPECTED_COLUMNS = ("id", "heavy_seq", "light_seq", "label")
OPTIONAL_COLUMNS = ("source", "is_test")

LOGGER = logging.getLogger("polyreact.data")

_DEFAULT_ALIASES: dict[str, Sequence[str]] = {
    "id": ("id", "sequence_id", "antibody_id", "uid"),
    "heavy_seq": ("heavy_seq", "heavy", "heavy_chain", "H", "H_chain"),
    "light_seq": ("light_seq", "light", "light_chain", "L", "L_chain"),
    "label": ("label", "polyreactive", "is_polyreactive", "class", "target"),
}

DEFAULT_LABEL_MAP: dict[str | int | float | bool, int] = {
    1: 1,
    0: 0,
    "1": 1,
    "0": 0,
    True: 1,
    False: 0,
    "true": 1,
    "false": 0,
    "polyreactive": 1,
    "non-polyreactive": 0,
    "poly": 1,
    "non": 0,
    "positive": 1,
    "negative": 0,
}


def _normalize_label_key(value: object) -> object:
    if isinstance(value, str):
        trimmed = value.strip().lower()
        if trimmed in {
            "polyreactive",
            "non-polyreactive",
            "poly",
            "non",
            "positive",
            "negative",
            "high",
            "low",
            "pos",
            "neg",
            "1",
            "0",
            "true",
            "false",
        }:
            return trimmed
        if trimmed.isdigit():
            return trimmed
    return value


def ensure_columns(frame: pd.DataFrame, *, heavy_only: bool = True) -> pd.DataFrame:
    """Validate and coerce dataframe columns to the canonical format."""

    frame = frame.copy()
    for column in ("id", "heavy_seq", "label"):
        if column not in frame.columns:
            msg = f"Required column '{column}' missing from dataframe"
            raise KeyError(msg)

    if "light_seq" not in frame.columns:
        frame["light_seq"] = ""

    if heavy_only:
        frame["light_seq"] = ""

    frame["id"] = frame["id"].astype(str)
    frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str)
    frame["light_seq"] = frame["light_seq"].fillna("").astype(str)
    frame["label"] = frame["label"].astype(int)

    ordered = list(EXPECTED_COLUMNS) + [
        col for col in frame.columns if col not in EXPECTED_COLUMNS
    ]
    return frame[ordered]


def standardize_frame(
    frame: pd.DataFrame,
    *,
    source: str,
    heavy_only: bool = True,
    column_aliases: dict[str, Sequence[str]] | None = None,
    label_map: dict[str | int | float | bool, int] | None = None,
    is_test: bool | None = None,
) -> pd.DataFrame:
    """Rename columns using aliases and coerce labels to integers."""

    aliases = {**_DEFAULT_ALIASES}
    if column_aliases:
        for key, values in column_aliases.items():
            aliases[key] = tuple(values) + tuple(aliases.get(key, ()))

    rename_map: dict[str, str] = {}
    for target, candidates in aliases.items():
        if target in frame.columns:
            continue
        for candidate in candidates:
            if candidate in frame.columns and candidate not in rename_map:
                rename_map[candidate] = target
                break

    normalized = frame.rename(columns=rename_map).copy()

    if "light_seq" not in normalized.columns:
        normalized["light_seq"] = ""

    label_lookup = label_map or DEFAULT_LABEL_MAP
    normalized["label"] = normalized["label"].map(lambda x: label_lookup.get(_normalize_label_key(x)))

    if normalized["label"].isnull().any():
        msg = "Label column contains unmapped or missing values"
        raise ValueError(msg)

    normalized["source"] = source
    if is_test is not None:
        normalized["is_test"] = bool(is_test)

    normalized = ensure_columns(normalized, heavy_only=heavy_only)
    return normalized


def deduplicate_sequences(
    frames: Iterable[pd.DataFrame],
    *,
    heavy_only: bool = True,
    key_columns: Sequence[str] | None = None,
    keep_intra_frames: set[int] | None = None,
) -> list[pd.DataFrame]:
    """Remove duplicate entries across multiple dataframes with configurable keys."""

    if key_columns is None:
        key_columns = ["heavy_seq"] if heavy_only else ["heavy_seq", "light_seq"]
    keep_intra_frames = keep_intra_frames or set()

    seen: set[tuple[str, ...]] = set()
    cleaned: list[pd.DataFrame] = []

    for frame_idx, frame in enumerate(frames):
        valid_columns = [col for col in key_columns if col in frame.columns]
        if not valid_columns:
            valid_columns = ["heavy_seq"]

        mask: list[bool] = []
        frame_seen: set[tuple[str, ...]] = set()
        allow_intra = frame_idx in keep_intra_frames

        for values in frame[valid_columns].itertuples(index=False, name=None):
            key = tuple(_normalise_key_value(value) for value in values)
            if key in seen:
                mask.append(False)
                continue
            if not allow_intra and key in frame_seen:
                mask.append(False)
                continue
            mask.append(True)
            frame_seen.add(key)
        seen.update(frame_seen)
        filtered = frame.loc[mask].reset_index(drop=True)
        removed = len(frame) - len(filtered)
        if removed:
            dataset = "<unknown>"
            if "source" in frame.columns and not frame["source"].empty:
                dataset = str(frame["source"].iloc[0])
            LOGGER.info("Removed %s duplicate sequences from %s", removed, dataset)
        cleaned.append(filtered)
    return cleaned


def _normalise_key_value(value: object) -> str:
    if value is None or (isinstance(value, float) and pd.isna(value)):
        return ""
    return str(value).strip()