|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import paddle |
|
from ppocr.utils.utility import load_vqa_bio_label_maps |
|
|
|
|
|
class VQASerTokenLayoutLMPostProcess(object): |
|
""" Convert between text-label and text-index """ |
|
|
|
def __init__(self, class_path, **kwargs): |
|
super(VQASerTokenLayoutLMPostProcess, self).__init__() |
|
label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path) |
|
|
|
self.label2id_map_for_draw = dict() |
|
for key in label2id_map: |
|
if key.startswith("I-"): |
|
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]] |
|
else: |
|
self.label2id_map_for_draw[key] = label2id_map[key] |
|
|
|
self.id2label_map_for_show = dict() |
|
for key in self.label2id_map_for_draw: |
|
val = self.label2id_map_for_draw[key] |
|
if key == "O": |
|
self.id2label_map_for_show[val] = key |
|
if key.startswith("B-") or key.startswith("I-"): |
|
self.id2label_map_for_show[val] = key[2:] |
|
else: |
|
self.id2label_map_for_show[val] = key |
|
|
|
def __call__(self, preds, batch=None, *args, **kwargs): |
|
if isinstance(preds, tuple): |
|
preds = preds[0] |
|
if isinstance(preds, paddle.Tensor): |
|
preds = preds.numpy() |
|
|
|
if batch is not None: |
|
return self._metric(preds, batch[5]) |
|
else: |
|
return self._infer(preds, **kwargs) |
|
|
|
def _metric(self, preds, label): |
|
pred_idxs = preds.argmax(axis=2) |
|
decode_out_list = [[] for _ in range(pred_idxs.shape[0])] |
|
label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])] |
|
|
|
for i in range(pred_idxs.shape[0]): |
|
for j in range(pred_idxs.shape[1]): |
|
if label[i, j] != -100: |
|
label_decode_out_list[i].append(self.id2label_map[label[i, |
|
j]]) |
|
decode_out_list[i].append(self.id2label_map[pred_idxs[i, |
|
j]]) |
|
return decode_out_list, label_decode_out_list |
|
|
|
def _infer(self, preds, segment_offset_ids, ocr_infos): |
|
results = [] |
|
|
|
for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, |
|
ocr_infos): |
|
pred = np.argmax(pred, axis=1) |
|
pred = [self.id2label_map[idx] for idx in pred] |
|
|
|
for idx in range(len(segment_offset_id)): |
|
if idx == 0: |
|
start_id = 0 |
|
else: |
|
start_id = segment_offset_id[idx - 1] |
|
|
|
end_id = segment_offset_id[idx] |
|
|
|
curr_pred = pred[start_id:end_id] |
|
curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred] |
|
|
|
if len(curr_pred) <= 0: |
|
pred_id = 0 |
|
else: |
|
counts = np.bincount(curr_pred) |
|
pred_id = np.argmax(counts) |
|
ocr_info[idx]["pred_id"] = int(pred_id) |
|
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)] |
|
results.append(ocr_info) |
|
return results |
|
|
|
|
|
class DistillationSerPostProcess(VQASerTokenLayoutLMPostProcess): |
|
""" |
|
DistillationSerPostProcess |
|
""" |
|
|
|
def __init__(self, class_path, model_name=["Student"], key=None, **kwargs): |
|
super().__init__(class_path, **kwargs) |
|
if not isinstance(model_name, list): |
|
model_name = [model_name] |
|
self.model_name = model_name |
|
self.key = key |
|
|
|
def __call__(self, preds, batch=None, *args, **kwargs): |
|
output = dict() |
|
for name in self.model_name: |
|
pred = preds[name] |
|
if self.key is not None: |
|
pred = pred[self.key] |
|
output[name] = super().__call__(pred, batch=batch, *args, **kwargs) |
|
return output |
|
|