File size: 5,454 Bytes
7e3e85d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from model.base_model import SummModel
from model.single_doc import TextRankModel
from typing import List, Union

from nltk import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer


class QueryBasedSummModel(SummModel):

    is_query_based = True

    def __init__(
        self,
        trained_domain: str = None,
        max_input_length: int = None,
        max_output_length: int = None,
        model_backend: SummModel = TextRankModel,
        retrieval_ratio: float = 0.5,
        preprocess: bool = True,
        **kwargs,
    ):
        super(QueryBasedSummModel, self).__init__(
            trained_domain=trained_domain,
            max_input_length=max_input_length,
            max_output_length=max_output_length,
        )
        self.model = model_backend(**kwargs)
        self.retrieval_ratio = retrieval_ratio
        self.preprocess = preprocess

    def _retrieve(self, instance: List[str], query: List[str], n_best) -> List[str]:
        raise NotImplementedError()

    def summarize(
        self,
        corpus: Union[List[str], List[List[str]]],
        queries: List[str] = None,
    ) -> List[str]:
        self.assert_summ_input_type(corpus, queries)

        retrieval_output = []  # List[str]
        for instance, query in zip(corpus, queries):
            if isinstance(instance, str):
                is_dialogue = False
                instance = sent_tokenize(instance)
            else:
                is_dialogue = True
            query = [query]

            # instance & query now are List[str] for sure
            if self.preprocess:
                preprocessor = Preprocessor()
                instance = preprocessor.preprocess(instance)
                query = preprocessor.preprocess(query)

            n_best = max(int(len(instance) * self.retrieval_ratio), 1)
            top_n_sent = self._retrieve(instance, query, n_best)

            if not is_dialogue:
                top_n_sent = " ".join(top_n_sent)  # str
            retrieval_output.append(top_n_sent)

        summaries = self.model.summarize(
            retrieval_output
        )  # List[str] or List[List[str]]
        return summaries

    def generate_specific_description(self):
        is_neural = self.model.is_neural & self.is_neural
        is_extractive = self.model.is_extractive | self.is_extractive
        model_name = "Pipeline with retriever: {}, summarizer: {}".format(
            self.model_name, self.model.model_name
        )

        extractive_abstractive = "extractive" if is_extractive else "abstractive"
        neural = "neural" if is_neural else "non-neural"

        basic_description = (
            f"{model_name} is a "
            f"{'query-based' if self.is_query_based else ''} "
            f"{extractive_abstractive}, {neural} model for summarization."
        )

        return basic_description

    @classmethod
    def assert_summ_input_type(cls, corpus, query):
        if query is None:
            raise TypeError(
                "Query-based summarization models summarize instances of query-text pairs, however, query is missing."
            )

        if not isinstance(query, list):
            raise TypeError(
                "Query-based single-document summarization requires query of `List[str]`."
            )
        if not all([isinstance(q, str) for q in query]):
            raise TypeError(
                "Query-based single-document summarization requires query of `List[str]`."
            )

    @classmethod
    def generate_basic_description(cls) -> str:
        basic_description = (
            "QueryBasedSummModel performs query-based summarization. Given a query-text pair,"
            "the model will first extract the most relevant sentences in articles or turns in "
            "dialogues, then use the single document summarization model to generate the summary"
        )
        return basic_description

    @classmethod
    def show_capability(cls):
        basic_description = cls.generate_basic_description()
        more_details = (
            "A query-based summarization model."
            " Allows for custom model backend selection at initialization."
            " Retrieve relevant turns and then summarize the retrieved turns\n"
            "Strengths: \n - Allows for control of backend model.\n"
            "Weaknesses: \n - Heavily depends on the performance of both retriever and summarizer.\n"
        )
        print(f"{basic_description}\n{'#' * 20}\n{more_details}")


class Preprocessor:
    def __init__(self, remove_stopwords=True, lower_case=True, stem=False):
        self.sw = stopwords.words("english")
        self.stemmer = PorterStemmer()
        self.remove_stopwords = remove_stopwords
        self.lower_case = lower_case
        self.stem = stem

    def preprocess(self, corpus: List[str]) -> List[str]:
        if self.lower_case:
            corpus = [sent.lower() for sent in corpus]
        tokenized_corpus = [word_tokenize(sent) for sent in corpus]
        if self.remove_stopwords:
            tokenized_corpus = [
                [word for word in sent if word not in self.sw]
                for sent in tokenized_corpus
            ]
        if self.stem:
            tokenized_corpus = [
                [self.stemmer.stem(word) for word in sent] for sent in tokenized_corpus
            ]
        return [" ".join(sent) for sent in tokenized_corpus]