Spaces:
Runtime error
Runtime error
from kpe import KPE | |
import utils | |
import os | |
from sentence_transformers import SentenceTransformer | |
import ranker | |
from huggingface_hub import hf_hub_download | |
class KpeRanker: | |
def __init__(self): | |
model_name = os.environ.get("MODEL_NAME") | |
model_repo = os.environ.get("MODEL_REPO") | |
model_token = os.environ.get("MODEL_TOKEN") | |
ner_model = os.environ.get("NER_MODEL") | |
transformer_model = os.environ.get("TRANSFORMER_MODEL") | |
local_dir = "./" | |
model_path = os.path.join(local_dir, model_name) | |
if not os.path.isfile(model_path): | |
hf_hub_download(repo_id=model_repo, filename=model_name, local_dir=local_dir, token=model_token) | |
TRAINED_MODEL_ADDR = model_path | |
# TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt') | |
self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model= ner_model , device='cpu') | |
self.ranker_transformer = SentenceTransformer(transformer_model, device='cpu') | |
def extract(self, text, count, using_ner, return_sorted): | |
text = utils.normalize(text) | |
kps = self.kpe.extract(text, using_ner=using_ner) | |
if return_sorted: | |
kps = ranker.get_sorted_keywords(self.ranker_transformer, text, kps) | |
else: | |
kps = [(kp, 1) for kp in kps] | |
if len(kps) > count: | |
kps = kps[:count] | |
return kps | |