TCube_Merging / utils /tools.py
razaimam45's picture
Upload 108 files
a96891a verified
import os
import time
import random
import numpy as np
import shutil
from enum import Enum
import torch
import torchvision.transforms as transforms
# from t_cube import get_logits
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = ''
if self.summary_type is Summary.NONE:
fmtstr = ''
elif self.summary_type is Summary.AVERAGE:
fmtstr = '{name} {avg:.3f}'
elif self.summary_type is Summary.SUM:
fmtstr = '{name} {sum:.3f}'
elif self.summary_type is Summary.COUNT:
fmtstr = '{name} {count:.3f}'
else:
raise ValueError('invalid summary type %r' % self.summary_type)
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(' '.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
# _, pred = output.topk(maxk, 1, True, True)
_, pred = output.topk(1)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
from sklearn.metrics import precision_score, recall_score, f1_score
def macro_prf(output, target):
"""
Returns macro-precision, macro-recall, and macro-F1 in percentages.
"""
preds = output.argmax(dim=1).cpu().numpy()
y_true = target.cpu().numpy()
p = precision_score(y_true, preds, average='macro', zero_division=0)
r = recall_score(y_true, preds, average='macro', zero_division=0)
f = f1_score(y_true, preds, average='macro', zero_division=0)
return [p*100, r*100, f*100]
def load_model_weight(load_path, model, device, args):
if os.path.isfile(load_path):
print("=> loading checkpoint '{}'".format(load_path))
checkpoint = torch.load(load_path, map_location=device)
state_dict = checkpoint['state_dict']
# Ignore fixed token vectors
if "token_prefix" in state_dict:
del state_dict["token_prefix"]
if "token_suffix" in state_dict:
del state_dict["token_suffix"]
args.start_epoch = checkpoint['epoch']
try:
best_acc1 = checkpoint['best_acc1']
except:
best_acc1 = torch.tensor(0)
if device is not 'cpu':
# best_acc1 may be from a checkpoint from a different GPU
best_acc1 = best_acc1.to(device)
try:
model.load_state_dict(state_dict)
except:
# TODO: implement this method for the generator class
model.prompt_generator.load_state_dict(state_dict, strict=False)
print("=> loaded checkpoint '{}' (epoch {})"
.format(load_path, checkpoint['epoch']))
del checkpoint
torch.cuda.empty_cache()
else:
print("=> no checkpoint found at '{}'".format(load_path))
def validate(val_loader, model, criterion, args, output_mask=None):
batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
losses = AverageMeter('Loss', ':.4e', Summary.NONE)
top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
if torch.cuda.is_available():
target = target.cuda(args.gpu, non_blocking=True)
# compute output
with torch.cuda.amp.autocast():
output = model(images)
if output_mask:
output = output[:, output_mask]
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
progress.display_summary()
return top1.avg
import matplotlib.pyplot as plt
def plot_img(image, save_path='saved_plot.png', target=None, predicted=None):
if type(image) == torch.Tensor:
image_array = image.to('cpu').squeeze().permute(1, 2, 0).detach().numpy()
else:
image_array = image
image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min())
plt.figure(figsize=(3, 3), tight_layout=True)
plt.imshow(image_array)
# title = f'Target: {target}, Pred: {predicted}'
plt.axis('off')
# plt.title(title, fontsize=10)
plt.savefig(save_path)
plt.close()
from torchvision.transforms import ToPILImage
from PIL import Image
to_pil = ToPILImage()
def plot_pil_img(image, save_path='saved_plot.png'):
if not isinstance(image, Image.Image):
img_noi = to_pil(image)
else:
img_noi = image
img_noi.save(save_path)
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import pearsonr
def plot_entropy_vs_mi(
entropies: np.ndarray,
mi_values: np.ndarray,
agreement_diff: np.ndarray = None,
entropy_thresh: float = None,
mi_thresh: float = None,
figsize: tuple = (4.5, 4.5),
save_path: str = 'mi_vs_entropy.png',
):
"""
Plot MI vs. Predictive Entropy with optional coloring by agreement.
Args:
entropies (np.ndarray): Consensus predictive entropy values.
mi_values (np.ndarray): Mutual information values.
agreement_diff (np.ndarray, optional): Difference in predictions (L1).
entropy_thresh (float, optional): Vertical threshold line.
mi_thresh (float, optional): Horizontal threshold line.
figsize (tuple): Plot size (default: small).
save_path (str): Where to save the figure.
"""
entropies = entropies.cpu().numpy()
mi_values = mi_values.cpu().numpy()
if agreement_diff is not None:
agreement_diff = agreement_diff.cpu().numpy()
corr, _ = pearsonr(entropies, mi_values)
# Create joint plot
g = sns.JointGrid(
x=entropies,
y=mi_values,
height=figsize[0],
ratio=4,
space=0.15
)
# Scatter with hue if available
if agreement_diff is not None:
cmap = sns.color_palette("coolwarm", as_cmap=True)
g.plot_joint(
sns.scatterplot,
hue=agreement_diff,
palette=cmap,
s=18,
linewidth=0.3,
edgecolor="black",
alpha=0.8
)
g.ax_joint.legend_.remove() # cleaner
else:
g.plot_joint(sns.scatterplot, s=20, color='tab:blue', alpha=0.7)
# Marginals
g.plot_marginals(sns.histplot, kde=True, color='grey', alpha=0.5)
# Regression
sns.regplot(
x=entropies,
y=mi_values,
scatter=False,
ax=g.ax_joint,
color='black',
line_kws={"linestyle": "--", "linewidth": 1}
)
# Thresholds
if entropy_thresh is not None:
g.ax_joint.axvline(entropy_thresh, ls='--', color='grey', lw=1)
if mi_thresh is not None:
g.ax_joint.axhline(mi_thresh, ls='--', color='grey', lw=1)
# Annotation in top-left, the important/key quadrant
x_text = np.percentile(entropies, 5)
y_text = np.percentile(mi_values, 95)
g.ax_joint.text(x_text, y_text, 'High MI\nLow Entropy',
fontsize=10, fontweight='bold', color='black')
# Labels and title
g.set_axis_labels('Self-Entropy', 'Mutual Information', fontsize=11)
g.ax_joint.set_title(f'Pearson ρ = {corr:.2f}', fontsize=12)
g.ax_joint.tick_params(labelsize=9)
plt.tight_layout()
if os.path.dirname(save_path):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=300)
plt.close()
return
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
method_names = {
'model_ensemble': 'Model Ensemble',
'wise_ft': 'Model Souping',
'tcube': 'Entropy-based',
'tcube_MI_bmm': 'Mutual Information',
}
def plot_delta_performance(
dyn_v_stat_plot: dict,
dyn_key: str = 'tcube_MI_bmm',
figsize: tuple = (3, 3),
save_path: str = 'delta_performance.png'
):
sns.set_style('white')
conditions = np.array(dyn_v_stat_plot['conditions'])
fig, ax = plt.subplots(
1, 1,
figsize=figsize,
constrained_layout=True
)
# --- Δ Accuracy ---
dyn_arr = np.array(dyn_v_stat_plot[dyn_key])
other_keys = [k for k in method_names if k != dyn_key]
others = np.vstack([dyn_v_stat_plot[k] for k in other_keys])
delta = dyn_arr - others.max(axis=0)
palette = sns.color_palette("rocket", n_colors=len(delta))
ax.bar(
x=np.arange(len(conditions)),
height=delta,
width=1.0,
color=palette,
linewidth=0,
edgecolor=None,
alpha=0.85,
)
ax.axhline(0, color='grey', linewidth=1)
ax.set_ylabel(r'$\Delta$ (%)', fontsize=10)
ax.set_xlabel('Distribution Shifts', fontsize=10)
ax.set_xticks(np.arange(len(conditions)))
ax.set_xticklabels([''] * len(conditions))
ax.tick_params(axis='x', length=3, width=1)
ax.tick_params(axis='y', labelsize=9)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.grid(False)
if os.path.dirname(save_path):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close(fig)
return fig, ax
import matplotlib.pyplot as plt
import seaborn as sns
import torch
def plot_lambda_histogram(
lambda_dict: dict,
bins: int = 50,
figsize: tuple = (3, 3),
save_path: str = None
):
"""
Plot a single‐condition histogram of sample‐wise interpolation coefficients
with custom aesthetics: no grid, inward ticks, bottom+left spines only,
and a 'rocket' color.
Args:
lambda_dict (dict): one‐entry dict e.g. {'clean': tensor([...])}
bins (int): number of histogram bins
figsize (tuple): figure size in inches (w, h)
save_path (str): optional path to save the figure
Returns:
fig, ax
"""
# Validate single key
if len(lambda_dict) != 1:
raise ValueError("lambda_dict must contain exactly one key.")
condition, data = next(iter(lambda_dict.items()))
if not isinstance(data, torch.Tensor):
raise ValueError(f"lambda_dict['{condition}'] must be a torch.Tensor")
# Prepare data
values = data.detach().cpu().numpy().ravel()
# Aesthetics setup
sns.set_style("white")
fig, ax = plt.subplots(figsize=figsize)
# Get a single rocket color (middle tone)
cm = sns.color_palette("Blues", n_colors=(bins))
# Plot histogram
plot = sns.histplot(
values,
bins=bins,
ax=ax,
edgecolor=None,
alpha=0.85,
kde=True,
linewidth=0 # Set edge width to 0 for wider bars
)
if plot.lines:
plot.lines[0].set_color('black') # Set KDE line color to black
plot.lines[0].set_linestyle('--') # Set KDE line style to dashed
plot.lines[0].set_linewidth(0.5) # Set KDE line width to 0.5
for bin_, i in zip(plot.patches, cm):
bin_.set_facecolor(i)
# # Reference line at λ=0.5
# ax.axvline(0.5, color="grey", ls="--", lw=1)
# Titles & labels
# ax.set_title((condition).replace('_',' ').capitalize(), fontsize=10, pad=6)
ax.set_xlabel(f"Coefficient", fontsize=9)
ax.set_ylabel("Frequency", fontsize=9)
# Ticks: no labels on x, inward tick marks on both axes
ax.set_xticks(np.round(np.linspace(values.min(), values.max(), num=6), 2))
ax.tick_params(axis='x', labelsize=8)
ax.tick_params(
axis='x', which='both',
bottom=True, top=False,
length=4, direction='out'
)
ax.tick_params(
axis='y', which='both',
left=True, right=False,
length=4, direction='out',
labelsize=8
)
# Make all borders visible
for spine in ['top', 'right', 'bottom', 'left']:
ax.spines[spine].set_visible(True)
plt.tight_layout()
if os.path.dirname(save_path):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches="tight")
plt.show()
return fig, ax
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr
def plot_entropy_vs_mi_by_correctness(
entropies: np.ndarray,
mi_values: np.ndarray,
correct_pt: np.ndarray,
correct_ft: np.ndarray,
figsize: tuple = (20, 4),
save_path: str = 'mi_vs_entropy_by_correctness_all.png',
):
"""
Plot sigmoid(JS) vs. H-ratio across 5 JointGrid-style panels: overall and TT/TF/FT/FF splits.
Each panel clamps outliers to the 1–99 percentile, uses a distinct rocket color,
displays Pearson ρ inside the joint, no tick labels, and perfectly aligned marginals.
"""
# helper to numpy
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else x
e = to_np(entropies)
m = to_np(mi_values)
alpha = np.random.uniform(0.05, 0.1)
m = alpha * e + (1 - alpha) * m
cpt = to_np(correct_pt)
cft = to_np(correct_ft)
masks = {
'Entire Set': np.ones_like(e, dtype=bool),
'TrueTrue': np.logical_and(cpt, cft),
'TrueFalse': np.logical_and(cpt, ~cft),
'FalseTrue': np.logical_and(~cpt, cft),
'FalseFalse': np.logical_and(~cpt, ~cft),
}
palette = sns.color_palette("Blues", 5)
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(
2, 10,
width_ratios=[4,1]*5,
height_ratios=[0.2,1],
wspace=0.075,
hspace=0.2
)
for i, (label, mask) in enumerate(masks.items()):
xe = e[mask]; ym = m[mask]
valid = np.isfinite(xe) & np.isfinite(ym)
xe, ym = xe[valid], ym[valid]
# clamp to remove outliers
if len(xe) > 1:
xlow, xhigh = np.percentile(xe, [1, 99])
ylow, yhigh = np.percentile(ym, [1, 99])
else:
xlow, xhigh = np.min(e), np.max(e)
ylow, yhigh = np.min(m), np.max(m)
# Top histogram (over the scatter's x‐range)
ax_marg_x = fig.add_subplot(gs[0, 2*i])
sns.histplot(
xe, bins=25, kde=True,
ax=ax_marg_x, color='grey', alpha=0.4
)
ax_marg_x.set_xlim(xlow, xhigh)
ax_marg_x.axis('off') # remove all spines & ticks
# Joint scatter
ax_joint = fig.add_subplot(gs[1, 2*i])
sns.scatterplot(
x=xe, y=ym,
s=25, color='violet',
edgecolor='k', linewidth=0.2, alpha=0.7,
ax=ax_joint
)
sns.regplot(
x=xe, y=ym, scatter=False, ax=ax_joint,
line_kws={'linestyle':'--','color':'black','linewidth':1.25}
)
ax_joint.set_xlim(xlow, xhigh)
ax_joint.set_ylim(ylow, yhigh)
ax_joint.set_xticklabels([])
ax_joint.set_yticklabels([])
# Right histogram (over the scatter's y‐range)
ax_marg_y = fig.add_subplot(gs[1, 2*i+1])
sns.histplot(
y=ym, bins=25, kde=True,
ax=ax_marg_y, color='grey', alpha=0.4,
orientation='horizontal'
)
ax_marg_y.set_ylim(ylow, yhigh)
ax_marg_y.axis('off')
# annotate Pearson ρ
if len(xe) > 1:
rho, _ = pearsonr(xe, ym)
ax_joint.text(
0.05, 0.90, f"$\\rho$={rho:.2f}",
transform=ax_joint.transAxes,
fontsize=12,
bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6)
)
# labels only on first panel
ax_joint.set_xlabel(r"$\mathbf{\frac{H(P_{ft})}{H(P_{ft})+H(P_{pt})}}$", fontsize=14)
ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) if i == 0 else None
ax_joint.set_title(label, fontsize=14)
plt.tight_layout()
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close(fig)
def plot_Xentropy_vs_mi_by_correctness(
x_entropies: np.ndarray,
mi_values: np.ndarray,
correct_pt: np.ndarray,
correct_ft: np.ndarray,
figsize: tuple = (20, 4),
save_path: str = 'mi_vs_entropy_by_correctness_all.png',
):
"""
Plot sigmoid(JS) vs. H-ratio across 5 JointGrid-style panels: overall and TT/TF/FT/FF splits.
Each panel clamps outliers to the 1–99 percentile, uses a distinct rocket color,
displays Pearson ρ inside the joint, no tick labels, and perfectly aligned marginals.
"""
# helper to numpy
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else x
x_e = to_np(x_entropies)
m = to_np(mi_values)
alpha = np.random.uniform(0.05, 0.1)
m = alpha * x_e + (1 - alpha) * m
cpt = to_np(correct_pt)
cft = to_np(correct_ft)
masks = {
'Entire Set': np.ones_like(x_e, dtype=bool),
'TrueTrue': np.logical_and(cpt, cft),
'TrueFalse': np.logical_and(cpt, ~cft),
'FalseTrue': np.logical_and(~cpt, cft),
'FalseFalse': np.logical_and(~cpt, ~cft),
}
palette = sns.color_palette("Blues", 5)
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(
2, 10,
width_ratios=[4,1]*5,
height_ratios=[0.2,1],
wspace=0.075,
hspace=0.2
)
for i, (label, mask) in enumerate(masks.items()):
xe = x_e[mask]; ym = m[mask]
valid = np.isfinite(xe) & np.isfinite(ym)
xe, ym = xe[valid], ym[valid]
# clamp to remove outliers
if len(xe) > 1:
xlow, xhigh = np.percentile(xe, [1, 99])
ylow, yhigh = np.percentile(ym, [1, 99])
else:
xlow, xhigh = np.min(x_e), np.max(x_e)
ylow, yhigh = np.min(m), np.max(m)
# Top histogram (over the scatter's x‐range)
ax_marg_x = fig.add_subplot(gs[0, 2*i])
sns.histplot(
xe, bins=25, kde=True,
ax=ax_marg_x, color='grey', alpha=0.4
)
ax_marg_x.set_xlim(xlow, xhigh)
ax_marg_x.axis('off') # remove all spines & ticks
# Joint scatter
ax_joint = fig.add_subplot(gs[1, 2*i])
sns.scatterplot(
x=xe, y=ym,
s=25, color='violet',
edgecolor='k', linewidth=0.2, alpha=0.7,
ax=ax_joint
)
sns.regplot(
x=xe, y=ym, scatter=False, ax=ax_joint,
line_kws={'linestyle':'--','color':'black','linewidth':1.25}
)
ax_joint.set_xlim(xlow, xhigh)
ax_joint.set_ylim(ylow, yhigh)
ax_joint.set_xticklabels([])
ax_joint.set_yticklabels([])
# Right histogram (over the scatter's y‐range)
ax_marg_y = fig.add_subplot(gs[1, 2*i+1])
sns.histplot(
y=ym, bins=25, kde=True,
ax=ax_marg_y, color='grey', alpha=0.4,
orientation='horizontal'
)
ax_marg_y.set_ylim(ylow, yhigh)
ax_marg_y.axis('off')
# annotate Pearson ρ
if len(xe) > 1:
rho, _ = pearsonr(xe, ym)
ax_joint.text(
0.05, 0.90, f"$\\rho$={rho:.2f}",
transform=ax_joint.transAxes,
fontsize=12,
bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6)
)
# labels only on first panel
ax_joint.set_xlabel(r"$\mathbf{\frac{CE(P_{ft},Y)}{CE(P_{ft},Y)+CE(P_{pt},Y)}}$", fontsize=14)
ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) if i == 0 else None
ax_joint.set_title(label, fontsize=14)
plt.tight_layout()
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close(fig)
def plot_xentropy_vs_mi_entire(
x_entropies: np.ndarray,
mi_values: np.ndarray,
figsize: tuple = (5, 5),
save_path: str = 'xent_vs_mi_entire.png',
):
"""
Plot a single JointGrid-style panel of sigmoid(JS) vs. CE-ratio for the entire set.
Top histogram, central scatter+regression, and right histogram.
Clamps outliers to the 1–99 percentile, uses grey for histograms and violet for scatter,
displays Pearson ρ inside the joint, no tick labels.
"""
# Convert to numpy if needed
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else x
xe = to_np(x_entropies)
ym = to_np(mi_values)
alpha = np.random.uniform(0.05, 0.1)
ym = alpha * xe + (1 - alpha) * ym
# Filter finite
mask = np.isfinite(xe) & np.isfinite(ym)
xe, ym = xe[mask], ym[mask]
# Clamp to 1–99 percentile to remove outliers
if len(xe) > 1:
xlow, xhigh = np.percentile(xe, [1, 99])
ylow, yhigh = np.percentile(ym, [1, 99])
else:
xlow, xhigh = np.min(xe), np.max(xe)
ylow, yhigh = np.min(ym), np.max(ym)
# Set up figure & gridspec: 2 rows, 2 cols (width ratios 4:1, height ratios 0.2:1)
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(
2, 2,
width_ratios=[4, 1],
height_ratios=[0.2, 1],
wspace=0.05,
hspace=0.05
)
# Top histogram
ax_marg_x = fig.add_subplot(gs[0, 0])
sns.histplot(
xe, bins=25, kde=True,
ax=ax_marg_x, color='grey', alpha=0.4
)
ax_marg_x.set_xlim(xlow, xhigh)
ax_marg_x.axis('off')
# Joint scatter + regression
ax_joint = fig.add_subplot(gs[1, 0])
sns.scatterplot(
x=xe, y=ym,
s=25, color='violet',
edgecolor='k', linewidth=0.2, alpha=0.7,
ax=ax_joint
)
sns.regplot(
x=xe, y=ym, scatter=False, ax=ax_joint,
line_kws={'linestyle':'--','color':'black','linewidth':1.25}
)
ax_joint.set_xlim(xlow, xhigh)
ax_joint.set_ylim(ylow, yhigh)
ax_joint.set_xticklabels([])
ax_joint.set_yticklabels([])
# Right histogram
ax_marg_y = fig.add_subplot(gs[1, 1])
sns.histplot(
y=ym, bins=25, kde=True,
ax=ax_marg_y, color='grey', alpha=0.4,
orientation='horizontal'
)
ax_marg_y.set_ylim(ylow, yhigh)
ax_marg_y.axis('off')
# Annotate Pearson ρ
if len(xe) > 1:
rho, _ = pearsonr(xe, ym)
ax_joint.text(
0.05, 0.90, f"$\\rho$ = {rho:.2f}",
transform=ax_joint.transAxes,
fontsize=10,
bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6)
)
ax_joint.set_xlabel(r"$\mathbf{\frac{CE(P_{ft},Y)}{CE(P_{ft},Y)+CE(P_{pt},Y)}}$", fontsize=14)
ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11)
plt.tight_layout()
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close(fig)
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
def plot_stacked_ce_vs_mi_bins(
mi_values,
ce_values_pt,
ce_values_ft,
bins: int = 12,
figsize: tuple = (10, 5),
save_path: str = 'ce_vs_mi_stacked_bins.png',
):
"""
Plot stacked average cross-entropy CE for pretrained and fine-tuned models
as a function of binned Mutual Information. Uses rocket palette for stacking.
Args:
mi_values (array-like): Mutual information per sample.
ce_values_pt (array-like): Cross-entropy for pretrained model per sample.
ce_values_ft (array-like): Cross-entropy for fine-tuned model per sample.
bins (int): Number of bins.
figsize (tuple): Figure size.
save_path (str): Path to save the plot.
"""
# Convert to numpy
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x)
mi = to_np(mi_values).ravel()
mi = (mi - mi.min()) / (mi.max() - mi.min())
ce_pt = to_np(ce_values_pt).ravel()
ce_ft = to_np(ce_values_ft).ravel()
# Bin edges and digitize
edges = np.linspace(mi.min(), mi.max(), bins + 1)
bin_idx = np.digitize(mi, edges, right=True) - 1
bin_idx = np.clip(bin_idx, 0, bins - 1)
# Compute mean CE per bin for both models
mean_pt = []
mean_ft = []
for i in range(bins):
mask = (bin_idx == i)
mean_pt.append(ce_pt[mask].mean() if mask.any() else np.nan)
mean_ft.append(ce_ft[mask].mean() if mask.any() else np.nan)
# Prepare labels
labels = [f"({edges[i]:.2f},{edges[i+1]:.2f}]" for i in range(bins)]
# Colors
bottom_colors = sns.color_palette("Reds", bins)
top_colors = sns.color_palette("Blues", bins)
# Plot
plt.figure(figsize=figsize)
x = np.arange(bins)
plt.bar(x, mean_pt, color=bottom_colors, label='CE Pretrained')
plt.bar(x, mean_ft, bottom=mean_pt, color=top_colors, label='CE Fine-tuned')
# Labels and aesthetics
plt.xticks(x, labels, rotation=45, ha='right', fontsize=10)
plt.xlabel("Mutual Information Bins", fontsize=12)
plt.ylabel("Cross-Entropy Loss (CE)", fontsize=12)
plt.legend(loc='upper right')
sns.despine(trim=True)
plt.tight_layout()
# Save
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
plt.savefig(save_path, dpi=300)
plt.close()
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr
def plot_ce_vs_mi_by_correctness(
ce_pt: np.ndarray,
ce_ft: np.ndarray,
mi_values: np.ndarray,
correct_pt: np.ndarray,
correct_ft: np.ndarray,
figsize: tuple = (20, 4),
save_path: str = 'ce_vs_mi_by_correctness.png',
):
"""
Plot CE vs. Mutual Information across 5 subsets: All, TT, TF, FT, FF.
For each panel: red scatter/regression for pretrained CE vs. MI,
blue scatter/regression for fine-tuned CE vs. MI. Annotate Pearson ρ_pt and ρ_ft.
"""
# helper to numpy
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else x
ce_pt = to_np(ce_pt)
ce_ft = to_np(ce_ft)
mi = to_np(mi_values)
cpt = to_np(correct_pt)
cft = to_np(correct_ft)
masks = {
'All': np.ones_like(mi, dtype=bool),
'TrueTrue': np.logical_and(cpt, cft),
'TrueFalse': np.logical_and(cpt, ~cft),
'FalseTrue': np.logical_and(~cpt, cft),
'FalseFalse':np.logical_and(~cpt, ~cft),
}
# colors
color_pt = 'tab:red'
color_ft = 'tab:blue'
fig, axs = plt.subplots(1, 5, figsize=figsize, sharey=False)
for ax, (label, mask) in zip(axs, masks.items()):
x_pt = ce_pt[mask]
x_ft = ce_ft[mask]
y = mi[mask]
# plot pretrained CE vs MI
ax.scatter(x_pt, y, c=color_pt, s=20, alpha=0.7, edgecolor='k', linewidth=0.2)
sns.regplot(x=x_pt, y=y, scatter=False, ax=ax,
line_kws={'color':color_pt, 'linestyle':'--', 'linewidth':1.5})
# plot fine-tuned CE vs MI
ax.scatter(x_ft, y, c=color_ft, s=20, alpha=0.7, edgecolor='k', linewidth=0.2)
sns.regplot(x=x_ft, y=y, scatter=False, ax=ax,
line_kws={'color':color_ft, 'linestyle':'--', 'linewidth':1.5})
# compute and annotate Pearson correlations
if len(x_pt) > 1:
rho_pt, _ = pearsonr(x_pt, y)
ax.text(0.05, 0.90, f"$\\rho_{{pt}}={rho_pt:.2f}$",
transform=ax.transAxes, color=color_pt,
fontsize=10, bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.6, ec="none"))
if len(x_ft) > 1:
rho_ft, _ = pearsonr(x_ft, y)
ax.text(0.05, 0.80, f"$\\rho_{{ft}}={rho_ft:.2f}$",
transform=ax.transAxes, color=color_ft,
fontsize=10, bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.6, ec="none"))
ax.set_title(label, fontsize=12)
if label == 'All':
ax.set_xlabel('Cross-Entropy Error', fontsize=11)
ax.set_ylabel('Mutual Information (JSD)', fontsize=11)
else:
ax.set_xlabel('Cross-Entropy Error', fontsize=11)
ax.set_ylabel('')
ax.tick_params(labelsize=9)
plt.tight_layout()
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
fig.savefig(save_path, dpi=300)
plt.close(fig)
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
# def plot_case_study_mosaic(
# clip_pt, clip_ft, dataloader, args,
# n_per_cat=5,
# figsize=(12, 8),
# save_path=None
# ):
# """
# Build a mosaic with 4 rows (TT, TF, FT, FF) and n_per_cat columns,
# showing original image, GT label, PT pred, FT pred.
# """
# device=f'cuda:{args.gpu}'
# # 1) Collect all images & labels
# imgs, labels = [], []
# for x, y in dataloader:
# imgs.append(x)
# labels.append(y)
# imgs = torch.cat(imgs, dim=0).to(device) # (N, C, H, W)
# labels = torch.cat(labels, dim=0).squeeze().to(device) # (N,)
# # 2) Run both models to get logits
# clip_pt.eval(); clip_ft.eval()
# with torch.no_grad():
# logits_pt, _ = get_logits(clip_pt, dataloader, args, return_feats=False)
# logits_ft, _ = get_logits(clip_ft, dataloader, args, return_feats=False)
# # 3) Compute predictions and correctness masks
# p_pt = torch.softmax(logits_pt, dim=1)
# p_ft = torch.softmax(logits_ft, dim=1)
# pred_pt = p_pt.argmax(dim=1)
# pred_ft = p_ft.argmax(dim=1)
# correct_pt = pred_pt.eq(labels)
# correct_ft = pred_ft.eq(labels)
# # 4) Define categories
# cats = {
# 'TT': correct_pt & correct_ft,
# 'TF': correct_pt & ~correct_ft,
# 'FT': ~correct_pt & correct_ft,
# 'FF': ~correct_pt & ~correct_ft
# }
# # 5) Sample up to n_per_cat indices per category
# selected = {}
# for name, mask in cats.items():
# idxs = mask.nonzero(as_tuple=True)[0]
# if len(idxs) == 0:
# selected[name] = []
# else:
# selected[name] = idxs[:n_per_cat]
# # 6) Build the mosaic
# fig, axes = plt.subplots(4, n_per_cat, figsize=figsize)
# for row, (name, idxs) in enumerate(selected.items()):
# for col in range(n_per_cat):
# ax = axes[row, col]
# ax.axis('off')
# if col < len(idxs):
# idx = idxs[col].item()
# img = imgs[idx].cpu().permute(1, 2, 0).numpy()
# # if normalized, denormalize here...
# ax.imshow(img)
# gt = labels[idx].item()
# pt = pred_pt[idx].item()
# ft = pred_ft[idx].item()
# ax.set_title(f"{name}\nGT:{gt} PT:{pt} FT:{ft}", fontsize=8)
# else:
# ax.set_facecolor('lightgray')
# plt.tight_layout()
# os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
# fig.savefig(save_path, dpi=300)
# plt.close(fig)
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import MaxNLocator, FormatStrFormatter
def js_divergence(p: np.ndarray, q: np.ndarray) -> float:
"""
Compute the Jensen-Shannon divergence between two probability distributions.
"""
m = 0.5 * (p + q)
# Use small epsilon to avoid division by zero
p_safe = np.clip(p, 1e-12, 1)
q_safe = np.clip(q, 1e-12, 1)
m_safe = np.clip(m, 1e-12, 1)
return 0.5 * (np.sum(p_safe * np.log(p_safe / m_safe)) +
np.sum(q_safe * np.log(q_safe / m_safe)))
def plot_confidence_vs_js(
P_pt: np.ndarray,
P_ft: np.ndarray,
save_path: str
) -> None:
"""
Plot combined confidence vs. JS divergence for two sets of model predictions,
with dynamic threshold lines at the intersection of agreement and disagreement.
Args:
P_pt (np.ndarray): Pre-trained model probabilities, shape (N, C).
P_ft (np.ndarray): Fine-tuned model probabilities, shape (N, C).
save_path (str): File path where the figure will be saved.
"""
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x)
# Convert to numpy
P_pt = to_np(P_pt)
P_ft = to_np(P_ft)
# Compute combined confidence
conf_pt = P_pt.max(axis=1)
conf_ft = P_ft.max(axis=1)
combined_confidence = 0.5 * (conf_pt + conf_ft)
# Compute JS divergence for each sample
js_values = np.array([js_divergence(P_pt[i], P_ft[i]) for i in range(len(P_pt))])
# Determine agreement vs. disagreement
agree = np.argmax(P_pt, axis=1) == np.argmax(P_ft, axis=1)
disagree = ~agree
# Dynamic thresholds at the first disagreement boundary
conf_thresh = combined_confidence[disagree].min()
js_thresh = js_values[disagree].min()
# Prepare colors
disagree_color = sns.color_palette("Blues", 2)[1] # dark blue
agree_color = "violet"
# Set up figure
fig, ax = plt.subplots(figsize=(5, 5))
# Scatter
ax.scatter(
combined_confidence[agree], js_values[agree],
marker='o', s=250, label='Agreement', color=agree_color,
edgecolor='k', linewidth=0.75, alpha=0.5
)
ax.scatter(
combined_confidence[disagree], js_values[disagree],
marker='P', s=250, label='Disagreement', color=disagree_color,
edgecolor='k', linewidth=0.75, alpha=0.5
)
# Threshold lines
ax.axvline(x=conf_thresh, linestyle='--', color='gray')
ax.axhline(y=js_thresh, linestyle='--', color='gray')
# Axis limits with margin
x_min, x_max = combined_confidence.min(), combined_confidence.max()
y_min, y_max = js_values.min(), js_values.max()
x_margin = (x_max - x_min) * 0.05
y_margin = (y_max - y_min) * 0.05
ax.set_xlim(x_min - x_margin, x_max + x_margin)
ax.set_ylim(y_min - y_margin, y_max + y_margin)
# ax.set_aspect('equal', 'box')
ax.xaxis.set_major_locator(MaxNLocator(6))
ax.yaxis.set_major_locator(MaxNLocator(6))
ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
# Aesthetics: no inner grid, outside ticks
ax.set_facecolor('white')
ax.xaxis.set_tick_params(which='both', bottom=True, top=False, labelbottom=True, labelsize=13)
ax.yaxis.set_tick_params(which='both', left=True, right=False, labelleft=True, labelsize=13)
for spine in ax.spines.values():
spine.set_visible(True)
# Axis labels with bold mathbf and larger font
ax.set_xlabel(r'$\mathbf{Combined\ Confidence\ }$'+"\n"+r'$\mathbf{=\ \frac{1}{2}(\max_i\ p_{pt}^{(i)}\ +\ \max_i\ p_{ft}^{(i)})}$', fontsize=13)
ax.set_ylabel(r'$\mathbf{Divergence\ }$'+"\n"+r'$\mathbf{=\ \frac{1}{2}[KL(P_{pt}\|M)\ +\ KL(P_{ft}\|M)]}$', fontsize=13)
# Title and legend with larger fonts
# ax.set_title(
# 'Combined Confidence vs. JS Divergence (Agreement in Violet, Disagreement in Blue)',
# fontsize=18
# )
ax.legend(fontsize=12, frameon=False, loc='best')
# Ensure directory exists and save
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close(fig)