TEOChat / videollava /eval /eval_classification.py
jirvin16's picture
Initial commit
134cb11
"""
Segmentation metric code dapted from code for XView2: A Strong Baseline
Xview2_Strong_Baseline/legacy/xview2_metrics.py
Xview2_Strong_Baseline/legacy/create_masks.py
"""
# add python path
# import sys
# import os
# sys.path.append('/deep/u/emily712/aicc-win24-geo-vlm/videollava/')
import json
import string
import numpy as np
import cv2
from collections import defaultdict, Counter
from nltk.tokenize import word_tokenize
from shapely.geometry import Polygon
from pathlib import Path
from sklearn.metrics import f1_score
from tqdm import tqdm
def compute_tp_fn_fp(pred: np.ndarray, targ: np.ndarray, c: int):
"""
Computes the number of TPs, FNs, FPs, between a prediction (x) and a target (y) for the desired class (c)
Args:
pred (np.ndarray): prediction
targ (np.ndarray): target
c (int): positive class
"""
TP = np.logical_and(pred == c, targ == c).sum()
FN = np.logical_and(pred != c, targ == c).sum()
FP = np.logical_and(pred == c, targ != c).sum()
return [TP, FN, FP]
def accuracy_precision_recall(answer_path, dataset, ignore_punctuation=True, verbose=True):
# Replace with the path to the answers file
if type(answer_path) == dict:
results = answer_path
else:
with open(answer_path) as json_data:
results = json.load(json_data)
task_total = defaultdict(int)
task_tp = defaultdict(int)
binary_classification = defaultdict(bool)
binary_fp = defaultdict(int)
binary_fn = defaultdict(int)
# Dictionary of dictionaries. Key: task. Value: {class: count}
ground_truths = defaultdict(dict)
values = defaultdict(list)
accepted_tasks = [
"temporal_question_answering",
"region_based_question_answering",
"temporal_region_based_question_answering",
"question_answering",
"temporal_referring_expression",
"rural_urban",
"comp",
"presence",
"count",
"change_to_what",
"smallest_change",
"change_or_not",
"change_ratio",
"largest_change",
"change_ratio_types",
"increase_or_not",
"decrease_or_not"
]
for result in results.values():
if "task" in result and not any(result["task"].startswith(task) for task in accepted_tasks):
continue
# Clean predicted string if necessary
result["predicted"] = result["predicted"].lower()
result["ground_truth"] = result["ground_truth"].lower()
if ignore_punctuation:
result["predicted"] = ''.join(ch for ch in result["predicted"] if ch not in string.punctuation)
result["ground_truth"] = ''.join(ch for ch in result["ground_truth"] if ch not in string.punctuation)
if verbose:
values["predicted"].append(result["predicted"])
values["ground_truth"].append(result["ground_truth"])
values["correct_incorrect"].append("Correct" if result["predicted"] == result["ground_truth"] else "Incorrect")
if "task" not in result:
result["task"] = dataset
# True positive
if result["predicted"] == result["ground_truth"]:
task_tp[result["task"]] += 1
task_total[result["task"]] += 1
# If binary classification (yes/no question), calculate precision and recall metrics
binary_classification[result["task"]] = binary_classification[result["task"]] or (result["ground_truth"] in ["yes", "no"])
if binary_classification[result["task"]]:
if result["predicted"] != "no" and result["ground_truth"] == "no":
binary_fp[result["task"]] += 1
if result["predicted"] != "yes" and result["ground_truth"] == "yes":
binary_fn[result["task"]] += 1
# Update ground truth counts for the task
task = result["task"]
class_label = result["ground_truth"]
ground_truths[task][class_label] = ground_truths[task].get(class_label, 0) + 1
# Print tab separated values
if verbose:
max_len = max(len(v) for v in values["ground_truth"]) + 5
print("Predicted" + " " * (max_len - 9) + "\tGround Truth" + " " * (max_len - 12) + "\tCorrect/Incorrect")
for i in range(len(values["predicted"])):
print(values["predicted"][i] + " " * (max_len - len(values["predicted"][i])) + "\t" + values["ground_truth"][i] + " " * (max_len - len(values["ground_truth"][i])) + "\t" + values["correct_incorrect"][i])
total_tp = 0
total_predictions = 0
for task in task_tp:
acc_string = "Accuracy"
if ignore_punctuation:
acc_string += " (ignoring punctuation)"
print(f"{acc_string} for {task}: {round((task_tp[task] / task_total[task]), 4) * 100}%")
if binary_classification[task]:
if (task_tp[task] + binary_fp[task]) > 0:
print(f"Precision (ignoring punctuation) for {task}: {round((task_tp[task] / (task_tp[task] + binary_fp[task])), 3) * 100}%")
if (task_tp[task] + binary_fn[task]) > 0:
print(f"Recall (ignoring punctuation) for {task}: {round((task_tp[task] / (task_tp[task] + binary_fn[task])), 3) * 100}%")
majority_class = max(ground_truths[task], key=ground_truths[task].get)
majority_class_percentage = (ground_truths[task][majority_class] / task_total[task]) * 100
print(f"Majority class for {task}: {majority_class}, Percentage: {round(majority_class_percentage, 4)}%")
total_tp += task_tp[task]
total_predictions += task_total[task]
if total_predictions == 0:
print("No predictions made.")
else:
total_accuracy = (total_tp / total_predictions) * 100
print(f"Overall Accuracy: {round(total_accuracy, 3)}%")
# For testing accuracy/precision/recall on a particular script without running inference
if __name__ == '__main__':
root_dir = '/deep/u/jirvin16/aicc/aicc-win24-geo-vlm/videollava/scripts/geovlm/eval/QFabric/answers/'
answer_path = root_dir + "video-llava-7b-8bit-lora-final-no-metadata-zero-gc-acc8-freq-no-geochat-checkpoint-8000_qfabric_test_aux_data_test_prompt_strategy_interleave_chronological_prefix_True_load_8bit_True_load_4bit_False_delete_system_prompt_False.json"
accuracy_precision_recall(answer_path, dataset="qfabric", ignore_punctuation=True, verbose=False)