Arabic-KW / kwextractor.py
medmediani
Changed the model path
dcb5fc8
import torch
import csv, os, sys
import argparse
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer
class KeyWordExtractor():
def __init__(self):
KWE_PRETRAINED = 'medmediani/Arabic-KW-Mdel'
self.SEQ_LENGTH = 512
self.MAX_KW_NGS=3
self.NKW=3
#self.device = torch.device('cpu')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sentence_model = SentenceTransformer(KWE_PRETRAINED)
sentence_model.to(self.device)
self.kw_model = KeyBERT(model=sentence_model)
#self.kw_model.to(self.device)
def _extract_by_paragraph(self, ctxt, nkws=None, max_kw_ngs=None):
paragraphs=map(str.strip,ctxt.split("\n"))
kws=[]
for paragraph in paragraphs:
if paragraph:
kws.extend(self.kw_model.extract_keywords(paragraph, keyphrase_ngram_range=(1, max_kw_ngs),
top_n=nkws,
#use_maxsum=True,nr_candidates=20, top_n=5,
#use_mmr=True,
diversity=0.8,
stop_words=None)
)
print("KWS=",kws,file=sys.stderr)
kws.sort(key=lambda x: x[1],reverse=True)
ukws=set()
for kw,_ in kws:
if len(ukws)>=nkws:
return ukws
ukws.add(kw)
return ukws
def extract(self, ctxt, nkws=None, max_kw_ngs=None):
nkws= nkws if nkws is not None else self.NKW
max_kw_ngs=max_kw_ngs if max_kw_ngs is not None else self.MAX_KW_NGS
#Since we are taking only 512 tokens, let's do by paragraph
kw=self._extract_by_paragraph(ctxt,nkws,max_kw_ngs)
return ", ".join(kw)
return ", ".join(w for w,_ in kw)