DocumentQA / DiT_Extractor /sentence_extractor.py
Epoching's picture
Update DiT_Extractor/sentence_extractor.py
961cf08
raw
history blame
No virus
5.56 kB
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
# All rights reserved.
# See the top-level LICENSE and NOTICE files for details.
# LLNL-CODE-838964
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
import json
from tokenizers.pre_tokenizers import Whitespace
import base_utils
import spacy
def guess_sentences(tokens, text):
sentence_delems = ('.', '?', ').', '!')
sentences = []
sentence = []
maybe_delem = None
for token in tokens:
# check next token to see if there is space after prev delem
if maybe_delem != None:
if maybe_delem[1][1] < token[1][0]:
sentences.append(sentence)
sentence = []
maybe_delem = None
sentence.append(token)
if token[0] in sentence_delems:
maybe_delem = token
if sentence != []:
sentences.append(sentence)
return sentences
def spacey_sentences(text):
nlp = spacy.blank('en')
nlp.add_pipe('sentencizer')
sentences = [s.text for s in nlp(text).sents]
return sentences
def add_coords(sentences, all_coords):
sentences_out = []
for sentence in sentences:
new_sentence = []
for token in sentence:
indexes = token[1]
bbox = all_coords[indexes[0]]
for i in range(indexes[0]+1, indexes[1]):
bbox = base_utils.union(bbox, all_coords[i])
new_sentence.append((token[0],token[1],bbox))
sentences_out.append(new_sentence)
return sentences_out
def sentence_extract(document):
"""
Convert extract .PDF result .pkl into tokens with max length of 384 tokens, seperated
on sentence delimiter boundaries such as .!?
"""
max_tokens = 384
document_tree = json.load(open(document,'r'))
sections_per_page = {}
for page_num, page in document_tree.items():
# Tokenize per section (rectangular block that was detected by DIT)
word_sections = []
text_sections = []
for section in page:
text_sections.append(section['text'])
all_text = ''
all_coord = []
if 'subelements' not in section:
continue
for subelement in section['subelements']:
for char in subelement:
all_text += char[1]
all_coord.append(char[0])
# check for weird characters, e.g. "(cid:206)", "ff", "fi", etc
# if string isn't just 1 character, it's an irregular LTChar (character) from pdfminer.
# instead of skipping them, we can just create extra duplicate coordinates for the additional characters.
if len(char[1]) > 1:
bad_char_len = len(char[1])
dupe_coord_amt = (bad_char_len - 1)
for dupe_i in range(dupe_coord_amt):
all_coord.append(char[0])
pre_tokenizer = Whitespace()
sentences_pre_tok = spacey_sentences(all_text)
sentences = []
for sentence in sentences_pre_tok:
tokenized = pre_tokenizer.pre_tokenize_str(sentence)
sentences.append(tokenized)
sentences = add_coords(sentences, all_coord)
word_section = []
t = 0
for sentence in sentences:
t += len(sentence)
if t <= max_tokens:
# update character indicies from concatenating sentences
if len(word_section) > 0:
last_word_obj = word_section[-1]
_, (_, char_idx_offset), _ = last_word_obj
sentence = [(w, (sc+char_idx_offset+1, ec+char_idx_offset+1), bbox) for w, (sc, ec), bbox in sentence]
word_section += sentence
else:
word_sections.append(word_section)
word_section = sentence
t = len(sentence)
word_sections.append(word_section)
sections = {'text_sections':text_sections, 'word_sections':word_sections}
sections_per_page[page_num] = sections
return sections_per_page
def format_output_contexts(sections_per_page):
all_contexts = {}
for page_idx in sections_per_page.keys():
text_sections = sections_per_page[page_idx]['text_sections']
word_sections = sections_per_page[page_idx]['word_sections']
for text_section, word_section in zip(text_sections, word_sections):
whitespaced_text = ' '.join([word[0] for word in word_section])
words_info = []
for word in word_section:
words_info.append({'word_text':word[0], 'char_indices':word[1], 'word_bbox':word[2]})
context_row = {'text':text_section, 'whitespaced_text':whitespaced_text, 'page_idx':int(page_idx), 'words_info':words_info}
context_id = 'context_{0}'.format(len(all_contexts))
all_contexts[context_id] = context_row
return all_contexts
def get_contexts(json_input):
json_output = 'contexts_{0}'.format(json_input)
sections_per_page = sentence_extract(json_input)
all_contexts = format_output_contexts(sections_per_page)
with open(json_output, 'w', encoding='utf8') as json_out:
json.dump(all_contexts, json_out, ensure_ascii=False, indent=4)