File size: 4,593 Bytes
04f3f18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from transformers import Pipeline
import torch
from .utilities import padToSize
from .summary import select, splitDocument


"""
    Generates the segments ids for BERT
"""
def generateSegmentIds(doc_ids, tokenizer):
    # Alternating 0s and 1s
    segments_ids = [0] * len(doc_ids)
    curr_segment = 0

    for i, token in enumerate(doc_ids):
        segments_ids[i] = curr_segment
        if token == tokenizer.vocab["[SEP]"]: 
            curr_segment = 1 - curr_segment

    return segments_ids


class ExtSummPipeline(Pipeline):
    """
        Extractive summarization pipeline
        
        Inputs
        ------
            inputs : dict
                'sentences' : list[str]
                    Sentences of the document

            strategy : str
                Strategy to summarize the document:
                - 'length': summary with a maximum length (strategy_args is the maximum length).
                - 'count': summary with the given number of sentences (strategy_args is the number of sentences).
                - 'ratio': summary proportional to the length of the document (strategy_args is the ratio [0, 1]).
                - 'threshold': summary only with sentences with a score higher than a given value (strategy_args is the minimum score).

            strategy_args : any
                Parameters of the strategy.

        Outputs
        -------
            selected_sents : list[str]
                List of the selected sentences

            selected_idxs : list[int]
                List of the indexes of the selected sentences in the original input
    """


    def _sanitize_parameters(self, **kwargs):
        postprocess_kwargs = {}
        
        if ("strategy" in kwargs and "strategy_args" not in kwargs) or ("strategy" not in kwargs and "strategy_args" in kwargs):
            raise ValueError("`strategy` and `strategy_args` have to be both set")
        if "strategy" in kwargs:
            postprocess_kwargs["strategy"] = kwargs["strategy"]
        if "strategy_args" in kwargs:
            postprocess_kwargs["strategy_args"] = kwargs["strategy_args"]

        return {}, {}, postprocess_kwargs


    def preprocess(self, inputs):
        sentences = inputs["sentences"]

        # Tokenization and chunking
        doc_tokens = self.tokenizer.tokenize( f"{self.tokenizer.sep_token}{self.tokenizer.cls_token}".join(sentences) )
        doc_tokens = [self.tokenizer.cls_token] + doc_tokens + [self.tokenizer.sep_token]
        doc_chunks = splitDocument(doc_tokens, self.tokenizer.cls_token, self.tokenizer.sep_token, self.model.config.input_size)
        
        # Batch preparation
        batch = {
            "ids": [],
            "segments_ids": [],
            "clss_mask": [],
            "attn_mask": [],
        }
        for chunk_tokens in doc_chunks:
            doc_ids = self.tokenizer.convert_tokens_to_ids(chunk_tokens)
            segment_ids = generateSegmentIds(doc_ids, self.tokenizer)
            clss_mask = [True if token == self.tokenizer.cls_token_id else False for token in doc_ids]
            attn_mask = [1 for _ in range(len(doc_ids))]

            batch["ids"].append( padToSize(doc_ids, self.model.config.input_size, self.tokenizer.pad_token_id) )
            batch["segments_ids"].append( padToSize(segment_ids, self.model.config.input_size, 0) )
            batch["clss_mask"].append( padToSize(clss_mask, self.model.config.input_size, False) )
            batch["attn_mask"].append( padToSize(attn_mask, self.model.config.input_size, 0) )

        batch["ids"] = torch.as_tensor(batch["ids"])
        batch["segments_ids"] = torch.as_tensor(batch["segments_ids"])
        batch["clss_mask"] = torch.as_tensor(batch["clss_mask"])
        batch["attn_mask"] = torch.as_tensor(batch["attn_mask"])
        return { "inputs": batch, "sentences": sentences }


    def _forward(self, args):
        batch = args["inputs"]
        sentences = args["sentences"]
        out_predictions = torch.as_tensor([]).to(self.device)

        self.model.eval()
        with torch.no_grad():
            batch_preds, _ = self.model(batch)
            for i, clss_mask in enumerate(batch["clss_mask"]):
                out_predictions = torch.cat((out_predictions, batch_preds[i][:torch.sum(clss_mask == True)]))

        return { "predictions": out_predictions, "sentences": sentences }


    def postprocess(self, args, strategy: str="count", strategy_args=3):
        predictions = args["predictions"]
        sentences = args["sentences"]
        return select(sentences, predictions, strategy, strategy_args)