|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import Sampler |
|
|
from collections import defaultdict |
|
|
import random |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class EarlyExitClassifier(nn.Module): |
|
|
def __init__(self, input_dim=27, hidden_dim=64): |
|
|
""" |
|
|
input_dim: 与 offline feat_cols_single 中 scalar 特征数量对齐(27) |
|
|
实际特征为: |
|
|
[s1_mid, s2_mid, margin_mid, z_margin_mid, mad_margin_mid, |
|
|
p1_mid, H_mid, gini_mid, |
|
|
topk_mean, topk_std, topk_cv, topk_kurt, topk_med, |
|
|
s1_over_mean, s1_over_med, |
|
|
p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk, |
|
|
tail_mean, head5_mean, |
|
|
mask_ratio, mask_len, mask_runs] |
|
|
再拼接 4 维 modality embedding → total_input_dim = input_dim + 4 |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.modality_emb = nn.Embedding(2, 4) |
|
|
|
|
|
total_input_dim = input_dim + 4 |
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(total_input_dim, hidden_dim), |
|
|
nn.BatchNorm1d(hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_dim, 1), |
|
|
) |
|
|
|
|
|
def forward(self, scalar_feats, modality_idx): |
|
|
""" |
|
|
scalar_feats: [B, input_dim] |
|
|
modality_idx: [B] in {0,1} |
|
|
""" |
|
|
mod_feat = self.modality_emb(modality_idx) |
|
|
x = torch.cat([scalar_feats, mod_feat], dim=1) |
|
|
logits = self.mlp(x) |
|
|
return logits |
|
|
|
|
|
|
|
|
class HomogeneousBatchSampler(Sampler): |
|
|
""" |
|
|
按 global_dataset_name 分组,同一 batch 尽量来自同一子数据集, |
|
|
便于 in-batch 对比学习。 |
|
|
""" |
|
|
def __init__(self, dataset, batch_size, drop_last=False): |
|
|
self.dataset = dataset |
|
|
self.batch_size = batch_size |
|
|
self.drop_last = drop_last |
|
|
self.groups = defaultdict(list) |
|
|
|
|
|
logger.info("Grouping data by dataset source for Homogeneous Sampling...") |
|
|
try: |
|
|
for idx in range(len(dataset)): |
|
|
item = dataset[idx] |
|
|
d_name = item.get("global_dataset_name", "unknown") |
|
|
self.groups[d_name].append(idx) |
|
|
except Exception as e: |
|
|
logger.warning( |
|
|
f"Error grouping dataset: {e}. " |
|
|
"Falling back to simple index chunking (NOT HOMOGENEOUS)." |
|
|
) |
|
|
self.groups["all"] = list(range(len(dataset))) |
|
|
|
|
|
logger.info(f"Grouped data into {len(self.groups)} datasets.") |
|
|
|
|
|
def __iter__(self): |
|
|
batch_list = [] |
|
|
for _, indices in self.groups.items(): |
|
|
random.shuffle(indices) |
|
|
for i in range(0, len(indices), self.batch_size): |
|
|
batch = indices[i : i + self.batch_size] |
|
|
if len(batch) < self.batch_size and self.drop_last: |
|
|
continue |
|
|
if len(batch) < 2: |
|
|
continue |
|
|
batch_list.append(batch) |
|
|
|
|
|
random.shuffle(batch_list) |
|
|
for batch in batch_list: |
|
|
yield batch |
|
|
|
|
|
def __len__(self): |
|
|
count = 0 |
|
|
for indices in self.groups.values(): |
|
|
if self.drop_last: |
|
|
count += len(indices) // self.batch_size |
|
|
else: |
|
|
remainder = len(indices) % self.batch_size |
|
|
full = len(indices) // self.batch_size |
|
|
count += full + (1 if remainder >= 2 else 0) |
|
|
return count |