""" Copyright (C) 2021 Microsoft Corporation """ import os import sys from collections import Counter import json import statistics as stat from datetime import datetime import multiprocessing from itertools import repeat from functools import partial import tqdm import math import torch from torchvision import transforms import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as patches from fitz import Rect from PIL import Image sys.path.append("../detr") import util.misc as utils from ms_datasets.coco_eval import CocoEvaluator import postprocess import grits from grits import grits_con, grits_top, grits_loc structure_class_names = [ 'table', 'table column', 'table row', 'table column header', 'table projected row header', 'table spanning cell', 'no object' ] structure_class_map = {k: v for v, k in enumerate(structure_class_names)} structure_class_thresholds = { "table": 0.5, "table column": 0.5, "table row": 0.5, "table column header": 0.5, "table projected row header": 0.5, "table spanning cell": 0.5, "no object": 10 } normalize = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def objects_to_cells(bboxes, labels, scores, page_tokens, structure_class_names, structure_class_thresholds, structure_class_map): bboxes, scores, labels = postprocess.apply_class_thresholds(bboxes, labels, scores, structure_class_names, structure_class_thresholds) table_objects = [] for bbox, score, label in zip(bboxes, scores, labels): table_objects.append({'bbox': bbox, 'score': score, 'label': label}) table = {'objects': table_objects, 'page_num': 0} table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']] if len(table_class_objects) > 1: table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True) try: table_bbox = list(table_class_objects[0]['bbox']) except: table_bbox = (0,0,1000,1000) tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5] # Determine the table cell structure from the objects table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds) return table_structures, cells, confidence_score def cells_to_adjacency_pair_list(cells, key='cell_text'): # Index the cells by their grid coordinates cell_nums_by_coordinates = dict() for cell_num, cell in enumerate(cells): for row_num in cell['row_nums']: for column_num in cell['column_nums']: cell_nums_by_coordinates[(row_num, column_num)] = cell_num # Count the number of unique rows and columns row_nums = set() column_nums = set() for cell in cells: for row_num in cell['row_nums']: row_nums.add(row_num) for column_num in cell['column_nums']: column_nums.add(column_num) num_rows = len(row_nums) num_columns = len(column_nums) # For each cell, determine its next neighbors # - For every row the cell occupies, what is the first cell to the right with text that # also occupies that row # - For every column the cell occupies, what is the first cell below with text that # also occupies that column adjacency_list = [] adjacency_bboxes = [] for cell1_num, cell1 in enumerate(cells): # Skip blank cells if cell1['cell_text'] == '': continue adjacent_cell_props = {} max_column = max(cell1['column_nums']) max_row = max(cell1['row_nums']) # For every column the cell occupies... for column_num in cell1['column_nums']: # Start from the next row and stop when we encounter a non-blank cell # This cell is considered adjacent for current_row in range(max_row+1, num_rows): cell2_num = cell_nums_by_coordinates[(current_row, column_num)] cell2 = cells[cell2_num] if not cell2['cell_text'] == '': adj_bbox = [(max(cell1['bbox'][0], cell2['bbox'][0])+min(cell1['bbox'][2], cell2['bbox'][2]))/2-3, cell1['bbox'][3], (max(cell1['bbox'][0], cell2['bbox'][0])+min(cell1['bbox'][2], cell2['bbox'][2]))/2+3, cell2['bbox'][1]] adjacent_cell_props[cell2_num] = ('V', current_row - max_row - 1, adj_bbox) break # For every row the cell occupies... for row_num in cell1['row_nums']: # Start from the next column and stop when we encounter a non-blank cell # This cell is considered adjacent for current_column in range(max_column+1, num_columns): cell2_num = cell_nums_by_coordinates[(row_num, current_column)] cell2 = cells[cell2_num] if not cell2['cell_text'] == '': adj_bbox = [cell1['bbox'][2], (max(cell1['bbox'][1], cell2['bbox'][1])+min(cell1['bbox'][3], cell2['bbox'][3]))/2-3, cell2['bbox'][0], (max(cell1['bbox'][1], cell2['bbox'][1])+min(cell1['bbox'][3], cell2['bbox'][3]))/2+3] adjacent_cell_props[cell2_num] = ('H', current_column - max_column - 1, adj_bbox) break for adjacent_cell_num, props in adjacent_cell_props.items(): cell2 = cells[adjacent_cell_num] adjacency_list.append((cell1['cell_text'], cell2['cell_text'], props[0], props[1])) adjacency_bboxes.append(props[2]) return adjacency_list, adjacency_bboxes def cells_to_adjacency_pair_list_with_blanks(cells, key='cell_text'): # Index the cells by their grid coordinates cell_nums_by_coordinates = dict() for cell_num, cell in enumerate(cells): for row_num in cell['row_nums']: for column_num in cell['column_nums']: cell_nums_by_coordinates[(row_num, column_num)] = cell_num # Count the number of unique rows and columns row_nums = set() column_nums = set() for cell in cells: for row_num in cell['row_nums']: row_nums.add(row_num) for column_num in cell['column_nums']: column_nums.add(column_num) num_rows = len(row_nums) num_columns = len(column_nums) # For each cell, determine its next neighbors # - For every row the cell occupies, what is the next cell to the right # - For every column the cell occupies, what is the next cell below adjacency_list = [] adjacency_bboxes = [] for cell1_num, cell1 in enumerate(cells): adjacent_cell_props = {} max_column = max(cell1['column_nums']) max_row = max(cell1['row_nums']) # For every column the cell occupies... for column_num in cell1['column_nums']: # The cell in the next row is adjacent current_row = max_row + 1 if current_row >= num_rows: continue cell2_num = cell_nums_by_coordinates[(current_row, column_num)] cell2 = cells[cell2_num] adj_bbox = [(max(cell1['bbox'][0], cell2['bbox'][0])+min(cell1['bbox'][2], cell2['bbox'][2]))/2-3, cell1['bbox'][3], (max(cell1['bbox'][0], cell2['bbox'][0])+min(cell1['bbox'][2], cell2['bbox'][2]))/2+3, cell2['bbox'][1]] adjacent_cell_props[cell2_num] = ('V', current_row - max_row - 1, adj_bbox) # For every row the cell occupies... for row_num in cell1['row_nums']: # The cell in the next column is adjacent current_column = max_column + 1 if current_column >= num_columns: continue cell2_num = cell_nums_by_coordinates[(row_num, current_column)] cell2 = cells[cell2_num] adj_bbox = [cell1['bbox'][2], (max(cell1['bbox'][1], cell2['bbox'][1])+min(cell1['bbox'][3], cell2['bbox'][3]))/2-3, cell2['bbox'][0], (max(cell1['bbox'][1], cell2['bbox'][1])+min(cell1['bbox'][3], cell2['bbox'][3]))/2+3] adjacent_cell_props[cell2_num] = ('H', current_column - max_column - 1, adj_bbox) for adjacent_cell_num, props in adjacent_cell_props.items(): cell2 = cells[adjacent_cell_num] adjacency_list.append((cell1['cell_text'], cell2['cell_text'], props[0], props[1])) adjacency_bboxes.append(props[2]) return adjacency_list, adjacency_bboxes def dar_con(true_adjacencies, pred_adjacencies): """ Directed adjacency relations (DAR) metric, which uses exact match between adjacent cell text content. """ true_c = Counter() true_c.update([elem for elem in true_adjacencies]) pred_c = Counter() pred_c.update([elem for elem in pred_adjacencies]) num_true_positives = (sum(true_c.values()) - sum((true_c - pred_c).values())) fscore, precision, recall = grits.compute_fscore(num_true_positives, len(true_adjacencies), len(pred_adjacencies)) return recall, precision, fscore def dar_con_original(true_cells, pred_cells): """ Original DAR metric, where blank cells are disregarded. """ true_adjacencies, _ = cells_to_adjacency_pair_list(true_cells) pred_adjacencies, _ = cells_to_adjacency_pair_list(pred_cells) return dar_con(true_adjacencies, pred_adjacencies) def dar_con_new(true_cells, pred_cells): """ New version of DAR metric where blank cells count. """ true_adjacencies, _ = cells_to_adjacency_pair_list_with_blanks(true_cells) pred_adjacencies, _ = cells_to_adjacency_pair_list_with_blanks(pred_cells) return dar_con(true_adjacencies, pred_adjacencies) def compute_metrics(mode, true_bboxes, true_labels, true_scores, true_cells, pred_bboxes, pred_labels, pred_scores, pred_cells): """ Compute the collection of table structure recognition metrics given the ground truth and predictions as input. - bboxes, labels, and scores are required to compute GriTS_RawLoc, which is GriTS_Loc but on unprocessed bounding boxes, compared with the dilated ground truth bounding boxes the model is trained on. - Otherwise, only true_cells and pred_cells are needed. """ metrics = {} # Compute grids/matrices for comparison true_relspan_grid = np.array(grits.cells_to_relspan_grid(true_cells)) true_bbox_grid = np.array(grits.cells_to_grid(true_cells, key='bbox')) true_text_grid = np.array(grits.cells_to_grid(true_cells, key='cell_text'), dtype=object) pred_relspan_grid = np.array(grits.cells_to_relspan_grid(pred_cells)) pred_bbox_grid = np.array(grits.cells_to_grid(pred_cells, key='bbox')) pred_text_grid = np.array(grits.cells_to_grid(pred_cells, key='cell_text'), dtype=object) # Compute GriTS_Top (topology) (metrics['grits_top'], metrics['grits_precision_top'], metrics['grits_recall_top'], metrics['grits_top_upper_bound']) = grits_top(true_relspan_grid, pred_relspan_grid) # Compute GriTS_Loc (location) (metrics['grits_loc'], metrics['grits_precision_loc'], metrics['grits_recall_loc'], metrics['grits_loc_upper_bound']) = grits_loc(true_bbox_grid, pred_bbox_grid) # Compute GriTS_Con (text content) (metrics['grits_con'], metrics['grits_precision_con'], metrics['grits_recall_con'], metrics['grits_con_upper_bound']) = grits_con(true_text_grid, pred_text_grid) # Compute content accuracy metrics['acc_con'] = int(metrics['grits_con'] == 1) if mode == 'grits-all': # Compute grids/matrices for comparison true_cell_dilatedbbox_grid = np.array(grits.output_to_dilatedbbox_grid(true_bboxes, true_labels, true_scores)) pred_cell_dilatedbbox_grid = np.array(grits.output_to_dilatedbbox_grid(pred_bboxes, pred_labels, pred_scores)) # Compute GriTS_RawLoc (location using unprocessed bounding boxes) (metrics['grits_rawloc'], metrics['grits_precision_rawloc'], metrics['grits_recall_rawloc'], metrics['grits_rawloc_upper_bound']) = grits_loc(true_cell_dilatedbbox_grid, pred_cell_dilatedbbox_grid) # Compute original DAR (directed adjacency relations) metric (metrics['dar_recall_con_original'], metrics['dar_precision_con_original'], metrics['dar_con_original']) = dar_con_original(true_cells, pred_cells) # Compute updated DAR (directed adjacency relations) metric (metrics['dar_recall_con'], metrics['dar_precision_con'], metrics['dar_con']) = dar_con_new(true_cells, pred_cells) return metrics def compute_statistics(structures, cells): statistics = {} statistics['num_rows'] = len(structures['rows']) statistics['num_columns'] = len(structures['columns']) statistics['num_cells'] = len(cells) statistics['num_spanning_cells'] = len([cell for cell in cells if len(cell['row_nums']) > 1 or len(cell['column_nums']) > 1]) header_rows = set() for cell in cells: if cell['header']: header_rows = header_rows.union(set(cell['row_nums'])) statistics['num_header_rows'] = len(header_rows) row_heights = [float(row['bbox'][3]-row['bbox'][1]) for row in structures['rows']] if len(row_heights) >= 2: statistics['row_height_coefficient_of_variation'] = stat.stdev(row_heights) / stat.mean(row_heights) else: statistics['row_height_coefficient_of_variation'] = 0 column_widths = [float(column['bbox'][2]-column['bbox'][0]) for column in structures['columns']] if len(column_widths) >= 2: statistics['column_width_coefficient_of_variation'] = stat.stdev(column_widths) / stat.mean(column_widths) else: statistics['column_width_coefficient_of_variation'] = 0 return statistics # for output bounding box post-processing def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=1) def rescale_bboxes(out_bbox, size): img_w, img_h = size b = box_cxcywh_to_xyxy(out_bbox) b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) return b def get_bbox_decorations(data_type, label): if label == 0: if data_type == 'detection': return 'brown', 0.05, 3, '//' else: return 'brown', 0, 3, None elif label == 1: return 'red', 0.15, 2, None elif label == 2: return 'blue', 0.15, 2, None elif label == 3: return 'magenta', 0.2, 3, '//' elif label == 4: return 'cyan', 0.2, 4, '//' elif label == 5: return 'green', 0.2, 4, '\\\\' return 'gray', 0, 0, None def compute_metrics_summary(sample_metrics, mode): """ Print a formatted summary of the table structure recognition metrics averaged over all samples. """ metrics_summary = {} metric_names = ['acc_con', 'grits_top', 'grits_con', 'grits_loc'] if mode == 'grits-all': metric_names += ['grits_rawloc', 'dar_con_original', 'dar_con'] simple_samples = [entry for entry in sample_metrics if entry['num_spanning_cells'] == 0] metrics_summary['simple'] = {'num_tables': len(simple_samples)} if len(simple_samples) > 0: for metric_name in metric_names: metrics_summary['simple'][metric_name] = np.mean([elem[metric_name] for elem in simple_samples]) complex_samples = [entry for entry in sample_metrics if entry['num_spanning_cells'] > 0] metrics_summary['complex'] = {'num_tables': len(complex_samples)} if len(complex_samples) > 0: for metric_name in metric_names: metrics_summary['complex'][metric_name] = np.mean([elem[metric_name] for elem in complex_samples]) metrics_summary['all'] = {'num_tables': len(sample_metrics)} if len(sample_metrics) > 0: for metric_name in metric_names: metrics_summary['all'][metric_name] = np.mean([elem[metric_name] for elem in sample_metrics]) return metrics_summary def print_metrics_line(name, metrics_dict, key, min_length=18): if len(name) < min_length: name = ' '*(min_length-len(name)) + name try: print("{}: {:.4f}".format(name, metrics_dict[key])) except: print("{}: --".format(name)) def print_metrics_summary(metrics_summary, all=False): """ Print a formatted summary of the table structure recognition metrics averaged over all samples. """ print('-' * 100) for table_type in ['simple', 'complex', 'all']: metrics = metrics_summary[table_type] print("Results on {} tables ({} total):".format(table_type, metrics['num_tables'])) print_metrics_line("Accuracy_Con", metrics, 'acc_con') print_metrics_line("GriTS_Top", metrics, 'grits_top') print_metrics_line("GriTS_Con", metrics, 'grits_con') print_metrics_line("GriTS_Loc", metrics, 'grits_loc') if all: print_metrics_line("GriTS_RawLoc", metrics, 'grits_rawloc') print_metrics_line("DAR_Con (original)", metrics, 'dar_con_original') print_metrics_line("DAR_Con", metrics, 'dar_con') print('-' * 50) def eval_tsr_sample(target, pred_logits, pred_bboxes, mode): true_img_size = list(reversed(target['orig_size'].tolist())) true_bboxes = target['boxes'] true_bboxes = [elem.tolist() for elem in rescale_bboxes(true_bboxes, true_img_size)] true_labels = target['labels'].tolist() true_scores = [1 for elem in true_labels] img_words_filepath = target["img_words_path"] with open(img_words_filepath, 'r') as f: true_page_tokens = json.load(f) true_table_structures, true_cells, _ = objects_to_cells(true_bboxes, true_labels, true_scores, true_page_tokens, structure_class_names, structure_class_thresholds, structure_class_map) m = pred_logits.softmax(-1).max(-1) pred_labels = list(m.indices.detach().cpu().numpy()) pred_scores = list(m.values.detach().cpu().numpy()) pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, true_img_size)] _, pred_cells, _ = objects_to_cells(pred_bboxes, pred_labels, pred_scores, true_page_tokens, structure_class_names, structure_class_thresholds, structure_class_map) metrics = compute_metrics(mode, true_bboxes, true_labels, true_scores, true_cells, pred_bboxes, pred_labels, pred_scores, pred_cells) statistics = compute_statistics(true_table_structures, true_cells) metrics.update(statistics) metrics['id'] = target["img_path"].split('/')[-1].split('.')[0] return metrics def visualize(args, target, pred_logits, pred_bboxes): img_filepath = target["img_path"] img_filename = img_filepath.split("/")[-1] bboxes_out_filename = img_filename.replace(".jpg", "_bboxes.jpg") bboxes_out_filepath = os.path.join(args.debug_save_dir, bboxes_out_filename) img = Image.open(img_filepath) img_size = img.size m = pred_logits.softmax(-1).max(-1) pred_labels = list(m.indices.detach().cpu().numpy()) pred_scores = list(m.values.detach().cpu().numpy()) pred_bboxes = pred_bboxes.detach().cpu() pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)] fig,ax = plt.subplots(1) ax.imshow(img, interpolation='lanczos') for bbox, label, score in zip(pred_bboxes, pred_labels, pred_scores): if ((args.data_type == 'structure' and not label > 5) or (args.data_type == 'detection' and not label > 1) and score > 0.5): color, alpha, linewidth, hatch = get_bbox_decorations(args.data_type, label) # Fill rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, alpha=alpha, edgecolor='none',facecolor=color, linestyle=None) ax.add_patch(rect) # Hatch rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, alpha=0.4, edgecolor=color,facecolor='none', linestyle='--',hatch=hatch) ax.add_patch(rect) # Edge rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, edgecolor=color,facecolor='none', linestyle="--") ax.add_patch(rect) fig.set_size_inches((15, 15)) plt.axis('off') plt.savefig(bboxes_out_filepath, bbox_inches='tight', dpi=100) if args.data_type == 'structure': img_words_filepath = os.path.join(args.table_words_dir, img_filename.replace(".jpg", "_words.json")) cells_out_filename = img_filename.replace(".jpg", "_cells.jpg") cells_out_filepath = os.path.join(args.debug_save_dir, cells_out_filename) with open(img_words_filepath, 'r') as f: tokens = json.load(f) _, pred_cells, _ = objects_to_cells(pred_bboxes, pred_labels, pred_scores, tokens, structure_class_names, structure_class_thresholds, structure_class_map) fig,ax = plt.subplots(1) ax.imshow(img, interpolation='lanczos') for cell in pred_cells: bbox = cell['bbox'] if cell['header']: alpha = 0.3 else: alpha = 0.125 rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, edgecolor='none',facecolor="magenta", alpha=alpha) ax.add_patch(rect) rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, edgecolor="magenta",facecolor='none',linestyle="--", alpha=0.08, hatch='///') ax.add_patch(rect) rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, edgecolor="magenta",facecolor='none',linestyle="--") ax.add_patch(rect) fig.set_size_inches((15, 15)) plt.axis('off') plt.savefig(cells_out_filepath, bbox_inches='tight', dpi=100) plt.close('all') @torch.no_grad() def evaluate(args, model, criterion, postprocessors, data_loader, base_ds, device): st_time = datetime.now() model.eval() criterion.eval() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Test:' iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) coco_evaluator = CocoEvaluator(base_ds, iou_types) if args.data_type == "structure": tsr_metrics = [] pred_logits_collection = [] pred_bboxes_collection = [] targets_collection = [] num_batches = len(data_loader) print_every = max(args.eval_step, int(math.ceil(num_batches / 100))) batch_num = 0 for samples, targets in metric_logger.log_every(data_loader, print_every, header): batch_num += 1 samples = samples.to(device) for t in targets: for k, v in t.items(): if not k == 'img_path': t[k] = v.to(device) outputs = model(samples) if args.debug: for target, pred_logits, pred_boxes in zip(targets, outputs['pred_logits'], outputs['pred_boxes']): visualize(args, target, pred_logits, pred_boxes) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(class_error=loss_dict_reduced['class_error']) orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) results = postprocessors['bbox'](outputs, orig_target_sizes) res = {target['image_id'].item(): output for target, output in zip(targets, results)} if coco_evaluator is not None: coco_evaluator.update(res) if args.data_type == "structure": pred_logits_collection += list(outputs['pred_logits'].detach().cpu()) pred_bboxes_collection += list(outputs['pred_boxes'].detach().cpu()) for target in targets: for k, v in target.items(): if not k == 'img_path': target[k] = v.cpu() img_filepath = target["img_path"] img_filename = img_filepath.split("/")[-1] img_words_filepath = os.path.join(args.table_words_dir, img_filename.replace(".jpg", "_words.json")) target["img_words_path"] = img_words_filepath targets_collection += targets if batch_num % args.eval_step == 0 or batch_num == num_batches: arguments = zip(targets_collection, pred_logits_collection, pred_bboxes_collection, repeat(args.mode)) with multiprocessing.Pool(args.eval_pool_size) as pool: metrics = pool.starmap_async(eval_tsr_sample, arguments).get() tsr_metrics += metrics pred_logits_collection = [] pred_bboxes_collection = [] targets_collection = [] # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) if coco_evaluator is not None: coco_evaluator.synchronize_between_processes() # accumulate predictions from all images if coco_evaluator is not None: coco_evaluator.accumulate() coco_evaluator.summarize() stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} if coco_evaluator is not None: if 'bbox' in postprocessors.keys(): stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() if args.data_type == "structure": # Save sample-level metrics for more analysis if len(args.metrics_save_filepath) > 0: with open(args.metrics_save_filepath, 'w') as outfile: json.dump(tsr_metrics, outfile) # Compute metrics averaged over all samples metrics_summary = compute_metrics_summary(tsr_metrics, args.mode) # Print summary of metrics print_metrics_summary(metrics_summary) print("Total time taken for {} samples: {}".format(len(base_ds), datetime.now() - st_time)) return stats, coco_evaluator def eval_coco(args, model, criterion, postprocessors, data_loader_test, dataset_test, device): """ Use this function to do COCO evaluation. Default implementation runs it on the test set. """ pubmed_stats, coco_evaluator = evaluate(args, model, criterion, postprocessors, data_loader_test, dataset_test, device) print("COCO metrics summary: AP50: {:.3f}, AP75: {:.3f}, AP: {:.3f}, AR: {:.3f}".format( pubmed_stats['coco_eval_bbox'][1], pubmed_stats['coco_eval_bbox'][2], pubmed_stats['coco_eval_bbox'][0], pubmed_stats['coco_eval_bbox'][8]))