Spaces:
Runtime error
Runtime error
import numpy as np | |
import paddle | |
def load_vqa_bio_label_maps(label_map_path): | |
with open(label_map_path, "r", encoding="utf-8") as fin: | |
lines = fin.readlines() | |
lines = [line.strip() for line in lines] | |
if "O" not in lines: | |
lines.insert(0, "O") | |
labels = [] | |
for line in lines: | |
if line == "O": | |
labels.append("O") | |
else: | |
labels.append("B-" + line) | |
labels.append("I-" + line) | |
label2id_map = {label: idx for idx, label in enumerate(labels)} | |
id2label_map = {idx: label for idx, label in enumerate(labels)} | |
return label2id_map, id2label_map | |
class VQASerTokenLayoutLMPostProcess(object): | |
"""Convert between text-label and text-index""" | |
def __init__(self, class_path, **kwargs): | |
super(VQASerTokenLayoutLMPostProcess, self).__init__() | |
label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path) | |
self.label2id_map_for_draw = dict() | |
for key in label2id_map: | |
if key.startswith("I-"): | |
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]] | |
else: | |
self.label2id_map_for_draw[key] = label2id_map[key] | |
self.id2label_map_for_show = dict() | |
for key in self.label2id_map_for_draw: | |
val = self.label2id_map_for_draw[key] | |
if key == "O": | |
self.id2label_map_for_show[val] = key | |
if key.startswith("B-") or key.startswith("I-"): | |
self.id2label_map_for_show[val] = key[2:] | |
else: | |
self.id2label_map_for_show[val] = key | |
def __call__(self, preds, batch=None, *args, **kwargs): | |
if isinstance(preds, paddle.Tensor): | |
preds = preds.numpy() | |
if batch is not None: | |
return self._metric(preds, batch[1]) | |
else: | |
return self._infer(preds, **kwargs) | |
def _metric(self, preds, label): | |
pred_idxs = preds.argmax(axis=2) | |
decode_out_list = [[] for _ in range(pred_idxs.shape[0])] | |
label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])] | |
for i in range(pred_idxs.shape[0]): | |
for j in range(pred_idxs.shape[1]): | |
if label[i, j] != -100: | |
label_decode_out_list[i].append(self.id2label_map[label[i, j]]) | |
decode_out_list[i].append(self.id2label_map[pred_idxs[i, j]]) | |
return decode_out_list, label_decode_out_list | |
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos): | |
results = [] | |
for pred, attention_mask, segment_offset_id, ocr_info in zip( | |
preds, attention_masks, segment_offset_ids, ocr_infos | |
): | |
pred = np.argmax(pred, axis=1) | |
pred = [self.id2label_map[idx] for idx in pred] | |
for idx in range(len(segment_offset_id)): | |
if idx == 0: | |
start_id = 0 | |
else: | |
start_id = segment_offset_id[idx - 1] | |
end_id = segment_offset_id[idx] | |
curr_pred = pred[start_id:end_id] | |
curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred] | |
if len(curr_pred) <= 0: | |
pred_id = 0 | |
else: | |
counts = np.bincount(curr_pred) | |
pred_id = np.argmax(counts) | |
ocr_info[idx]["pred_id"] = int(pred_id) | |
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)] | |
results.append(ocr_info) | |
return results | |