yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.stats import entropy
class AdaptiveAugmentation:
"""
Implements adaptive data-driven augmentation for HARCNet.
Dynamically adjusts geometric and MixUp augmentations based on data distribution.
"""
def __init__(self, alpha=0.5, beta=0.5, gamma=2.0):
"""
Args:
alpha: Weight for variance component in geometric augmentation
beta: Weight for entropy component in geometric augmentation
gamma: Scaling factor for MixUp interpolation
"""
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def compute_variance(self, x):
"""Compute variance across feature dimensions"""
# x shape: [B, C, H, W]
# Compute variance across channels for each spatial location
var = torch.var(x, dim=1, keepdim=True) # [B, 1, H, W]
return var.mean(dim=[1, 2, 3]) # [B]
def compute_entropy(self, probs):
"""Compute entropy of probability distributions"""
# probs shape: [B, C] where C is number of classes
# Ensure valid probability distribution
probs = torch.clamp(probs, min=1e-8, max=1.0)
log_probs = torch.log(probs)
entropy_val = -torch.sum(probs * log_probs, dim=1) # [B]
return entropy_val
def get_geometric_strength(self, x, model=None, probs=None):
"""
Compute geometric augmentation strength based on sample variance and entropy
S_g(x_i) = 伪路Var(x_i) + 尾路Entropy(x_i)
"""
var = self.compute_variance(x)
# If model predictions are provided, use them for entropy calculation
if probs is None and model is not None:
with torch.no_grad():
logits = model(x)
probs = F.softmax(logits, dim=1)
if probs is not None:
ent = self.compute_entropy(probs)
else:
# Default entropy if no predictions available
ent = torch.ones_like(var)
# Normalize to [0, 1] range
var = (var - var.min()) / (var.max() - var.min() + 1e-8)
ent = (ent - ent.min()) / (ent.max() - ent.min() + 1e-8)
strength = self.alpha * var + self.beta * ent
return strength
def get_mixup_params(self, y, num_classes=100):
"""
Generate MixUp parameters based on label entropy
位 ~ Beta(纬路Entropy(y), 纬路Entropy(y))
"""
# Convert labels to one-hot encoding
y_onehot = F.one_hot(y, num_classes=num_classes).float()
# Compute entropy of ground truth labels (across batch)
batch_entropy = self.compute_entropy(y_onehot.mean(dim=0, keepdim=True)).item()
# Generate mixup coefficient from Beta distribution
alpha = self.gamma * batch_entropy
alpha = max(0.1, min(alpha, 2.0)) # Bound alpha between 0.1 and 2.0
lam = np.random.beta(alpha, alpha)
# Generate random permutation for mixing
batch_size = y.size(0)
index = torch.randperm(batch_size).to(self.device)
return lam, index
def apply_mixup(self, x, y, num_classes=100):
"""Apply MixUp augmentation with adaptive coefficient"""
lam, index = self.get_mixup_params(y, num_classes)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
class TemporalConsistencyRegularization:
"""
Implements decayed temporal consistency regularization for HARCNet.
Reduces noise in pseudo-labels by incorporating past predictions.
"""
def __init__(self, memory_size=5, decay_rate=2.0, consistency_weight=0.1):
"""
Args:
memory_size: Number of past predictions to store (K)
decay_rate: Controls the decay of weights for past predictions (蟿)
consistency_weight: Weight for consistency loss (位_consistency)
"""
self.memory_size = memory_size
self.decay_rate = decay_rate
self.consistency_weight = consistency_weight
self.prediction_history = {} # Store past predictions for each sample
def compute_decay_weights(self):
"""
Compute exponentially decaying weights
蠅_k = e^(-k/蟿) / 危(e^(-k/蟿))
"""
weights = torch.exp(-torch.arange(1, self.memory_size + 1) / self.decay_rate)
return weights / weights.sum()
def update_history(self, indices, predictions):
"""Update prediction history for each sample"""
for i, idx in enumerate(indices):
idx = idx.item()
if idx not in self.prediction_history:
self.prediction_history[idx] = []
# Add current prediction to history
self.prediction_history[idx].append(predictions[i].detach())
# Keep only the most recent K predictions
if len(self.prediction_history[idx]) > self.memory_size:
self.prediction_history[idx].pop(0)
def get_aggregated_predictions(self, indices):
"""
Get aggregated predictions for each sample using decay weights
峄筥i = 危(蠅_k 路 欧_i^(t-k))
"""
weights = self.compute_decay_weights().to(indices.device)
aggregated_preds = []
for i, idx in enumerate(indices):
idx = idx.item()
if idx in self.prediction_history and len(self.prediction_history[idx]) > 0:
# Get available history (might be less than memory_size)
history = self.prediction_history[idx]
history_len = len(history)
if history_len > 0:
# Use available weights
available_weights = weights[-history_len:]
available_weights = available_weights / available_weights.sum()
# Compute weighted sum
weighted_sum = torch.zeros_like(history[0])
for j, pred in enumerate(history):
weighted_sum += available_weights[j] * pred
aggregated_preds.append(weighted_sum)
else:
# No history available, use zeros
aggregated_preds.append(torch.zeros_like(history[0]))
else:
# No history for this sample, return None
aggregated_preds.append(None)
return aggregated_preds
def compute_consistency_loss(self, current_preds, indices):
"""
Compute consistency loss between current and aggregated past predictions
L_consistency(x_i) = ||欧_i^(t) - 危(蠅_k 路 欧_i^(t-k))||^2_2
"""
aggregated_preds = self.get_aggregated_predictions(indices)
loss = 0.0
valid_samples = 0
for i, agg_pred in enumerate(aggregated_preds):
if agg_pred is not None:
# Compute MSE between current and aggregated predictions
sample_loss = F.mse_loss(current_preds[i], agg_pred)
loss += sample_loss
valid_samples += 1
# Return average loss if there are valid samples
if valid_samples > 0:
return loss / valid_samples
else:
# Return zero loss if no valid samples
return torch.tensor(0.0).to(current_preds.device)