sri-doc / token_classification.py
kavg's picture
fixed entity merging issue
ccb5ac8
raw
history blame contribute delete
No virus
8.39 kB
import numpy as np
from preprocess import normalize_box
import copy
def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping):
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
# take argmax on last dimension to get predicted class ID per token
predictions = outputs.logits.argmax(-1).squeeze().tolist()
return predictions
def compare_boxes(b1,b2):
b1 = np.array([c for c in b1])
b2 = np.array([c for c in b2])
equal = np.array_equal(b1,b2)
return equal
def mergable(w1,w2):
if w1['label'] == w2['label']:
threshold = 7
if abs(w1['box'][1] - w2['box'][1]) < threshold or abs(w1['box'][-1] - w2['box'][-1]) < threshold:
return True
return False
return False
def convert_data(data, tokenizer, img_size):
def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]
def simplify_bbox(bbox):
return [
min(bbox[0::2]),
min(bbox[1::2]),
max(bbox[2::2]),
max(bbox[3::2]),
]
def merge_bbox(bbox_list):
x0, y0, x1, y1 = list(zip(*bbox_list))
return [min(x0), min(y0), max(x1), max(y1)]
tokenized_doc = {"input_ids": [], "bbox": [], "labels": [], "attention_mask":[]}
entities = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()
for line in data:
if len(line["text"]) == 0:
empty_entity.add(line["id"])
continue
id2label[line["id"]] = line["label"]
tokenized_inputs = tokenizer(
line["text"],
add_special_tokens=False,
return_offsets_mapping=True,
return_attention_mask=True,
)
text_length = 0
ocr_length = 0
bbox = []
for token_id, offset in zip(tokenized_inputs["input_ids"], tokenized_inputs["offset_mapping"]):
if token_id == 6:
bbox.append(None)
continue
text_length += offset[1] - offset[0]
tmp_box = []
while ocr_length < text_length:
ocr_word = line["words"].pop(0)
ocr_length += len(
tokenizer._tokenizer.normalizer.normalize_str(ocr_word["text"].strip())
)
tmp_box.append(simplify_bbox(ocr_word["box"]))
if len(tmp_box) == 0:
tmp_box = last_box
bbox.append(normalize_bbox(merge_bbox(tmp_box), img_size))
last_box = tmp_box # noqa
bbox = [
[bbox[i + 1][0], bbox[i + 1][1], bbox[i + 1][0], bbox[i + 1][1]] if b is None else b
for i, b in enumerate(bbox)
]
if line["label"] == "other":
label = ["O"] * len(bbox)
else:
label = [f"I-{line['label'].upper()}"] * len(bbox)
label[0] = f"B-{line['label'].upper()}"
tokenized_inputs.update({"bbox": bbox, "labels": label})
if label[0] != "O":
entity_id_to_index_map[line["id"]] = len(entities)
entities.append(
{
"start": len(tokenized_doc["input_ids"]),
"end": len(tokenized_doc["input_ids"]) + len(tokenized_inputs["input_ids"]),
"label": line["label"].upper(),
}
)
for i in tokenized_doc:
tokenized_doc[i] = tokenized_doc[i] + tokenized_inputs[i]
chunk_size = 512
output = {}
for chunk_id, index in enumerate(range(0, len(tokenized_doc["input_ids"]), chunk_size)):
item = {}
entities_in_this_span = []
for k in tokenized_doc:
item[k] = tokenized_doc[k][index : index + chunk_size]
global_to_local_map = {}
for entity_id, entity in enumerate(entities):
if (
index <= entity["start"] < index + chunk_size
and index <= entity["end"] < index + chunk_size
):
entity["start"] = entity["start"] - index
entity["end"] = entity["end"] - index
global_to_local_map[entity_id] = len(entities_in_this_span)
entities_in_this_span.append(entity)
item.update(
{
"entities": entities_in_this_span
}
)
for key in item.keys():
output[key] = output.get(key, []) + item[key]
return output
def dfs(i, merged, width, height, visited, df_words):
v_threshold = int(.01 * height)
h_threshold = int(.08 * width)
visited.add(i)
merged.append(df_words[i])
for j in range(len(df_words)):
if j not in visited:
w1 = df_words[i]['words'][0]
w2 = df_words[j]['words'][0]
# and
if (abs(w1['box'][1] - w2['box'][1]) < v_threshold or abs(w1['box'][-1] - w2['box'][-1]) < v_threshold) \
and (df_words[i]['label'] == df_words[j]['label']) \
and (abs(w1['box'][0] - w2['box'][0]) < h_threshold or abs(w1['box'][-2] - w2['box'][-2]) < h_threshold):
dfs(j,merged, width, height, visited, df_words)
return merged
def createEntities(model, predictions, input_ids, ocr_df, tokenizer, img_size, bbox):
width, height = img_size
words = []
for index,row in ocr_df.iterrows():
word = {}
origin_box = [row['left'],row['top'],row['left']+row['width'],row['top']+row['height']]
word['word_text'] = row['text']
word['word_box'] = origin_box
word['normalized_box'] = normalize_box(word['word_box'], width, height)
words.append(word)
raw_input_ids = input_ids[0].tolist()
token_boxes = bbox.squeeze().tolist()
special_tokens = [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]
input_ids = [id for id in raw_input_ids if id not in special_tokens]
predictions = [model.config.id2label[prediction] for i,prediction in enumerate(predictions) if not (raw_input_ids[i] in special_tokens)]
actual_boxes = [box for i,box in enumerate(token_boxes) if not (raw_input_ids[i] in special_tokens )]
assert(len(actual_boxes) == len(predictions))
for word in words:
word_labels = []
token_labels = []
word_tagging = None
for i,box in enumerate(actual_boxes,start=0):
if compare_boxes(word['normalized_box'],box):
if predictions[i] != 'O':
word_labels.append(predictions[i][2:])
else:
word_labels.append('O')
token_labels.append(predictions[i])
if word_labels != []:
word_tagging = word_labels[0] if word_labels[0] != 'O' else word_labels[-1]
else:
word_tagging = 'O'
word['word_labels'] = token_labels
word['word_tagging'] = word_tagging
filtered_words = [{'id':i,'text':word['word_text'],
'label':word['word_tagging'],
'box':word['word_box'],
'words':[{'box':word['word_box'],'text':word['word_text']}]} for i,word in enumerate(words) if word['word_tagging'] != 'O']
merged_taggings = []
df_words = filtered_words.copy()
visited = set()
for i in range(len(df_words)):
if i not in visited:
merged_taggings.append(dfs(i,[], width, height, visited, df_words))
merged_words = []
for i,merged_tagging in enumerate(merged_taggings):
if ((len(merged_tagging) > 1)) or (merged_tagging[0]['label'] == 'ANSWER'):
new_word = {}
merging_word = " ".join([word['text'] for word in merged_tagging])
merging_box = [merged_tagging[0]['box'][0]-5,merged_tagging[0]['box'][1]-10,merged_tagging[-1]['box'][2]+5,merged_tagging[-1]['box'][3]+10]
new_word['text'] = merging_word
new_word['box'] = merging_box
new_word['label'] = merged_tagging[0]['label']
new_word['id'] = filtered_words[-1]['id']+i+1
new_word['words'] = [{'box':word['box'],'text':word['text']} for word in merged_tagging]
# new_word['start'] =
merged_words.append(new_word)
filtered_words.extend(merged_words)
predictions = [word['label'] for word in filtered_words]
actual_boxes = [word['box'] for word in filtered_words]
unique_taggings = set(predictions)
output = convert_data(copy.deepcopy(merged_words), tokenizer, img_size)
return output, merged_words