File size: 1,335 Bytes
4da642e
 
 
 
 
0571449
4da642e
 
0571449
 
 
 
 
 
 
 
 
4da642e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from kpe import KPE
import utils
import os
from sentence_transformers import SentenceTransformer
import ranker
from huggingface_hub import hf_hub_download

class KpeRanker:
    def __init__(self):
        model_path = "/root/.cache/huggingface/hub/models--ahdsoft--persian-keyphrase-extraction-model/trained_model_10000.pt"
        if os.path.isfile(file_path):
            TRAINED_MODEL_ADDR = model_path
        else:
            hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json")
            TRAINED_MODEL_ADDR = model_path
            
        # TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
        self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
        self.ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2',  device='cpu')

 
    def extract(self, text, count, using_ner, return_sorted):
        text = utils.normalize(text)
        kps = self.kpe.extract(text, using_ner=using_ner)    
        if return_sorted:
            kps = ranker.get_sorted_keywords(self.ranker_transformer, text, kps)
        else:
            kps = [(kp, 1) for kp in kps]
        if len(kps) > count:
            kps = kps[:count]
        return kps