mrmft's picture
adding project source
4da642e
from flair.data import Sentence
from flair.models import SequenceTagger
from NERDA.models import NERDA
from hazm import word_tokenize
import flair
import utils
class KPE:
def __init__(self, trained_kpe_model, flair_ner_model, device='cpu') -> None:
self.extractor_model = NERDA(
tag_scheme = ['B-KEYWORD', 'I-KEYWORD'],
tag_outside = 'O',
transformer = 'xlm-roberta-large',
max_len=512,
device=device)
flair.device = device
self.extractor_model.load_network_from_file(trained_kpe_model)
self.ner_tagger = SequenceTagger.load(flair_ner_model)
self.IGNORE_TAGS = {'ORDINAL', 'DATE', 'CARDINAL'}
@staticmethod
def combine_keywords_nes(init_keywords, nes):
# init_keywords = list(set(init_keywords))
nes = list(set(nes))
print('nes before combined ', nes)
combined_keywords = []
for kw in init_keywords:
matched_index = utils.fuzzy_subword_match(kw, nes)
if matched_index != -1:
print(kw, nes[matched_index])
combined_keywords.append(nes[matched_index])
del nes[matched_index]
else:
combined_keywords.append(kw)
print('nes after combined ', nes)
combined_keywords.extend([n for n in nes if n not in combined_keywords])
return combined_keywords
def extract(self, txt, using_ner=True):
sentence = Sentence(txt)
# predict NER tags
if using_ner:
self.ner_tagger.predict(sentence)
nes = [entity.text for entity in sentence.get_spans('ner') if entity.tag not in self.IGNORE_TAGS]
else:
nes = []
#remove puncs
nes = list(map(utils.remove_puncs, nes))
print('nes ', nes)
sentences, tags_conf = self.extractor_model.predict_text(txt, sent_tokenize=lambda txt: [txt], word_tokenize=lambda txt: txt.split(), return_confidence=True)
init_keywords = utils.get_ne_from_iob_output(sentences, tags_conf)
init_keywords = list(map(utils.remove_puncs, init_keywords))
print('init keywords : ', init_keywords)
# combine ner response and init keywords
merged_keywords = self.combine_keywords_nes(init_keywords, nes)
#set but keep order
final_keywords = []
for kw in merged_keywords:
if kw not in final_keywords:
final_keywords.append(kw)
return final_keywords