mrmft's picture
adding project source
4da642e
raw
history blame
No virus
933 Bytes
from kpe import KPE
import utils
import os
from sentence_transformers import SentenceTransformer
import ranker
class KpeRanker:
def __init__(self):
TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
self.ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
def extract(self, text, count, using_ner, return_sorted):
text = utils.normalize(text)
kps = self.kpe.extract(text, using_ner=using_ner)
if return_sorted:
kps = ranker.get_sorted_keywords(self.ranker_transformer, text, kps)
else:
kps = [(kp, 1) for kp in kps]
if len(kps) > count:
kps = kps[:count]
return kps