|
""" |
|
Segmentation metric code dapted from code for XView2: A Strong Baseline |
|
Xview2_Strong_Baseline/legacy/xview2_metrics.py |
|
Xview2_Strong_Baseline/legacy/create_masks.py |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import string |
|
import numpy as np |
|
import cv2 |
|
from collections import defaultdict, Counter |
|
from nltk.tokenize import word_tokenize |
|
from shapely.geometry import Polygon |
|
from pathlib import Path |
|
from sklearn.metrics import f1_score |
|
from tqdm import tqdm |
|
|
|
|
|
def compute_tp_fn_fp(pred: np.ndarray, targ: np.ndarray, c: int): |
|
""" |
|
Computes the number of TPs, FNs, FPs, between a prediction (x) and a target (y) for the desired class (c) |
|
|
|
Args: |
|
pred (np.ndarray): prediction |
|
targ (np.ndarray): target |
|
c (int): positive class |
|
""" |
|
TP = np.logical_and(pred == c, targ == c).sum() |
|
FN = np.logical_and(pred != c, targ == c).sum() |
|
FP = np.logical_and(pred == c, targ != c).sum() |
|
return [TP, FN, FP] |
|
|
|
|
|
def accuracy_precision_recall(answer_path, dataset, ignore_punctuation=True, verbose=True): |
|
|
|
if type(answer_path) == dict: |
|
results = answer_path |
|
else: |
|
with open(answer_path) as json_data: |
|
results = json.load(json_data) |
|
|
|
task_total = defaultdict(int) |
|
task_tp = defaultdict(int) |
|
|
|
binary_classification = defaultdict(bool) |
|
binary_fp = defaultdict(int) |
|
binary_fn = defaultdict(int) |
|
|
|
|
|
ground_truths = defaultdict(dict) |
|
|
|
values = defaultdict(list) |
|
|
|
accepted_tasks = [ |
|
"temporal_question_answering", |
|
"region_based_question_answering", |
|
"temporal_region_based_question_answering", |
|
"question_answering", |
|
"temporal_referring_expression", |
|
"rural_urban", |
|
"comp", |
|
"presence", |
|
"count", |
|
"change_to_what", |
|
"smallest_change", |
|
"change_or_not", |
|
"change_ratio", |
|
"largest_change", |
|
"change_ratio_types", |
|
"increase_or_not", |
|
"decrease_or_not" |
|
] |
|
|
|
for result in results.values(): |
|
if "task" in result and not any(result["task"].startswith(task) for task in accepted_tasks): |
|
continue |
|
|
|
|
|
result["predicted"] = result["predicted"].lower() |
|
result["ground_truth"] = result["ground_truth"].lower() |
|
if ignore_punctuation: |
|
result["predicted"] = ''.join(ch for ch in result["predicted"] if ch not in string.punctuation) |
|
result["ground_truth"] = ''.join(ch for ch in result["ground_truth"] if ch not in string.punctuation) |
|
if verbose: |
|
values["predicted"].append(result["predicted"]) |
|
values["ground_truth"].append(result["ground_truth"]) |
|
values["correct_incorrect"].append("Correct" if result["predicted"] == result["ground_truth"] else "Incorrect") |
|
if "task" not in result: |
|
result["task"] = dataset |
|
|
|
|
|
if result["predicted"] == result["ground_truth"]: |
|
task_tp[result["task"]] += 1 |
|
task_total[result["task"]] += 1 |
|
|
|
|
|
binary_classification[result["task"]] = binary_classification[result["task"]] or (result["ground_truth"] in ["yes", "no"]) |
|
if binary_classification[result["task"]]: |
|
if result["predicted"] != "no" and result["ground_truth"] == "no": |
|
binary_fp[result["task"]] += 1 |
|
if result["predicted"] != "yes" and result["ground_truth"] == "yes": |
|
binary_fn[result["task"]] += 1 |
|
|
|
|
|
task = result["task"] |
|
class_label = result["ground_truth"] |
|
ground_truths[task][class_label] = ground_truths[task].get(class_label, 0) + 1 |
|
|
|
|
|
if verbose: |
|
max_len = max(len(v) for v in values["ground_truth"]) + 5 |
|
print("Predicted" + " " * (max_len - 9) + "\tGround Truth" + " " * (max_len - 12) + "\tCorrect/Incorrect") |
|
for i in range(len(values["predicted"])): |
|
print(values["predicted"][i] + " " * (max_len - len(values["predicted"][i])) + "\t" + values["ground_truth"][i] + " " * (max_len - len(values["ground_truth"][i])) + "\t" + values["correct_incorrect"][i]) |
|
|
|
total_tp = 0 |
|
total_predictions = 0 |
|
for task in task_tp: |
|
acc_string = "Accuracy" |
|
if ignore_punctuation: |
|
acc_string += " (ignoring punctuation)" |
|
print(f"{acc_string} for {task}: {round((task_tp[task] / task_total[task]), 4) * 100}%") |
|
|
|
if binary_classification[task]: |
|
if (task_tp[task] + binary_fp[task]) > 0: |
|
print(f"Precision (ignoring punctuation) for {task}: {round((task_tp[task] / (task_tp[task] + binary_fp[task])), 3) * 100}%") |
|
if (task_tp[task] + binary_fn[task]) > 0: |
|
print(f"Recall (ignoring punctuation) for {task}: {round((task_tp[task] / (task_tp[task] + binary_fn[task])), 3) * 100}%") |
|
|
|
majority_class = max(ground_truths[task], key=ground_truths[task].get) |
|
majority_class_percentage = (ground_truths[task][majority_class] / task_total[task]) * 100 |
|
print(f"Majority class for {task}: {majority_class}, Percentage: {round(majority_class_percentage, 4)}%") |
|
|
|
total_tp += task_tp[task] |
|
total_predictions += task_total[task] |
|
|
|
if total_predictions == 0: |
|
print("No predictions made.") |
|
else: |
|
total_accuracy = (total_tp / total_predictions) * 100 |
|
print(f"Overall Accuracy: {round(total_accuracy, 3)}%") |
|
|
|
|
|
if __name__ == '__main__': |
|
root_dir = '/deep/u/jirvin16/aicc/aicc-win24-geo-vlm/videollava/scripts/geovlm/eval/QFabric/answers/' |
|
answer_path = root_dir + "video-llava-7b-8bit-lora-final-no-metadata-zero-gc-acc8-freq-no-geochat-checkpoint-8000_qfabric_test_aux_data_test_prompt_strategy_interleave_chronological_prefix_True_load_8bit_True_load_4bit_False_delete_system_prompt_False.json" |
|
accuracy_precision_recall(answer_path, dataset="qfabric", ignore_punctuation=True, verbose=False) |
|
|