Spaces:
Running
Running
import torch | |
import csv, os, sys | |
import argparse | |
from keybert import KeyBERT | |
from sentence_transformers import SentenceTransformer | |
class KeyWordExtractor(): | |
def __init__(self): | |
KWE_PRETRAINED = 'medmediani/Arabic-KW-Mdel' | |
self.SEQ_LENGTH = 512 | |
self.MAX_KW_NGS=3 | |
self.NKW=3 | |
#self.device = torch.device('cpu') | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
sentence_model = SentenceTransformer(KWE_PRETRAINED) | |
sentence_model.to(self.device) | |
self.kw_model = KeyBERT(model=sentence_model) | |
#self.kw_model.to(self.device) | |
def _extract_by_paragraph(self, ctxt, nkws=None, max_kw_ngs=None): | |
paragraphs=map(str.strip,ctxt.split("\n")) | |
kws=[] | |
for paragraph in paragraphs: | |
if paragraph: | |
kws.extend(self.kw_model.extract_keywords(paragraph, keyphrase_ngram_range=(1, max_kw_ngs), | |
top_n=nkws, | |
#use_maxsum=True,nr_candidates=20, top_n=5, | |
#use_mmr=True, | |
diversity=0.8, | |
stop_words=None) | |
) | |
print("KWS=",kws,file=sys.stderr) | |
kws.sort(key=lambda x: x[1],reverse=True) | |
ukws=set() | |
for kw,_ in kws: | |
if len(ukws)>=nkws: | |
return ukws | |
ukws.add(kw) | |
return ukws | |
def extract(self, ctxt, nkws=None, max_kw_ngs=None): | |
nkws= nkws if nkws is not None else self.NKW | |
max_kw_ngs=max_kw_ngs if max_kw_ngs is not None else self.MAX_KW_NGS | |
#Since we are taking only 512 tokens, let's do by paragraph | |
kw=self._extract_by_paragraph(ctxt,nkws,max_kw_ngs) | |
return ", ".join(kw) | |
return ", ".join(w for w,_ in kw) | |