File size: 1,491 Bytes
7e3e85d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from .base_query_based_model import QueryBasedSummModel
from model.base_model import SummModel
from model.single_doc import TextRankModel
from typing import List

from gensim.summarization.bm25 import BM25
from nltk import word_tokenize


class BM25SummModel(QueryBasedSummModel):

    # static variables
    model_name = "BM25"
    is_extractive = True  # only represents the retrieval part
    is_neural = False  # only represents the retrieval part
    is_query_based = True

    def __init__(
        self,
        trained_domain: str = None,
        max_input_length: int = None,
        max_output_length: int = None,
        model_backend: SummModel = TextRankModel,
        retrieval_ratio: float = 0.5,
        preprocess: bool = True,
        **kwargs
    ):
        super(BM25SummModel, self).__init__(
            trained_domain=trained_domain,
            max_input_length=max_input_length,
            max_output_length=max_output_length,
            model_backend=model_backend,
            retrieval_ratio=retrieval_ratio,
            preprocess=preprocess,
            **kwargs
        )

    def _retrieve(self, instance: List[str], query: List[str], n_best):
        bm25 = BM25(word_tokenize(s) for s in instance)
        scores = bm25.get_scores(query)
        best_sent_ind = sorted(
            range(len(scores)), key=lambda i: scores[i], reverse=True
        )[:n_best]
        top_n_sent = [instance[ind] for ind in sorted(best_sent_ind)]
        return top_n_sent