import torch import csv, os, sys import argparse from keybert import KeyBERT from sentence_transformers import SentenceTransformer class KeyWordExtractor(): def __init__(self): KWE_PRETRAINED = 'medmediani/Arabic-KW-Mdel' self.SEQ_LENGTH = 512 self.MAX_KW_NGS=3 self.NKW=3 #self.device = torch.device('cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') sentence_model = SentenceTransformer(KWE_PRETRAINED) sentence_model.to(self.device) self.kw_model = KeyBERT(model=sentence_model) #self.kw_model.to(self.device) def _extract_by_paragraph(self, ctxt, nkws=None, max_kw_ngs=None): paragraphs=map(str.strip,ctxt.split("\n")) kws=[] for paragraph in paragraphs: if paragraph: kws.extend(self.kw_model.extract_keywords(paragraph, keyphrase_ngram_range=(1, max_kw_ngs), top_n=nkws, #use_maxsum=True,nr_candidates=20, top_n=5, #use_mmr=True, diversity=0.8, stop_words=None) ) print("KWS=",kws,file=sys.stderr) kws.sort(key=lambda x: x[1],reverse=True) ukws=set() for kw,_ in kws: if len(ukws)>=nkws: return ukws ukws.add(kw) return ukws def extract(self, ctxt, nkws=None, max_kw_ngs=None): nkws= nkws if nkws is not None else self.NKW max_kw_ngs=max_kw_ngs if max_kw_ngs is not None else self.MAX_KW_NGS #Since we are taking only 512 tokens, let's do by paragraph kw=self._extract_by_paragraph(ctxt,nkws,max_kw_ngs) return ", ".join(kw) return ", ".join(w for w,_ in kw)