from flair.data import Sentence from flair.models import SequenceTagger from NERDA.models import NERDA from hazm import word_tokenize import flair import utils class KPE: def __init__(self, trained_kpe_model, flair_ner_model, device='cpu') -> None: self.extractor_model = NERDA( tag_scheme = ['B-KEYWORD', 'I-KEYWORD'], tag_outside = 'O', transformer = 'xlm-roberta-large', max_len=512, device=device) flair.device = device self.extractor_model.load_network_from_file(trained_kpe_model) self.ner_tagger = SequenceTagger.load(flair_ner_model) self.IGNORE_TAGS = {'ORDINAL', 'DATE', 'CARDINAL'} @staticmethod def combine_keywords_nes(init_keywords, nes): # init_keywords = list(set(init_keywords)) nes = list(set(nes)) print('nes before combined ', nes) combined_keywords = [] for kw in init_keywords: matched_index = utils.fuzzy_subword_match(kw, nes) if matched_index != -1: print(kw, nes[matched_index]) combined_keywords.append(nes[matched_index]) del nes[matched_index] else: combined_keywords.append(kw) print('nes after combined ', nes) combined_keywords.extend([n for n in nes if n not in combined_keywords]) return combined_keywords def extract(self, txt, using_ner=True): sentence = Sentence(txt) # predict NER tags if using_ner: self.ner_tagger.predict(sentence) nes = [entity.text for entity in sentence.get_spans('ner') if entity.tag not in self.IGNORE_TAGS] else: nes = [] #remove puncs nes = list(map(utils.remove_puncs, nes)) print('nes ', nes) sentences, tags_conf = self.extractor_model.predict_text(txt, sent_tokenize=lambda txt: [txt], word_tokenize=lambda txt: txt.split(), return_confidence=True) init_keywords = utils.get_ne_from_iob_output(sentences, tags_conf) init_keywords = list(map(utils.remove_puncs, init_keywords)) print('init keywords : ', init_keywords) # combine ner response and init keywords merged_keywords = self.combine_keywords_nes(init_keywords, nes) #set but keep order final_keywords = [] for kw in merged_keywords: if kw not in final_keywords: final_keywords.append(kw) return final_keywords