File size: 7,729 Bytes
62a2f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.stats import entropy
class AdaptiveAugmentation:
"""
Implements adaptive data-driven augmentation for HARCNet.
Dynamically adjusts geometric and MixUp augmentations based on data distribution.
"""
def __init__(self, alpha=0.5, beta=0.5, gamma=2.0):
"""
Args:
alpha: Weight for variance component in geometric augmentation
beta: Weight for entropy component in geometric augmentation
gamma: Scaling factor for MixUp interpolation
"""
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def compute_variance(self, x):
"""Compute variance across feature dimensions"""
# x shape: [B, C, H, W]
# Compute variance across channels for each spatial location
var = torch.var(x, dim=1, keepdim=True) # [B, 1, H, W]
return var.mean(dim=[1, 2, 3]) # [B]
def compute_entropy(self, probs):
"""Compute entropy of probability distributions"""
# probs shape: [B, C] where C is number of classes
# Ensure valid probability distribution
probs = torch.clamp(probs, min=1e-8, max=1.0)
log_probs = torch.log(probs)
entropy_val = -torch.sum(probs * log_probs, dim=1) # [B]
return entropy_val
def get_geometric_strength(self, x, model=None, probs=None):
"""
Compute geometric augmentation strength based on sample variance and entropy
S_g(x_i) = 伪路Var(x_i) + 尾路Entropy(x_i)
"""
var = self.compute_variance(x)
# If model predictions are provided, use them for entropy calculation
if probs is None and model is not None:
with torch.no_grad():
logits = model(x)
probs = F.softmax(logits, dim=1)
if probs is not None:
ent = self.compute_entropy(probs)
else:
# Default entropy if no predictions available
ent = torch.ones_like(var)
# Normalize to [0, 1] range
var = (var - var.min()) / (var.max() - var.min() + 1e-8)
ent = (ent - ent.min()) / (ent.max() - ent.min() + 1e-8)
strength = self.alpha * var + self.beta * ent
return strength
def get_mixup_params(self, y, num_classes=100):
"""
Generate MixUp parameters based on label entropy
位 ~ Beta(纬路Entropy(y), 纬路Entropy(y))
"""
# Convert labels to one-hot encoding
y_onehot = F.one_hot(y, num_classes=num_classes).float()
# Compute entropy of ground truth labels (across batch)
batch_entropy = self.compute_entropy(y_onehot.mean(dim=0, keepdim=True)).item()
# Generate mixup coefficient from Beta distribution
alpha = self.gamma * batch_entropy
alpha = max(0.1, min(alpha, 2.0)) # Bound alpha between 0.1 and 2.0
lam = np.random.beta(alpha, alpha)
# Generate random permutation for mixing
batch_size = y.size(0)
index = torch.randperm(batch_size).to(self.device)
return lam, index
def apply_mixup(self, x, y, num_classes=100):
"""Apply MixUp augmentation with adaptive coefficient"""
lam, index = self.get_mixup_params(y, num_classes)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
class TemporalConsistencyRegularization:
"""
Implements decayed temporal consistency regularization for HARCNet.
Reduces noise in pseudo-labels by incorporating past predictions.
"""
def __init__(self, memory_size=5, decay_rate=2.0, consistency_weight=0.1):
"""
Args:
memory_size: Number of past predictions to store (K)
decay_rate: Controls the decay of weights for past predictions (蟿)
consistency_weight: Weight for consistency loss (位_consistency)
"""
self.memory_size = memory_size
self.decay_rate = decay_rate
self.consistency_weight = consistency_weight
self.prediction_history = {} # Store past predictions for each sample
def compute_decay_weights(self):
"""
Compute exponentially decaying weights
蠅_k = e^(-k/蟿) / 危(e^(-k/蟿))
"""
weights = torch.exp(-torch.arange(1, self.memory_size + 1) / self.decay_rate)
return weights / weights.sum()
def update_history(self, indices, predictions):
"""Update prediction history for each sample"""
for i, idx in enumerate(indices):
idx = idx.item()
if idx not in self.prediction_history:
self.prediction_history[idx] = []
# Add current prediction to history
self.prediction_history[idx].append(predictions[i].detach())
# Keep only the most recent K predictions
if len(self.prediction_history[idx]) > self.memory_size:
self.prediction_history[idx].pop(0)
def get_aggregated_predictions(self, indices):
"""
Get aggregated predictions for each sample using decay weights
峄筥i = 危(蠅_k 路 欧_i^(t-k))
"""
weights = self.compute_decay_weights().to(indices.device)
aggregated_preds = []
for i, idx in enumerate(indices):
idx = idx.item()
if idx in self.prediction_history and len(self.prediction_history[idx]) > 0:
# Get available history (might be less than memory_size)
history = self.prediction_history[idx]
history_len = len(history)
if history_len > 0:
# Use available weights
available_weights = weights[-history_len:]
available_weights = available_weights / available_weights.sum()
# Compute weighted sum
weighted_sum = torch.zeros_like(history[0])
for j, pred in enumerate(history):
weighted_sum += available_weights[j] * pred
aggregated_preds.append(weighted_sum)
else:
# No history available, use zeros
aggregated_preds.append(torch.zeros_like(history[0]))
else:
# No history for this sample, return None
aggregated_preds.append(None)
return aggregated_preds
def compute_consistency_loss(self, current_preds, indices):
"""
Compute consistency loss between current and aggregated past predictions
L_consistency(x_i) = ||欧_i^(t) - 危(蠅_k 路 欧_i^(t-k))||^2_2
"""
aggregated_preds = self.get_aggregated_predictions(indices)
loss = 0.0
valid_samples = 0
for i, agg_pred in enumerate(aggregated_preds):
if agg_pred is not None:
# Compute MSE between current and aggregated predictions
sample_loss = F.mse_loss(current_preds[i], agg_pred)
loss += sample_loss
valid_samples += 1
# Return average loss if there are valid samples
if valid_samples > 0:
return loss / valid_samples
else:
# Return zero loss if no valid samples
return torch.tensor(0.0).to(current_preds.device)
|