SummerTime / model /query_based /bm25_model.py
aliabd
full demo working with old graido
7e3e85d
from .base_query_based_model import QueryBasedSummModel
from model.base_model import SummModel
from model.single_doc import TextRankModel
from typing import List
from gensim.summarization.bm25 import BM25
from nltk import word_tokenize
class BM25SummModel(QueryBasedSummModel):
# static variables
model_name = "BM25"
is_extractive = True # only represents the retrieval part
is_neural = False # only represents the retrieval part
is_query_based = True
def __init__(
self,
trained_domain: str = None,
max_input_length: int = None,
max_output_length: int = None,
model_backend: SummModel = TextRankModel,
retrieval_ratio: float = 0.5,
preprocess: bool = True,
**kwargs
):
super(BM25SummModel, self).__init__(
trained_domain=trained_domain,
max_input_length=max_input_length,
max_output_length=max_output_length,
model_backend=model_backend,
retrieval_ratio=retrieval_ratio,
preprocess=preprocess,
**kwargs
)
def _retrieve(self, instance: List[str], query: List[str], n_best):
bm25 = BM25(word_tokenize(s) for s in instance)
scores = bm25.get_scores(query)
best_sent_ind = sorted(
range(len(scores)), key=lambda i: scores[i], reverse=True
)[:n_best]
top_n_sent = [instance[ind] for ind in sorted(best_sent_ind)]
return top_n_sent