| import json | |
| import numpy as np | |
| from infer_utils import create_mask | |
| from shapely.wkt import loads | |
| from collections import defaultdict | |
| from tqdm import tqdm | |
| def clean_string(s): | |
| return s.replace(' ', '-').replace('.', '').lower() | |
| def get_class_dict(dataset): | |
| if dataset == "qfabric": | |
| class_dict = { | |
| "temporal_region_based_question_answering: What is the development status in this region [bbox] in image N?": | |
| { | |
| "prior-construction": 1, | |
| "greenland ": 2, | |
| "land-cleared": 3, | |
| "excavation": 4, | |
| "materials-dumped": 5, | |
| "construction-started": 6, | |
| "construction-midway": 7, | |
| "construction-done": 8, | |
| "operational": 9 | |
| }, | |
| "region_based_question_answering: Identify the type of urban development that has occurred in this area [bbox].": | |
| { | |
| "residential": 10, | |
| "commercial": 11, | |
| "industrial": 12, | |
| "road": 13, | |
| "demolition": 14, | |
| "mega-projects": 15 | |
| } | |
| } | |
| elif dataset == "xbd": | |
| class_dict = { | |
| "classification: Classify the level of damage experienced by the building at location [bbox] in the second image. Choose from: No damage, Minor Damage, Major Damage, Destroyed.": | |
| { | |
| "no-damage": 1, | |
| "minor-damage": 2, | |
| "major-damage": 3, | |
| "destroyed": 4, | |
| } | |
| } | |
| else: | |
| raise ValueError(f"Dataset {dataset} should not be evaluated on segmentation classification.") | |
| return class_dict | |
| def classification_segmentation(answer_path, dataset, per_class_f1=False, height=256, width=256): | |
| """ | |
| Given the path to the answer file, this function creates segmentation masks on the original polygon for the predicted and ground truth classes. | |
| Returns the class-weighted per-pixel F1 between predicted and ground-truth masks. | |
| """ | |
| with open(answer_path) as f: | |
| results = json.load(f) | |
| classes = get_class_dict(dataset) | |
| class_stats = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0, 'count': 0}) | |
| for result in tqdm(results.values()): | |
| if result['task'] not in classes: | |
| continue | |
| class_dict = classes[result['task']] | |
| predicted_class = clean_string(result['predicted']) | |
| try: | |
| ground_truth_class = clean_string(result["ground_truth"]) | |
| except: | |
| ground_truth_class = clean_string(result["original_answer"]) | |
| original_polygon = loads(result['original_input_polygon']) | |
| pred_msk = np.zeros((height, width), dtype='uint8') | |
| gt_msk = np.zeros((height, width), dtype='uint8') | |
| _msk = create_mask(original_polygon, im_size=(height, width)) | |
| if predicted_class not in class_dict or ground_truth_class not in class_dict: | |
| continue | |
| pred_label = class_dict[predicted_class] | |
| gt_label = class_dict[ground_truth_class] | |
| pred_msk[_msk > 0] = pred_label | |
| gt_msk[_msk > 0] = gt_label | |
| for label in class_dict.values(): | |
| pred_mask = (pred_msk == label) | |
| gt_mask = (gt_msk == label) | |
| tp = np.sum(pred_mask & gt_mask) | |
| fp = np.sum(pred_mask & ~gt_mask) | |
| fn = np.sum(~pred_mask & gt_mask) | |
| class_stats[label]['tp'] += tp | |
| class_stats[label]['fp'] += fp | |
| class_stats[label]['fn'] += fn | |
| class_stats[label]['count'] += np.sum(gt_mask) | |
| scores_dict = {} | |
| for task, class_info in classes.items(): | |
| print(f"Task: {task}") | |
| class_f1_scores = {} | |
| weighted_f1_score = 0 | |
| total_weight = 0 | |
| tp = 0 | |
| fp = 0 | |
| fn = 0 | |
| for class_name, class_label in class_info.items(): | |
| stats = class_stats[class_label] | |
| total_samples = sum(stats['count'] for label, stats in class_stats.items() if label in class_info.values()) | |
| if stats['tp'] + stats['fp'] == 0 or stats['tp'] + stats['fn'] == 0: | |
| f1 = 0.0 | |
| else: | |
| precision = stats['tp'] / (stats['tp'] + stats['fp']) | |
| recall = stats['tp'] / (stats['tp'] + stats['fn']) | |
| if precision + recall == 0: | |
| f1 = 0.0 | |
| else: | |
| f1 = 2 * (precision * recall) / (precision + recall) | |
| class_f1_scores[class_name] = f1 | |
| if stats['count'] > 0: | |
| prevalence_inv = total_samples / stats['count'] | |
| weighted_f1_score += f1 * prevalence_inv | |
| total_weight += prevalence_inv | |
| tp += stats['tp'] | |
| fp += stats['fp'] | |
| fn += stats['fn'] | |
| if tp + fp == 0 or tp + fn == 0: | |
| micro_f1 = 0.0 | |
| else: | |
| micro_f1 = tp / (tp + 0.5 * (fp + fn)) | |
| if total_weight > 0: | |
| weighted_f1_score /= total_weight | |
| else: | |
| weighted_f1_score = 0.0 | |
| scores_dict[task] = (class_f1_scores, weighted_f1_score) | |
| print(f"Per-class F1 scores: {class_f1_scores}") | |
| if dataset == 'qfabric': | |
| print(f"Micro average F1 score: ", micro_f1) | |
| else: | |
| print(f"Weighted average F1 score: {weighted_f1_score}") | |
| return scores_dict |