Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import glob | |
import os.path as osp | |
import re | |
from typing import Dict, List, Optional, Sequence, Tuple | |
import numpy as np | |
import torch | |
from mmengine.evaluator import BaseMetric | |
from mmengine.logging import MMLogger | |
from rapidfuzz.distance import Levenshtein | |
from shapely.geometry import Point | |
from mmocr.registry import METRICS | |
# TODO: CTW1500 read pair | |
class E2EPointMetric(BaseMetric): | |
"""Point metric for textspotting. Proposed in SPTS. | |
Args: | |
text_score_thrs (dict): Best text score threshold searching | |
space. Defaults to dict(start=0.8, stop=1, step=0.01). | |
word_spotting (bool): Whether to work in word spotting mode. Defaults | |
to False. | |
lexicon_path (str, optional): Lexicon path for word spotting, which | |
points to a lexicon file or a directory. Defaults to None. | |
lexicon_mapping (tuple, optional): The rule to map test image name to | |
its corresponding lexicon file. Only effective when lexicon path | |
is a directory. Defaults to ('(.*).jpg', r'\1.txt'). | |
pair_path (str, optional): Pair path for word spotting, which points | |
to a pair file or a directory. Defaults to None. | |
pair_mapping (tuple, optional): The rule to map test image name to | |
its corresponding pair file. Only effective when pair path is a | |
directory. Defaults to ('(.*).jpg', r'\1.txt'). | |
match_dist_thr (float, optional): Matching distance threshold for | |
word spotting. Defaults to None. | |
collect_device (str): Device name used for collecting results from | |
different ranks during distributed training. Must be 'cpu' or | |
'gpu'. Defaults to 'cpu'. | |
prefix (str, optional): The prefix that will be added in the metric | |
names to disambiguate homonymous metrics of different evaluators. | |
If prefix is not provided in the argument, self.default_prefix | |
will be used instead. Defaults to None | |
""" | |
default_prefix: Optional[str] = 'e2e_icdar' | |
def __init__(self, | |
text_score_thrs: Dict = dict(start=0.8, stop=1, step=0.01), | |
word_spotting: bool = False, | |
lexicon_path: Optional[str] = None, | |
lexicon_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'), | |
pair_path: Optional[str] = None, | |
pair_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'), | |
match_dist_thr: Optional[float] = None, | |
collect_device: str = 'cpu', | |
prefix: Optional[str] = None) -> None: | |
super().__init__(collect_device=collect_device, prefix=prefix) | |
self.text_score_thrs = np.arange(**text_score_thrs) | |
self.word_spotting = word_spotting | |
self.match_dist_thr = match_dist_thr | |
if lexicon_path: | |
self.lexicon_mapping = lexicon_mapping | |
self.pair_mapping = pair_mapping | |
self.lexicons = self._read_lexicon(lexicon_path) | |
self.pairs = self._read_pair(pair_path) | |
def _read_lexicon(self, lexicon_path: str) -> List[str]: | |
if lexicon_path.endswith('.txt'): | |
lexicon = open(lexicon_path, 'r').read().splitlines() | |
lexicon = [ele.strip() for ele in lexicon] | |
else: | |
lexicon = {} | |
for file in glob.glob(osp.join(lexicon_path, '*.txt')): | |
basename = osp.basename(file) | |
lexicon[basename] = self._read_lexicon(file) | |
return lexicon | |
def _read_pair(self, pair_path: str) -> Dict[str, str]: | |
pairs = {} | |
if pair_path.endswith('.txt'): | |
pair_lines = open(pair_path, 'r').read().splitlines() | |
for line in pair_lines: | |
line = line.strip() | |
word = line.split(' ')[0].upper() | |
word_gt = line[len(word) + 1:] | |
pairs[word] = word_gt | |
else: | |
for file in glob.glob(osp.join(pair_path, '*.txt')): | |
basename = osp.basename(file) | |
pairs[basename] = self._read_pair(file) | |
return pairs | |
def poly_center(self, poly_pts): | |
poly_pts = np.array(poly_pts).reshape(-1, 2) | |
return poly_pts.mean(0) | |
def process(self, data_batch: Sequence[Dict], | |
data_samples: Sequence[Dict]) -> None: | |
"""Process one batch of data samples and predictions. The processed | |
results should be stored in ``self.results``, which will be used to | |
compute the metrics when all batches have been processed. | |
Args: | |
data_batch (Sequence[Dict]): A batch of data from dataloader. | |
data_samples (Sequence[Dict]): A batch of outputs from | |
the model. | |
""" | |
for data_sample in data_samples: | |
pred_instances = data_sample.get('pred_instances') | |
pred_points = pred_instances.get('points') | |
text_scores = pred_instances.get('text_scores') | |
if isinstance(text_scores, torch.Tensor): | |
text_scores = text_scores.cpu().numpy() | |
text_scores = np.array(text_scores, dtype=np.float32) | |
pred_texts = pred_instances.get('texts') | |
gt_instances = data_sample.get('gt_instances') | |
gt_polys = gt_instances.get('polygons') | |
gt_ignore_flags = gt_instances.get('ignored') | |
gt_texts = gt_instances.get('texts') | |
if isinstance(gt_ignore_flags, torch.Tensor): | |
gt_ignore_flags = gt_ignore_flags.cpu().numpy() | |
gt_points = [self.poly_center(poly) for poly in gt_polys] | |
if self.word_spotting: | |
gt_ignore_flags, gt_texts = self._word_spotting_filter( | |
gt_ignore_flags, gt_texts) | |
pred_ignore_flags = text_scores < self.text_score_thrs.min() | |
text_scores = text_scores[~pred_ignore_flags] | |
pred_texts = self._get_true_elements(pred_texts, | |
~pred_ignore_flags) | |
pred_points = self._get_true_elements(pred_points, | |
~pred_ignore_flags) | |
result = dict( | |
# reserved for image-level lexcions | |
gt_img_name=osp.basename(data_sample.get('img_path', '')), | |
text_scores=text_scores, | |
pred_points=pred_points, | |
gt_points=gt_points, | |
pred_texts=pred_texts, | |
gt_texts=gt_texts, | |
gt_ignore_flags=gt_ignore_flags) | |
self.results.append(result) | |
def _get_true_elements(self, array: List, flags: np.ndarray) -> List: | |
return [array[i] for i in self._true_indexes(flags)] | |
def compute_metrics(self, results: List[Dict]) -> Dict: | |
"""Compute the metrics from processed results. | |
Args: | |
results (list[dict]): The processed results of each batch. | |
Returns: | |
dict: The computed metrics. The keys are the names of the metrics, | |
and the values are corresponding results. | |
""" | |
logger: MMLogger = MMLogger.get_current_instance() | |
best_eval_results = dict(hmean=-1) | |
num_thres = len(self.text_score_thrs) | |
num_preds = np.zeros( | |
num_thres, dtype=int) # the number of points actually predicted | |
num_tp = np.zeros(num_thres, dtype=int) # number of true positives | |
num_gts = np.zeros(num_thres, dtype=int) # number of valid gts | |
for result in results: | |
text_scores = result['text_scores'] | |
pred_points = result['pred_points'] | |
gt_points = result['gt_points'] | |
gt_texts = result['gt_texts'] | |
pred_texts = result['pred_texts'] | |
gt_ignore_flags = result['gt_ignore_flags'] | |
gt_img_name = result['gt_img_name'] | |
# Correct the words with lexicon | |
pred_dist_flags = np.zeros(len(pred_texts), dtype=bool) | |
if hasattr(self, 'lexicons'): | |
for i, pred_text in enumerate(pred_texts): | |
# If it's an image-level lexicon | |
if isinstance(self.lexicons, dict): | |
lexicon_name = self._map_img_name( | |
gt_img_name, self.lexicon_mapping) | |
pair_name = self._map_img_name(gt_img_name, | |
self.pair_mapping) | |
pred_texts[i], match_dist = self._match_word( | |
pred_text, self.lexicons[lexicon_name], | |
self.pairs[pair_name]) | |
else: | |
pred_texts[i], match_dist = self._match_word( | |
pred_text, self.lexicons, self.pairs) | |
if (self.match_dist_thr | |
and match_dist >= self.match_dist_thr): | |
# won't even count this as a prediction | |
pred_dist_flags[i] = True | |
# Filter out predictions by IoU threshold | |
for i, text_score_thr in enumerate(self.text_score_thrs): | |
pred_ignore_flags = pred_dist_flags | ( | |
text_scores < text_score_thr) | |
filtered_pred_texts = self._get_true_elements( | |
pred_texts, ~pred_ignore_flags) | |
filtered_pred_points = self._get_true_elements( | |
pred_points, ~pred_ignore_flags) | |
gt_matched = np.zeros(len(gt_texts), dtype=bool) | |
num_gt = len(gt_texts) - np.sum(gt_ignore_flags) | |
if num_gt == 0: | |
continue | |
num_gts[i] += num_gt | |
for pred_text, pred_point in zip(filtered_pred_texts, | |
filtered_pred_points): | |
dists = [ | |
Point(pred_point).distance(Point(gt_point)) | |
for gt_point in gt_points | |
] | |
min_idx = np.argmin(dists) | |
if gt_texts[min_idx] == '###' or gt_ignore_flags[min_idx]: | |
continue | |
if not gt_matched[min_idx] and ( | |
pred_text.upper() == gt_texts[min_idx].upper()): | |
gt_matched[min_idx] = True | |
num_tp[i] += 1 | |
num_preds[i] += 1 | |
for i, text_score_thr in enumerate(self.text_score_thrs): | |
if num_preds[i] == 0 or num_tp[i] == 0: | |
recall, precision, hmean = 0, 0, 0 | |
else: | |
recall = num_tp[i] / num_gts[i] | |
precision = num_tp[i] / num_preds[i] | |
hmean = 2 * recall * precision / (recall + precision) | |
eval_results = dict( | |
precision=precision, recall=recall, hmean=hmean) | |
logger.info(f'text score threshold: {text_score_thr:.2f}, ' | |
f'recall: {eval_results["recall"]:.4f}, ' | |
f'precision: {eval_results["precision"]:.4f}, ' | |
f'hmean: {eval_results["hmean"]:.4f}\n') | |
if eval_results['hmean'] > best_eval_results['hmean']: | |
best_eval_results = eval_results | |
return best_eval_results | |
def _map_img_name(self, img_name: str, mapping: Tuple[str, str]) -> str: | |
"""Map the image name to the another one based on mapping.""" | |
return re.sub(mapping[0], mapping[1], img_name) | |
def _true_indexes(self, array: np.ndarray) -> np.ndarray: | |
"""Get indexes of True elements from a 1D boolean array.""" | |
return np.where(array)[0] | |
def _word_spotting_filter(self, gt_ignore_flags: np.ndarray, | |
gt_texts: List[str] | |
) -> Tuple[np.ndarray, List[str]]: | |
"""Filter out gt instances that cannot be in a valid dictionary, and do | |
some simple preprocessing to texts.""" | |
for i in range(len(gt_texts)): | |
if gt_ignore_flags[i]: | |
continue | |
text = gt_texts[i] | |
if text[-2:] in ["'s", "'S"]: | |
text = text[:-2] | |
text = text.strip('-') | |
for char in "'!?.:,*\"()·[]/": | |
text = text.replace(char, ' ') | |
text = text.strip() | |
gt_ignore_flags[i] = not self._include_in_dict(text) | |
if not gt_ignore_flags[i]: | |
gt_texts[i] = text | |
return gt_ignore_flags, gt_texts | |
def _include_in_dict(self, text: str) -> bool: | |
"""Check if the text could be in a valid dictionary.""" | |
if len(text) != len(text.replace(' ', '')) or len(text) < 3: | |
return False | |
not_allowed = '×÷·' | |
valid_ranges = [(ord(u'a'), ord(u'z')), (ord(u'A'), ord(u'Z')), | |
(ord(u'À'), ord(u'ƿ')), (ord(u'DŽ'), ord(u'ɿ')), | |
(ord(u'Ά'), ord(u'Ͽ')), (ord(u'-'), ord(u'-'))] | |
for char in text: | |
code = ord(char) | |
if (not_allowed.find(char) != -1): | |
return False | |
valid = any(code >= r[0] and code <= r[1] for r in valid_ranges) | |
if not valid: | |
return False | |
return True | |
def _match_word(self, | |
text: str, | |
lexicons: List[str], | |
pairs: Optional[Dict[str, str]] = None) -> Tuple[str, int]: | |
"""Match the text with the lexicons and pairs.""" | |
text = text.upper() | |
matched_word = '' | |
matched_dist = 100 | |
for lexicon in lexicons: | |
lexicon = lexicon.upper() | |
norm_dist = Levenshtein.distance(text, lexicon) | |
norm_dist = Levenshtein.normalized_distance(text, lexicon) | |
if norm_dist < matched_dist: | |
matched_dist = norm_dist | |
if pairs: | |
matched_word = pairs[lexicon] | |
else: | |
matched_word = lexicon | |
return matched_word, matched_dist | |