dcd018's picture
Initial commit.
0cb1b63
"""
Copyright (C) 2021 Microsoft Corporation
"""
import os
import sys
from collections import Counter
import json
import statistics as stat
from datetime import datetime
import multiprocessing
from itertools import repeat
from functools import partial
import tqdm
import math
import torch
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from fitz import Rect
from PIL import Image
sys.path.append("../detr")
import util.misc as utils
from ms_datasets.coco_eval import CocoEvaluator
import postprocess
import grits
from grits import grits_con, grits_top, grits_loc
structure_class_names = [
'table', 'table column', 'table row', 'table column header',
'table projected row header', 'table spanning cell', 'no object'
]
structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
structure_class_thresholds = {
"table": 0.5,
"table column": 0.5,
"table row": 0.5,
"table column header": 0.5,
"table projected row header": 0.5,
"table spanning cell": 0.5,
"no object": 10
}
normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def objects_to_cells(bboxes, labels, scores, page_tokens, structure_class_names, structure_class_thresholds, structure_class_map):
bboxes, scores, labels = postprocess.apply_class_thresholds(bboxes, labels, scores,
structure_class_names,
structure_class_thresholds)
table_objects = []
for bbox, score, label in zip(bboxes, scores, labels):
table_objects.append({'bbox': bbox, 'score': score, 'label': label})
table = {'objects': table_objects, 'page_num': 0}
table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
if len(table_class_objects) > 1:
table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
try:
table_bbox = list(table_class_objects[0]['bbox'])
except:
table_bbox = (0,0,1000,1000)
tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5]
# Determine the table cell structure from the objects
table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table,
structure_class_names,
structure_class_thresholds)
return table_structures, cells, confidence_score
def cells_to_adjacency_pair_list(cells, key='cell_text'):
# Index the cells by their grid coordinates
cell_nums_by_coordinates = dict()
for cell_num, cell in enumerate(cells):
for row_num in cell['row_nums']:
for column_num in cell['column_nums']:
cell_nums_by_coordinates[(row_num, column_num)] = cell_num
# Count the number of unique rows and columns
row_nums = set()
column_nums = set()
for cell in cells:
for row_num in cell['row_nums']:
row_nums.add(row_num)
for column_num in cell['column_nums']:
column_nums.add(column_num)
num_rows = len(row_nums)
num_columns = len(column_nums)
# For each cell, determine its next neighbors
# - For every row the cell occupies, what is the first cell to the right with text that
# also occupies that row
# - For every column the cell occupies, what is the first cell below with text that
# also occupies that column
adjacency_list = []
adjacency_bboxes = []
for cell1_num, cell1 in enumerate(cells):
# Skip blank cells
if cell1['cell_text'] == '':
continue
adjacent_cell_props = {}
max_column = max(cell1['column_nums'])
max_row = max(cell1['row_nums'])
# For every column the cell occupies...
for column_num in cell1['column_nums']:
# Start from the next row and stop when we encounter a non-blank cell
# This cell is considered adjacent
for current_row in range(max_row+1, num_rows):
cell2_num = cell_nums_by_coordinates[(current_row, column_num)]
cell2 = cells[cell2_num]
if not cell2['cell_text'] == '':
adj_bbox = [(max(cell1['bbox'][0], cell2['bbox'][0])+min(cell1['bbox'][2], cell2['bbox'][2]))/2-3,
cell1['bbox'][3],
(max(cell1['bbox'][0], cell2['bbox'][0])+min(cell1['bbox'][2], cell2['bbox'][2]))/2+3,
cell2['bbox'][1]]
adjacent_cell_props[cell2_num] = ('V', current_row - max_row - 1,
adj_bbox)
break
# For every row the cell occupies...
for row_num in cell1['row_nums']:
# Start from the next column and stop when we encounter a non-blank cell
# This cell is considered adjacent
for current_column in range(max_column+1, num_columns):
cell2_num = cell_nums_by_coordinates[(row_num, current_column)]
cell2 = cells[cell2_num]
if not cell2['cell_text'] == '':
adj_bbox = [cell1['bbox'][2],
(max(cell1['bbox'][1], cell2['bbox'][1])+min(cell1['bbox'][3], cell2['bbox'][3]))/2-3,
cell2['bbox'][0],
(max(cell1['bbox'][1], cell2['bbox'][1])+min(cell1['bbox'][3], cell2['bbox'][3]))/2+3]
adjacent_cell_props[cell2_num] = ('H', current_column - max_column - 1,
adj_bbox)
break
for adjacent_cell_num, props in adjacent_cell_props.items():
cell2 = cells[adjacent_cell_num]
adjacency_list.append((cell1['cell_text'], cell2['cell_text'], props[0], props[1]))
adjacency_bboxes.append(props[2])
return adjacency_list, adjacency_bboxes
def cells_to_adjacency_pair_list_with_blanks(cells, key='cell_text'):
# Index the cells by their grid coordinates
cell_nums_by_coordinates = dict()
for cell_num, cell in enumerate(cells):
for row_num in cell['row_nums']:
for column_num in cell['column_nums']:
cell_nums_by_coordinates[(row_num, column_num)] = cell_num
# Count the number of unique rows and columns
row_nums = set()
column_nums = set()
for cell in cells:
for row_num in cell['row_nums']:
row_nums.add(row_num)
for column_num in cell['column_nums']:
column_nums.add(column_num)
num_rows = len(row_nums)
num_columns = len(column_nums)
# For each cell, determine its next neighbors
# - For every row the cell occupies, what is the next cell to the right
# - For every column the cell occupies, what is the next cell below
adjacency_list = []
adjacency_bboxes = []
for cell1_num, cell1 in enumerate(cells):
adjacent_cell_props = {}
max_column = max(cell1['column_nums'])
max_row = max(cell1['row_nums'])
# For every column the cell occupies...
for column_num in cell1['column_nums']:
# The cell in the next row is adjacent
current_row = max_row + 1
if current_row >= num_rows:
continue
cell2_num = cell_nums_by_coordinates[(current_row, column_num)]
cell2 = cells[cell2_num]
adj_bbox = [(max(cell1['bbox'][0], cell2['bbox'][0])+min(cell1['bbox'][2], cell2['bbox'][2]))/2-3,
cell1['bbox'][3],
(max(cell1['bbox'][0], cell2['bbox'][0])+min(cell1['bbox'][2], cell2['bbox'][2]))/2+3,
cell2['bbox'][1]]
adjacent_cell_props[cell2_num] = ('V', current_row - max_row - 1,
adj_bbox)
# For every row the cell occupies...
for row_num in cell1['row_nums']:
# The cell in the next column is adjacent
current_column = max_column + 1
if current_column >= num_columns:
continue
cell2_num = cell_nums_by_coordinates[(row_num, current_column)]
cell2 = cells[cell2_num]
adj_bbox = [cell1['bbox'][2],
(max(cell1['bbox'][1], cell2['bbox'][1])+min(cell1['bbox'][3], cell2['bbox'][3]))/2-3,
cell2['bbox'][0],
(max(cell1['bbox'][1], cell2['bbox'][1])+min(cell1['bbox'][3], cell2['bbox'][3]))/2+3]
adjacent_cell_props[cell2_num] = ('H', current_column - max_column - 1,
adj_bbox)
for adjacent_cell_num, props in adjacent_cell_props.items():
cell2 = cells[adjacent_cell_num]
adjacency_list.append((cell1['cell_text'], cell2['cell_text'], props[0], props[1]))
adjacency_bboxes.append(props[2])
return adjacency_list, adjacency_bboxes
def dar_con(true_adjacencies, pred_adjacencies):
"""
Directed adjacency relations (DAR) metric, which uses exact match
between adjacent cell text content.
"""
true_c = Counter()
true_c.update([elem for elem in true_adjacencies])
pred_c = Counter()
pred_c.update([elem for elem in pred_adjacencies])
num_true_positives = (sum(true_c.values()) - sum((true_c - pred_c).values()))
fscore, precision, recall = grits.compute_fscore(num_true_positives,
len(true_adjacencies),
len(pred_adjacencies))
return recall, precision, fscore
def dar_con_original(true_cells, pred_cells):
"""
Original DAR metric, where blank cells are disregarded.
"""
true_adjacencies, _ = cells_to_adjacency_pair_list(true_cells)
pred_adjacencies, _ = cells_to_adjacency_pair_list(pred_cells)
return dar_con(true_adjacencies, pred_adjacencies)
def dar_con_new(true_cells, pred_cells):
"""
New version of DAR metric where blank cells count.
"""
true_adjacencies, _ = cells_to_adjacency_pair_list_with_blanks(true_cells)
pred_adjacencies, _ = cells_to_adjacency_pair_list_with_blanks(pred_cells)
return dar_con(true_adjacencies, pred_adjacencies)
def compute_metrics(mode, true_bboxes, true_labels, true_scores, true_cells,
pred_bboxes, pred_labels, pred_scores, pred_cells):
"""
Compute the collection of table structure recognition metrics given
the ground truth and predictions as input.
- bboxes, labels, and scores are required to compute GriTS_RawLoc, which
is GriTS_Loc but on unprocessed bounding boxes, compared with the dilated
ground truth bounding boxes the model is trained on.
- Otherwise, only true_cells and pred_cells are needed.
"""
metrics = {}
# Compute grids/matrices for comparison
true_relspan_grid = np.array(grits.cells_to_relspan_grid(true_cells))
true_bbox_grid = np.array(grits.cells_to_grid(true_cells, key='bbox'))
true_text_grid = np.array(grits.cells_to_grid(true_cells, key='cell_text'), dtype=object)
pred_relspan_grid = np.array(grits.cells_to_relspan_grid(pred_cells))
pred_bbox_grid = np.array(grits.cells_to_grid(pred_cells, key='bbox'))
pred_text_grid = np.array(grits.cells_to_grid(pred_cells, key='cell_text'), dtype=object)
# Compute GriTS_Top (topology)
(metrics['grits_top'],
metrics['grits_precision_top'],
metrics['grits_recall_top'],
metrics['grits_top_upper_bound']) = grits_top(true_relspan_grid,
pred_relspan_grid)
# Compute GriTS_Loc (location)
(metrics['grits_loc'],
metrics['grits_precision_loc'],
metrics['grits_recall_loc'],
metrics['grits_loc_upper_bound']) = grits_loc(true_bbox_grid,
pred_bbox_grid)
# Compute GriTS_Con (text content)
(metrics['grits_con'],
metrics['grits_precision_con'],
metrics['grits_recall_con'],
metrics['grits_con_upper_bound']) = grits_con(true_text_grid,
pred_text_grid)
# Compute content accuracy
metrics['acc_con'] = int(metrics['grits_con'] == 1)
if mode == 'grits-all':
# Compute grids/matrices for comparison
true_cell_dilatedbbox_grid = np.array(grits.output_to_dilatedbbox_grid(true_bboxes, true_labels, true_scores))
pred_cell_dilatedbbox_grid = np.array(grits.output_to_dilatedbbox_grid(pred_bboxes, pred_labels, pred_scores))
# Compute GriTS_RawLoc (location using unprocessed bounding boxes)
(metrics['grits_rawloc'],
metrics['grits_precision_rawloc'],
metrics['grits_recall_rawloc'],
metrics['grits_rawloc_upper_bound']) = grits_loc(true_cell_dilatedbbox_grid,
pred_cell_dilatedbbox_grid)
# Compute original DAR (directed adjacency relations) metric
(metrics['dar_recall_con_original'], metrics['dar_precision_con_original'],
metrics['dar_con_original']) = dar_con_original(true_cells, pred_cells)
# Compute updated DAR (directed adjacency relations) metric
(metrics['dar_recall_con'], metrics['dar_precision_con'],
metrics['dar_con']) = dar_con_new(true_cells, pred_cells)
return metrics
def compute_statistics(structures, cells):
statistics = {}
statistics['num_rows'] = len(structures['rows'])
statistics['num_columns'] = len(structures['columns'])
statistics['num_cells'] = len(cells)
statistics['num_spanning_cells'] = len([cell for cell in cells if len(cell['row_nums']) > 1
or len(cell['column_nums']) > 1])
header_rows = set()
for cell in cells:
if cell['header']:
header_rows = header_rows.union(set(cell['row_nums']))
statistics['num_header_rows'] = len(header_rows)
row_heights = [float(row['bbox'][3]-row['bbox'][1]) for row in structures['rows']]
if len(row_heights) >= 2:
statistics['row_height_coefficient_of_variation'] = stat.stdev(row_heights) / stat.mean(row_heights)
else:
statistics['row_height_coefficient_of_variation'] = 0
column_widths = [float(column['bbox'][2]-column['bbox'][0]) for column in structures['columns']]
if len(column_widths) >= 2:
statistics['column_width_coefficient_of_variation'] = stat.stdev(column_widths) / stat.mean(column_widths)
else:
statistics['column_width_coefficient_of_variation'] = 0
return statistics
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
def get_bbox_decorations(data_type, label):
if label == 0:
if data_type == 'detection':
return 'brown', 0.05, 3, '//'
else:
return 'brown', 0, 3, None
elif label == 1:
return 'red', 0.15, 2, None
elif label == 2:
return 'blue', 0.15, 2, None
elif label == 3:
return 'magenta', 0.2, 3, '//'
elif label == 4:
return 'cyan', 0.2, 4, '//'
elif label == 5:
return 'green', 0.2, 4, '\\\\'
return 'gray', 0, 0, None
def compute_metrics_summary(sample_metrics, mode):
"""
Print a formatted summary of the table structure recognition metrics
averaged over all samples.
"""
metrics_summary = {}
metric_names = ['acc_con', 'grits_top', 'grits_con', 'grits_loc']
if mode == 'grits-all':
metric_names += ['grits_rawloc', 'dar_con_original', 'dar_con']
simple_samples = [entry for entry in sample_metrics if entry['num_spanning_cells'] == 0]
metrics_summary['simple'] = {'num_tables': len(simple_samples)}
if len(simple_samples) > 0:
for metric_name in metric_names:
metrics_summary['simple'][metric_name] = np.mean([elem[metric_name] for elem in simple_samples])
complex_samples = [entry for entry in sample_metrics if entry['num_spanning_cells'] > 0]
metrics_summary['complex'] = {'num_tables': len(complex_samples)}
if len(complex_samples) > 0:
for metric_name in metric_names:
metrics_summary['complex'][metric_name] = np.mean([elem[metric_name] for elem in complex_samples])
metrics_summary['all'] = {'num_tables': len(sample_metrics)}
if len(sample_metrics) > 0:
for metric_name in metric_names:
metrics_summary['all'][metric_name] = np.mean([elem[metric_name] for elem in sample_metrics])
return metrics_summary
def print_metrics_line(name, metrics_dict, key, min_length=18):
if len(name) < min_length:
name = ' '*(min_length-len(name)) + name
try:
print("{}: {:.4f}".format(name, metrics_dict[key]))
except:
print("{}: --".format(name))
def print_metrics_summary(metrics_summary, all=False):
"""
Print a formatted summary of the table structure recognition metrics
averaged over all samples.
"""
print('-' * 100)
for table_type in ['simple', 'complex', 'all']:
metrics = metrics_summary[table_type]
print("Results on {} tables ({} total):".format(table_type, metrics['num_tables']))
print_metrics_line("Accuracy_Con", metrics, 'acc_con')
print_metrics_line("GriTS_Top", metrics, 'grits_top')
print_metrics_line("GriTS_Con", metrics, 'grits_con')
print_metrics_line("GriTS_Loc", metrics, 'grits_loc')
if all:
print_metrics_line("GriTS_RawLoc", metrics, 'grits_rawloc')
print_metrics_line("DAR_Con (original)", metrics, 'dar_con_original')
print_metrics_line("DAR_Con", metrics, 'dar_con')
print('-' * 50)
def eval_tsr_sample(target, pred_logits, pred_bboxes, mode):
true_img_size = list(reversed(target['orig_size'].tolist()))
true_bboxes = target['boxes']
true_bboxes = [elem.tolist() for elem in rescale_bboxes(true_bboxes, true_img_size)]
true_labels = target['labels'].tolist()
true_scores = [1 for elem in true_labels]
img_words_filepath = target["img_words_path"]
with open(img_words_filepath, 'r') as f:
true_page_tokens = json.load(f)
true_table_structures, true_cells, _ = objects_to_cells(true_bboxes, true_labels, true_scores,
true_page_tokens, structure_class_names,
structure_class_thresholds, structure_class_map)
m = pred_logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())
pred_scores = list(m.values.detach().cpu().numpy())
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, true_img_size)]
_, pred_cells, _ = objects_to_cells(pred_bboxes, pred_labels, pred_scores,
true_page_tokens, structure_class_names,
structure_class_thresholds, structure_class_map)
metrics = compute_metrics(mode, true_bboxes, true_labels, true_scores, true_cells,
pred_bboxes, pred_labels, pred_scores, pred_cells)
statistics = compute_statistics(true_table_structures, true_cells)
metrics.update(statistics)
metrics['id'] = target["img_path"].split('/')[-1].split('.')[0]
return metrics
def visualize(args, target, pred_logits, pred_bboxes):
img_filepath = target["img_path"]
img_filename = img_filepath.split("/")[-1]
bboxes_out_filename = img_filename.replace(".jpg", "_bboxes.jpg")
bboxes_out_filepath = os.path.join(args.debug_save_dir, bboxes_out_filename)
img = Image.open(img_filepath)
img_size = img.size
m = pred_logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())
pred_scores = list(m.values.detach().cpu().numpy())
pred_bboxes = pred_bboxes.detach().cpu()
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
fig,ax = plt.subplots(1)
ax.imshow(img, interpolation='lanczos')
for bbox, label, score in zip(pred_bboxes, pred_labels, pred_scores):
if ((args.data_type == 'structure' and not label > 5)
or (args.data_type == 'detection' and not label > 1)
and score > 0.5):
color, alpha, linewidth, hatch = get_bbox_decorations(args.data_type,
label)
# Fill
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
linewidth=linewidth, alpha=alpha,
edgecolor='none',facecolor=color,
linestyle=None)
ax.add_patch(rect)
# Hatch
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
linewidth=1, alpha=0.4,
edgecolor=color,facecolor='none',
linestyle='--',hatch=hatch)
ax.add_patch(rect)
# Edge
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
linewidth=linewidth,
edgecolor=color,facecolor='none',
linestyle="--")
ax.add_patch(rect)
fig.set_size_inches((15, 15))
plt.axis('off')
plt.savefig(bboxes_out_filepath, bbox_inches='tight', dpi=100)
if args.data_type == 'structure':
img_words_filepath = os.path.join(args.table_words_dir, img_filename.replace(".jpg", "_words.json"))
cells_out_filename = img_filename.replace(".jpg", "_cells.jpg")
cells_out_filepath = os.path.join(args.debug_save_dir, cells_out_filename)
with open(img_words_filepath, 'r') as f:
tokens = json.load(f)
_, pred_cells, _ = objects_to_cells(pred_bboxes, pred_labels, pred_scores,
tokens, structure_class_names,
structure_class_thresholds, structure_class_map)
fig,ax = plt.subplots(1)
ax.imshow(img, interpolation='lanczos')
for cell in pred_cells:
bbox = cell['bbox']
if cell['header']:
alpha = 0.3
else:
alpha = 0.125
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1,
edgecolor='none',facecolor="magenta", alpha=alpha)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1,
edgecolor="magenta",facecolor='none',linestyle="--",
alpha=0.08, hatch='///')
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1,
edgecolor="magenta",facecolor='none',linestyle="--")
ax.add_patch(rect)
fig.set_size_inches((15, 15))
plt.axis('off')
plt.savefig(cells_out_filepath, bbox_inches='tight', dpi=100)
plt.close('all')
@torch.no_grad()
def evaluate(args, model, criterion, postprocessors, data_loader, base_ds, device):
st_time = datetime.now()
model.eval()
criterion.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Test:'
iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
coco_evaluator = CocoEvaluator(base_ds, iou_types)
if args.data_type == "structure":
tsr_metrics = []
pred_logits_collection = []
pred_bboxes_collection = []
targets_collection = []
num_batches = len(data_loader)
print_every = max(args.eval_step, int(math.ceil(num_batches / 100)))
batch_num = 0
for samples, targets in metric_logger.log_every(data_loader, print_every, header):
batch_num += 1
samples = samples.to(device)
for t in targets:
for k, v in t.items():
if not k == 'img_path':
t[k] = v.to(device)
outputs = model(samples)
if args.debug:
for target, pred_logits, pred_boxes in zip(targets, outputs['pred_logits'], outputs['pred_boxes']):
visualize(args, target, pred_logits, pred_boxes)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {k: v * weight_dict[k]
for k, v in loss_dict_reduced.items() if k in weight_dict}
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
for k, v in loss_dict_reduced.items()}
metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
**loss_dict_reduced_scaled,
**loss_dict_reduced_unscaled)
metric_logger.update(class_error=loss_dict_reduced['class_error'])
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['bbox'](outputs, orig_target_sizes)
res = {target['image_id'].item(): output for target, output in zip(targets, results)}
if coco_evaluator is not None:
coco_evaluator.update(res)
if args.data_type == "structure":
pred_logits_collection += list(outputs['pred_logits'].detach().cpu())
pred_bboxes_collection += list(outputs['pred_boxes'].detach().cpu())
for target in targets:
for k, v in target.items():
if not k == 'img_path':
target[k] = v.cpu()
img_filepath = target["img_path"]
img_filename = img_filepath.split("/")[-1]
img_words_filepath = os.path.join(args.table_words_dir, img_filename.replace(".jpg", "_words.json"))
target["img_words_path"] = img_words_filepath
targets_collection += targets
if batch_num % args.eval_step == 0 or batch_num == num_batches:
arguments = zip(targets_collection, pred_logits_collection, pred_bboxes_collection,
repeat(args.mode))
with multiprocessing.Pool(args.eval_pool_size) as pool:
metrics = pool.starmap_async(eval_tsr_sample, arguments).get()
tsr_metrics += metrics
pred_logits_collection = []
pred_bboxes_collection = []
targets_collection = []
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
if coco_evaluator is not None:
coco_evaluator.synchronize_between_processes()
# accumulate predictions from all images
if coco_evaluator is not None:
coco_evaluator.accumulate()
coco_evaluator.summarize()
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if coco_evaluator is not None:
if 'bbox' in postprocessors.keys():
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
if args.data_type == "structure":
# Save sample-level metrics for more analysis
if len(args.metrics_save_filepath) > 0:
with open(args.metrics_save_filepath, 'w') as outfile:
json.dump(tsr_metrics, outfile)
# Compute metrics averaged over all samples
metrics_summary = compute_metrics_summary(tsr_metrics, args.mode)
# Print summary of metrics
print_metrics_summary(metrics_summary)
print("Total time taken for {} samples: {}".format(len(base_ds), datetime.now() - st_time))
return stats, coco_evaluator
def eval_coco(args, model, criterion, postprocessors, data_loader_test, dataset_test, device):
"""
Use this function to do COCO evaluation. Default implementation runs it on
the test set.
"""
pubmed_stats, coco_evaluator = evaluate(args, model, criterion, postprocessors,
data_loader_test, dataset_test,
device)
print("COCO metrics summary: AP50: {:.3f}, AP75: {:.3f}, AP: {:.3f}, AR: {:.3f}".format(
pubmed_stats['coco_eval_bbox'][1], pubmed_stats['coco_eval_bbox'][2],
pubmed_stats['coco_eval_bbox'][0], pubmed_stats['coco_eval_bbox'][8]))