ece / test_resnet-cifar_logits.py
jordyvl's picture
fix for equal mass binning
0a06e4b
"""
This testing script loads actual probabilisitic predictions from a resnet finetuned on CIFAR
There are a number of logits-groundtruth pickles available @ https://github.com/markus93/NN_calibration/tree/master/logits
[Seems to have moved from Git-LFS to sharepoint]
https://tartuulikool-my.sharepoint.com/:f:/g/personal/markus93_ut_ee/EmW0xbhcic5Ou0lRbTrySOUBF2ccSsN7lo6lvSfuG1djew?e=l0TErb
See https://github.com/markus93/NN_calibration/blob/master/logits/Readme.txt to decode the [model_dataset] filenames
As a bonus, one could consider temperature scaling and measuring after calibration.
"""
import sys
import numpy as np
import scipy.stats as stats
from scipy.special import softmax
import pickle
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from ece import create_bins, discretize_into_bins, ECE
# Open file with pickled variables
def unpickle_probs(file, verbose=0, normalize=True):
with open(file, "rb") as f: # Python 3: open(..., 'rb')
y1, y2 = pickle.load(f) # unpickle the content
if isinstance(y1, tuple):
y_probs_val, y_val = y1
y_probs_test, y_test = y2
else:
y_probs_val, y_probs_test, y_val, y_test = train_test_split(
y1, y2.reshape(-1, 1), test_size=len(y2) - 5000, random_state=15
) # Splits the data in the case of pretrained models
if normalize:
y_probs_val = softmax(y_probs_val, -1)
y_probs_test = softmax(y_probs_test, -1)
if verbose:
print(
"y_probs_val:", y_probs_val.shape
) # (5000, 10); Validation set probabilities of predictions
print("y_true_val:", y_val.shape) # (5000, 1); Validation set true labels
print("y_probs_test:", y_probs_test.shape) # (10000, 10); Test set probabilities
print("y_true_test:", y_test.shape) # (10000, 1); Test set true labels
return ((y_probs_val, y_val.ravel()), (y_probs_test, y_test.ravel()))
def unpickle_structured_probs(valpath=None, testpath=None):
valpath = "/home/jordy/code/gordon/arkham/arkham/StructuredCalibration/models/jordyvl/bert-base-cased_conll2003-sm-first-ner_validation_UTY.pickle"
testpath = "/home/jordy/code/gordon/arkham/arkham/StructuredCalibration/models/jordyvl/bert-base-cased_conll2003-sm-first-ner_test_UTY.pickle"
with open(valpath, "rb") as f:
X_val, _, y_val, _ = pickle.load(f)
with open(testpath, "rb") as f:
X_test, _, y_test, _ = pickle.load(f)
X_val = np.log(X_val) # originally exponentiated [different purposes]
X_test = np.log(X_test) # originally exponentiated [different purposes]
# structured logits
"""
ALTERNATE equal mass binning
"""
# Define data types.
from typing import List, Tuple, NewType, TypeVar
Data = List[Tuple[float, float]] # List of (predicted_probability, true_label).
Bins = List[float] # List of bin boundaries, excluding 0.0, but including 1.0.
BinnedData = List[Data] # binned_data[i] contains the data in bin i.
T = TypeVar('T')
eps = 1e-6
def split(sequence: List[T], parts: int) -> List[List[T]]:
assert parts <= len(sequence), "more bins than probabilities"
part_size = int(np.ceil(len(sequence) * 1.0 / parts))
assert part_size * parts >= len(sequence), "no missing instances when partitioning"
assert (part_size - 1) * parts < len(sequence), "dropping 1 does not make for missing"
return [sequence[i:i + part_size] for i in range(0, len(sequence), part_size)]
def get_equal_bins(probs: List[float], n_bins: int=10) -> Bins:
"""Get bins that contain approximately an equal number of data points."""
sorted_probs = sorted(probs)
binned_data = split(sorted_probs, n_bins)
bins: Bins = []
for i in range(len(binned_data) - 1):
last_prob = binned_data[i][-1]
next_first_prob = binned_data[i + 1][0]
bins.append((last_prob + next_first_prob) / 2.0)
bins.append(1.0)
bins = sorted(list(set(bins))) #this is the special thing!
return bins
def histedges_equalN(x, nbin):
npt = len(x)
return np.interp(np.linspace(0, npt, nbin + 1),
np.arange(npt),
np.sort(x))
'''
bin_upper_edges = histedges_equalN(P, n_bins)
#n, bins, patches = plt.hist(x, histedges_equalN(x, 10))
'''
def test_equalmass_binning(P, Y):
#probs = np.array([0.63, 0.2, 0.2, 0, 0.95, 0.05, 0.72, 0.1, 0.2])
kwargs = dict(
n_bins= 10,
scheme="equal-mass",
bin_range=None,
proxy="upper-edge",
#proxy="center",
p=1,
detail=True,
)
if P.ndim == 2: #can assume ECE
p_max = np.max(P, -1) # create p̂ as top-1 softmax probability € [0,1]
eqr_bins = create_bins(n_bins=kwargs["n_bins"], scheme="equal-range", bin_range=kwargs["bin_range"], P=p_max)
eqm_bins = create_bins(n_bins=kwargs["n_bins"], scheme=kwargs["scheme"], bin_range=kwargs["bin_range"], P=p_max)
#alternate_eqm_bins = get_equal_bins(p_max, kwargs["n_bins"])
eqr_hist = np.digitize(p_max, eqr_bins, right=True)
eqm_hist = np.digitize(p_max, eqm_bins, right=True)
eqml_hist = np.digitize(p_max, eqm_bins, right=False)
#eqm_bins = [0] + eqm_bins
other_hist = discretize_into_bins(np.expand_dims(p_max, 0), eqm_bins)
hist_difference = stats.power_divergence(eqr_hist, eqm_hist, lambda_="pearson") #chisquare
#plt.hist(eqr_hist, color="green", label="equal-range")
plt.hist(eqm_hist, color="blue", label="equal-mass")
plt.legend()
#plt.show()
res = ECE()._compute(P, Y, **kwargs)
print(f"eqm ECE: {res['ECE']}")
kwargs["scheme"] = "equal-range"
res = ECE()._compute(P, Y, **kwargs)
print(f"eqr ECE: {res['ECE']}")
# res = ECE()._compute(predictions, references, detail=True)
# print(f"ECE: {res['ECE']}")
if __name__ == "__main__":
FILE_PATH = sys.argv[1] if len(sys.argv) > 1 else "resnet110_c10_logits.p"
(p_val, y_val), (p_test, y_test) = unpickle_probs(FILE_PATH, False, True)
test_equalmass_binning(p_val, y_val)
# do on val