|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from scipy.stats import entropy |
|
|
|
|
|
class AdaptiveAugmentation: |
|
""" |
|
Implements adaptive data-driven augmentation for HARCNet. |
|
Dynamically adjusts geometric and MixUp augmentations based on data distribution. |
|
""" |
|
def __init__(self, alpha=0.5, beta=0.5, gamma=2.0): |
|
""" |
|
Args: |
|
alpha: Weight for variance component in geometric augmentation |
|
beta: Weight for entropy component in geometric augmentation |
|
gamma: Scaling factor for MixUp interpolation |
|
""" |
|
self.alpha = alpha |
|
self.beta = beta |
|
self.gamma = gamma |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def compute_variance(self, x): |
|
"""Compute variance across feature dimensions""" |
|
|
|
|
|
var = torch.var(x, dim=1, keepdim=True) |
|
return var.mean(dim=[1, 2, 3]) |
|
|
|
def compute_entropy(self, probs): |
|
"""Compute entropy of probability distributions""" |
|
|
|
|
|
probs = torch.clamp(probs, min=1e-8, max=1.0) |
|
log_probs = torch.log(probs) |
|
entropy_val = -torch.sum(probs * log_probs, dim=1) |
|
return entropy_val |
|
|
|
def get_geometric_strength(self, x, model=None, probs=None): |
|
""" |
|
Compute geometric augmentation strength based on sample variance and entropy |
|
S_g(x_i) = 伪路Var(x_i) + 尾路Entropy(x_i) |
|
""" |
|
var = self.compute_variance(x) |
|
|
|
|
|
if probs is None and model is not None: |
|
with torch.no_grad(): |
|
logits = model(x) |
|
probs = F.softmax(logits, dim=1) |
|
|
|
if probs is not None: |
|
ent = self.compute_entropy(probs) |
|
else: |
|
|
|
ent = torch.ones_like(var) |
|
|
|
|
|
var = (var - var.min()) / (var.max() - var.min() + 1e-8) |
|
ent = (ent - ent.min()) / (ent.max() - ent.min() + 1e-8) |
|
|
|
strength = self.alpha * var + self.beta * ent |
|
return strength |
|
|
|
def get_mixup_params(self, y, num_classes=100): |
|
""" |
|
Generate MixUp parameters based on label entropy |
|
位 ~ Beta(纬路Entropy(y), 纬路Entropy(y)) |
|
""" |
|
|
|
y_onehot = F.one_hot(y, num_classes=num_classes).float() |
|
|
|
|
|
batch_entropy = self.compute_entropy(y_onehot.mean(dim=0, keepdim=True)).item() |
|
|
|
|
|
alpha = self.gamma * batch_entropy |
|
alpha = max(0.1, min(alpha, 2.0)) |
|
|
|
lam = np.random.beta(alpha, alpha) |
|
|
|
|
|
batch_size = y.size(0) |
|
index = torch.randperm(batch_size).to(self.device) |
|
|
|
return lam, index |
|
|
|
def apply_mixup(self, x, y, num_classes=100): |
|
"""Apply MixUp augmentation with adaptive coefficient""" |
|
lam, index = self.get_mixup_params(y, num_classes) |
|
mixed_x = lam * x + (1 - lam) * x[index] |
|
y_a, y_b = y, y[index] |
|
return mixed_x, y_a, y_b, lam |
|
|
|
|
|
class TemporalConsistencyRegularization: |
|
""" |
|
Implements decayed temporal consistency regularization for HARCNet. |
|
Reduces noise in pseudo-labels by incorporating past predictions. |
|
""" |
|
def __init__(self, memory_size=5, decay_rate=2.0, consistency_weight=0.1): |
|
""" |
|
Args: |
|
memory_size: Number of past predictions to store (K) |
|
decay_rate: Controls the decay of weights for past predictions (蟿) |
|
consistency_weight: Weight for consistency loss (位_consistency) |
|
""" |
|
self.memory_size = memory_size |
|
self.decay_rate = decay_rate |
|
self.consistency_weight = consistency_weight |
|
self.prediction_history = {} |
|
|
|
def compute_decay_weights(self): |
|
""" |
|
Compute exponentially decaying weights |
|
蠅_k = e^(-k/蟿) / 危(e^(-k/蟿)) |
|
""" |
|
weights = torch.exp(-torch.arange(1, self.memory_size + 1) / self.decay_rate) |
|
return weights / weights.sum() |
|
|
|
def update_history(self, indices, predictions): |
|
"""Update prediction history for each sample""" |
|
for i, idx in enumerate(indices): |
|
idx = idx.item() |
|
if idx not in self.prediction_history: |
|
self.prediction_history[idx] = [] |
|
|
|
|
|
self.prediction_history[idx].append(predictions[i].detach()) |
|
|
|
|
|
if len(self.prediction_history[idx]) > self.memory_size: |
|
self.prediction_history[idx].pop(0) |
|
|
|
def get_aggregated_predictions(self, indices): |
|
""" |
|
Get aggregated predictions for each sample using decay weights |
|
峄筥i = 危(蠅_k 路 欧_i^(t-k)) |
|
""" |
|
weights = self.compute_decay_weights().to(indices.device) |
|
aggregated_preds = [] |
|
|
|
for i, idx in enumerate(indices): |
|
idx = idx.item() |
|
if idx in self.prediction_history and len(self.prediction_history[idx]) > 0: |
|
|
|
history = self.prediction_history[idx] |
|
history_len = len(history) |
|
|
|
if history_len > 0: |
|
|
|
available_weights = weights[-history_len:] |
|
available_weights = available_weights / available_weights.sum() |
|
|
|
|
|
weighted_sum = torch.zeros_like(history[0]) |
|
for j, pred in enumerate(history): |
|
weighted_sum += available_weights[j] * pred |
|
|
|
aggregated_preds.append(weighted_sum) |
|
else: |
|
|
|
aggregated_preds.append(torch.zeros_like(history[0])) |
|
else: |
|
|
|
aggregated_preds.append(None) |
|
|
|
return aggregated_preds |
|
|
|
def compute_consistency_loss(self, current_preds, indices): |
|
""" |
|
Compute consistency loss between current and aggregated past predictions |
|
L_consistency(x_i) = ||欧_i^(t) - 危(蠅_k 路 欧_i^(t-k))||^2_2 |
|
""" |
|
aggregated_preds = self.get_aggregated_predictions(indices) |
|
loss = 0.0 |
|
valid_samples = 0 |
|
|
|
for i, agg_pred in enumerate(aggregated_preds): |
|
if agg_pred is not None: |
|
|
|
sample_loss = F.mse_loss(current_preds[i], agg_pred) |
|
loss += sample_loss |
|
valid_samples += 1 |
|
|
|
|
|
if valid_samples > 0: |
|
return loss / valid_samples |
|
else: |
|
|
|
return torch.tensor(0.0).to(current_preds.device) |
|
|