# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import paddle from ppocr.utils.utility import load_vqa_bio_label_maps class VQASerTokenLayoutLMPostProcess(object): """ Convert between text-label and text-index """ def __init__(self, class_path, **kwargs): super(VQASerTokenLayoutLMPostProcess, self).__init__() label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path) self.label2id_map_for_draw = dict() for key in label2id_map: if key.startswith("I-"): self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]] else: self.label2id_map_for_draw[key] = label2id_map[key] self.id2label_map_for_show = dict() for key in self.label2id_map_for_draw: val = self.label2id_map_for_draw[key] if key == "O": self.id2label_map_for_show[val] = key if key.startswith("B-") or key.startswith("I-"): self.id2label_map_for_show[val] = key[2:] else: self.id2label_map_for_show[val] = key def __call__(self, preds, batch=None, *args, **kwargs): if isinstance(preds, tuple): preds = preds[0] if isinstance(preds, paddle.Tensor): preds = preds.numpy() if batch is not None: return self._metric(preds, batch[5]) else: return self._infer(preds, **kwargs) def _metric(self, preds, label): pred_idxs = preds.argmax(axis=2) decode_out_list = [[] for _ in range(pred_idxs.shape[0])] label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])] for i in range(pred_idxs.shape[0]): for j in range(pred_idxs.shape[1]): if label[i, j] != -100: label_decode_out_list[i].append(self.id2label_map[label[i, j]]) decode_out_list[i].append(self.id2label_map[pred_idxs[i, j]]) return decode_out_list, label_decode_out_list def _infer(self, preds, segment_offset_ids, ocr_infos): results = [] for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, ocr_infos): pred = np.argmax(pred, axis=1) pred = [self.id2label_map[idx] for idx in pred] for idx in range(len(segment_offset_id)): if idx == 0: start_id = 0 else: start_id = segment_offset_id[idx - 1] end_id = segment_offset_id[idx] curr_pred = pred[start_id:end_id] curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred] if len(curr_pred) <= 0: pred_id = 0 else: counts = np.bincount(curr_pred) pred_id = np.argmax(counts) ocr_info[idx]["pred_id"] = int(pred_id) ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)] results.append(ocr_info) return results class DistillationSerPostProcess(VQASerTokenLayoutLMPostProcess): """ DistillationSerPostProcess """ def __init__(self, class_path, model_name=["Student"], key=None, **kwargs): super().__init__(class_path, **kwargs) if not isinstance(model_name, list): model_name = [model_name] self.model_name = model_name self.key = key def __call__(self, preds, batch=None, *args, **kwargs): output = dict() for name in self.model_name: pred = preds[name] if self.key is not None: pred = pred[self.key] output[name] = super().__call__(pred, batch=batch, *args, **kwargs) return output