CRSTC / metrics /event_based_metrics.py
CAPYLEE's picture
Upload 218 files
dd1cb8f verified
import sklearn
import numpy as np
import matplotlib.pyplot as plt
# Args:
# y_true (ndarray): Ground truth labels (0 or 1).
# y_pred (ndarray): Predicted labels (0 or 1).
# Event-based metrics
def event_metrics(y_true, y_pred, tolerance, overlap_threshold=0.7, switch=False):
if switch:
y_pred = (1 - y_pred > 0)
else:
y_pred = y_pred > 0
# Create empty list for storing true events
true_events = []
# Initilize start index
start = None
for i, label in enumerate(y_true):
if label == 1 and start is None:
start = i
elif label == 0 and start is not None:
true_events.append((start, i - 1))
start = None
if start is not None:
true_events.append((start, len(y_true) - 1))
pred_events = []
start = None
for i, label in enumerate(y_pred):
if label == 1 and start is None:
start = i
elif label == 0 and start is not None:
pred_events.append((start, i - 1))
start = None
if start is not None:
pred_events.append((start, len(pred_events) - 1))
# Highlight overlapping events
# Intialize true positive and overlap events
tp, fp, fn = 0, 0, 0
counted_events = []
fake_events = []
undetected_events = []
pred_check = pred_events[:]
iou_list = []
for true_event in true_events:
tp_event = 0
for pred_event in pred_events:
lower_bound = true_event[0] - tolerance
upper_bound = true_event[1] + tolerance
# Calculate overlap rate
overlap_rate = 0
if lower_bound <= pred_event[0] and upper_bound >= pred_event[1]:
overlap_start = max(true_event[0], pred_event[0])
overlap_end = min(true_event[1], pred_event[1])
overlap_length = overlap_end - overlap_start + 1
true_length = true_event[1] - true_event[0] + 1
pred_length = pred_event[1] - pred_event[0] + 1
overlap_rate = overlap_length / min(true_length, pred_length)
# Range check
if overlap_rate >= overlap_threshold:
union_start = min(true_event[0], pred_event[0])
union_end = max(true_event[1], pred_event[1])
union_length = union_end - union_start + 1
iou = overlap_length / union_length
iou_list.append(iou)
# True positive: correctly detected events
if pred_event in pred_check:
pred_check.remove(pred_event)
if tp_event == 0:
tp_event = 1
counted_events.append((true_event[0], true_event[1]))
# False negative: events in true label that have not been correctly detected according to the definition
if tp_event == 0:
fn += 1
undetected_events.append((true_event[0], true_event[1]))
tp += tp_event
# False positive: events in prediction that are not correct according to the definition
if pred_check:
for pred_event in pred_check:
if pred_event[1] - pred_event[0] > tolerance:
fp += 1
fake_events.append((pred_event[0], pred_event[1]))
if tp == 0 and fn == 0 and fp == 0:
F = 1
else:
# Calculation of F-Score
P = tp / (tp + fp) if (tp + fp) != 0 else 0
R = tp / (tp + fn) if (tp + fn) != 0 else 0
F = 2 * P * R / (P + R) if (P + R) != 0 else 0
# Calculation of IOU
if iou_list == []:
IOU = 0
else:
IOU = np.mean(iou_list)
return F, IOU, counted_events, fake_events, undetected_events
def event_visualization(y_true, y_pred, counted_events, fake_events, undetected_events):
# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(range(len(y_true)), y_true, label='True Label')
plt.plot(range(len(y_pred)), y_pred, label='Predicted Label')
for event in counted_events:
plt.axvspan(event[0], event[1], alpha=0.3, color='green', label='Overlap event')
for event in fake_events:
plt.axvspan(event[0], event[1], alpha=0.3, color='red', label='Fake event')
for event in undetected_events:
plt.axvspan(event[0], event[1], alpha=0.3, color='blue', label='Undetected event')
# Add labels and title
plt.xlabel('Index')
plt.ylabel('Label')
plt.title('Overlapping Events Visualization')
plt.legend()
plt.grid(True)
plt.show()
# Reference: https://doi.org/10.3390/app6060162