File size: 1,483 Bytes
4da642e
 
 
 
 
0571449
4da642e
 
0571449
84b31d5
 
 
5199291
 
 
84b31d5
07cbdb5
84b31d5
07cbdb5
84b31d5
07cbdb5
0571449
5199291
 
4da642e
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from kpe import KPE
import utils
import os
from sentence_transformers import SentenceTransformer
import ranker
from huggingface_hub import hf_hub_download

class KpeRanker:
    def __init__(self):
        model_name = os.environ.get("MODEL_NAME")
        model_repo = os.environ.get("MODEL_REPO")
        model_token = os.environ.get("MODEL_TOKEN")
        ner_model = os.environ.get("NER_MODEL")
        transformer_model = os.environ.get("TRANSFORMER_MODEL")
        
        
        local_dir = "./"
        model_path = os.path.join(local_dir, model_name)
        if not os.path.isfile(model_path):
            hf_hub_download(repo_id=model_repo, filename=model_name, local_dir=local_dir, token=model_token)
        TRAINED_MODEL_ADDR = model_path
        # TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
        self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model= ner_model , device='cpu')
        self.ranker_transformer = SentenceTransformer(transformer_model,  device='cpu')

 
    def extract(self, text, count, using_ner, return_sorted):
        text = utils.normalize(text)
        kps = self.kpe.extract(text, using_ner=using_ner)    
        if return_sorted:
            kps = ranker.get_sorted_keywords(self.ranker_transformer, text, kps)
        else:
            kps = [(kp, 1) for kp in kps]
        if len(kps) > count:
            kps = kps[:count]
        return kps