File size: 23,277 Bytes
90fa1fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
"""
@author: Nianlong Gu, Institute of Neuroinformatics, ETH Zurich
@email: nianlonggu@gmail.com

The source code for the paper: MemSum: Extractive Summarization of Long Documents using Multi-step Episodic Markov Decision Processes

When using this code or some of our pre-trained models for your application, please cite the following paper:
@article{DBLP:journals/corr/abs-2107-08929,
  author    = {Nianlong Gu and
               Elliott Ash and
               Richard H. R. Hahnloser},
  title     = {MemSum: Extractive Summarization of Long Documents using Multi-step
               Episodic Markov Decision Processes},
  journal   = {CoRR},
  volume    = {abs/2107.08929},
  year      = {2021},
  url       = {https://arxiv.org/abs/2107.08929},
  eprinttype = {arXiv},
  eprint    = {2107.08929},
  timestamp = {Thu, 22 Jul 2021 11:14:11 +0200},
  biburl    = {https://dblp.org/rec/journals/corr/abs-2107-08929.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pickle
import numpy as np

class AddMask( nn.Module ):
    def __init__( self, pad_index ):
        super().__init__()
        self.pad_index = pad_index
    def forward( self, x):
        # here x is a batch of input sequences (not embeddings) with the shape of [ batch_size, seq_len]
        mask = x == self.pad_index
        return mask


class PositionalEncoding( nn.Module ):
    def __init__(self,  embed_dim, max_seq_len = 512  ):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        pe = torch.zeros( 1, max_seq_len,  embed_dim )
        for pos in range( max_seq_len ):
            for i in range( 0, embed_dim, 2 ):
                pe[ 0, pos, i ] = math.sin( pos / ( 10000 ** ( i/embed_dim ) )  )
                if i+1 < embed_dim:
                    pe[ 0, pos, i+1 ] = math.cos( pos / ( 10000** ( i/embed_dim ) ) )
        self.register_buffer( "pe", pe )
        ## register_buffer can register some variables that can be saved and loaded by state_dict, but not trainable since not accessible by model.parameters()
    def forward( self, x ):
        return x + self.pe[ :, : x.size(1), :]



class MultiHeadAttention( nn.Module ):
    def __init__(self, embed_dim, num_heads ):
        super().__init__()
        dim_per_head = int( embed_dim/num_heads )
        
        self.ln_q = nn.Linear( embed_dim, num_heads * dim_per_head )
        self.ln_k = nn.Linear( embed_dim, num_heads * dim_per_head )
        self.ln_v = nn.Linear( embed_dim, num_heads * dim_per_head )

        self.ln_out = nn.Linear( num_heads * dim_per_head, embed_dim )

        self.num_heads = num_heads
        self.dim_per_head = dim_per_head
    
    def forward( self, q,k,v, mask = None):
        q = self.ln_q( q )
        k = self.ln_k( k )
        v = self.ln_v( v )

        q = q.view( q.size(0), q.size(1),  self.num_heads, self.dim_per_head  ).transpose( 1,2 )
        k = k.view( k.size(0), k.size(1),  self.num_heads, self.dim_per_head  ).transpose( 1,2 )
        v = v.view( v.size(0), v.size(1),  self.num_heads, self.dim_per_head  ).transpose( 1,2 )

        a = self.scaled_dot_product_attention( q,k, mask )
        new_v = a.matmul(v)
        new_v = new_v.transpose( 1,2 ).contiguous()
        new_v = new_v.view( new_v.size(0), new_v.size(1), -1 )
        new_v = self.ln_out(new_v)
        return new_v

    def scaled_dot_product_attention( self, q, k, mask = None ):
        ## note the here q and k have converted into multi-head mode 
        ## q's shape is [ Batchsize, num_heads, seq_len_q, dim_per_head ]
        ## k's shape is [ Batchsize, num_heads, seq_len_k, dim_per_head ]
        # scaled dot product
        a = q.matmul( k.transpose( 2,3 ) )/ math.sqrt( q.size(-1) )
        # apply mask (either padding mask or seqeunce mask)
        if mask is not None:
            a = a.masked_fill( mask.unsqueeze(1).unsqueeze(1) , -1e9 )  
        # apply softmax, to get the likelihood as attention matrix
        a = F.softmax( a, dim=-1 )
        return a

class FeedForward( nn.Module ):
    def __init__( self, embed_dim, hidden_dim ):
        super().__init__()
        self.ln1 = nn.Linear( embed_dim, hidden_dim )
        self.ln2 = nn.Linear( hidden_dim, embed_dim )
    def forward(  self, x):
        net = F.relu(self.ln1(x))
        out = self.ln2(net)
        return out


class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim ):
        super().__init__()
        self.mha = MultiHeadAttention( embed_dim, num_heads  )
        self.norm1 = nn.LayerNorm( embed_dim )
        self.feed_forward = FeedForward( embed_dim, hidden_dim )
        self.norm2 = nn.LayerNorm( embed_dim )
    def forward( self, x, mask, dropout_rate = 0. ):
        short_cut = x
        net = F.dropout(self.mha( x,x,x, mask ), p = dropout_rate)
        net = self.norm1( short_cut + net )
        short_cut = net
        net = F.dropout(self.feed_forward( net ), p = dropout_rate )
        net = self.norm2( short_cut + net )
        return net

class TransformerDecoderLayer( nn.Module ):
    def __init__(self, embed_dim, num_heads, hidden_dim ):
        super().__init__()
        self.masked_mha = MultiHeadAttention(  embed_dim, num_heads )
        self.norm1 = nn.LayerNorm( embed_dim )
        self.mha = MultiHeadAttention( embed_dim, num_heads )
        self.norm2 = nn.LayerNorm( embed_dim )
        self.feed_forward = FeedForward( embed_dim, hidden_dim )
        self.norm3 = nn.LayerNorm( embed_dim )
    def forward(self, encoder_output, x, src_mask, trg_mask , dropout_rate = 0. ):
        short_cut = x
        net = F.dropout(self.masked_mha( x,x,x, trg_mask ), p = dropout_rate)
        net = self.norm1( short_cut + net )
        short_cut = net
        net = F.dropout(self.mha( net, encoder_output, encoder_output, src_mask ), p = dropout_rate)
        net = self.norm2( short_cut + net )
        short_cut = net
        net = F.dropout(self.feed_forward( net ), p = dropout_rate)
        net = self.norm3( short_cut + net )
        return net 

class MultiHeadPoolingLayer( nn.Module ):
    def __init__( self, embed_dim, num_heads  ):
        super().__init__()
        self.num_heads = num_heads
        self.dim_per_head = int( embed_dim/num_heads )
        self.ln_attention_score = nn.Linear( embed_dim, num_heads )
        self.ln_value = nn.Linear( embed_dim,  num_heads * self.dim_per_head )
        self.ln_out = nn.Linear( num_heads * self.dim_per_head , embed_dim )
    def forward(self, input_embedding , mask=None):
        a = self.ln_attention_score( input_embedding )
        v = self.ln_value( input_embedding )
        
        a = a.view( a.size(0), a.size(1), self.num_heads, 1 ).transpose(1,2)
        v = v.view( v.size(0), v.size(1),  self.num_heads, self.dim_per_head  ).transpose(1,2)
        a = a.transpose(2,3)
        if mask is not None:
            a = a.masked_fill( mask.unsqueeze(1).unsqueeze(1) , -1e9 ) 
        a = F.softmax(a , dim = -1 )

        new_v = a.matmul(v)
        new_v = new_v.transpose( 1,2 ).contiguous()
        new_v = new_v.view( new_v.size(0), new_v.size(1) ,-1 ).squeeze(1)
        new_v = self.ln_out( new_v )
        return new_v


class LocalSentenceEncoder( nn.Module ):
    def __init__( self, vocab_size, pad_index, embed_dim, num_heads , hidden_dim , num_enc_layers , pretrained_word_embedding ):
        super().__init__()
        self.addmask = AddMask( pad_index )
      
        self.rnn = nn.LSTM(  embed_dim, embed_dim, 2, batch_first = True, bidirectional = True)
        self.mh_pool = MultiHeadPoolingLayer( 2*embed_dim, num_heads )
        self.norm_out = nn.LayerNorm( 2*embed_dim )
        self.ln_out = nn.Linear( 2*embed_dim, embed_dim )

        if pretrained_word_embedding is not None:
            ## make sure the pad embedding is 0
            pretrained_word_embedding[pad_index] = 0
            self.register_buffer( "word_embedding", torch.from_numpy( pretrained_word_embedding ) )
        else:
            self.register_buffer( "word_embedding", torch.randn( vocab_size, embed_dim ) )

    """
    input_seq 's shape:  batch_size x seq_len 
    """
    def forward( self, input_seq, dropout_rate = 0. ):
        mask = self.addmask( input_seq )
        ## batch_size x seq_len x embed_dim
        net = self.word_embedding[ input_seq ]
        net, _ = self.rnn( net )
        net =  self.ln_out(F.relu(self.norm_out(self.mh_pool( net, mask ))))
        return net


class GlobalContextEncoder(nn.Module):
    def __init__(self, embed_dim,  num_heads, hidden_dim, num_dec_layers ):
        super().__init__()
        # self.pos_encode = PositionalEncoding( embed_dim)
        # self.layer_list = nn.ModuleList( [  TransformerEncoderLayer( embed_dim, num_heads, hidden_dim ) for _ in range(num_dec_layers) ] )
        self.rnn = nn.LSTM(  embed_dim, embed_dim, 2, batch_first = True, bidirectional = True)
        self.norm_out = nn.LayerNorm( 2*embed_dim )
        self.ln_out = nn.Linear( 2*embed_dim, embed_dim )

    def forward(self, sen_embed, doc_mask, dropout_rate = 0.):
        net, _ = self.rnn( sen_embed )
        net = self.ln_out(F.relu( self.norm_out(net) ) )
        return net


class ExtractionContextDecoder( nn.Module ):
    def __init__( self, embed_dim,  num_heads, hidden_dim, num_dec_layers ):
        super().__init__()
        self.layer_list = nn.ModuleList( [  TransformerDecoderLayer( embed_dim, num_heads, hidden_dim ) for _ in range(num_dec_layers) ] )
    ## remaining_mask: set all unextracted sen indices as True
    ## extraction_mask: set all extracted sen indices as True
    def forward( self, sen_embed, remaining_mask, extraction_mask, dropout_rate = 0. ):
        net = sen_embed
        for layer in self.layer_list:
            #  encoder_output, x,  src_mask, trg_mask , dropout_rate = 0.
            net = layer( sen_embed, net, remaining_mask, extraction_mask, dropout_rate )
        return net

class Extractor( nn.Module ):
    def __init__( self, embed_dim, num_heads ):
        super().__init__()
        self.norm_input = nn.LayerNorm( 3*embed_dim  )
        
        self.ln_hidden1 = nn.Linear(  3*embed_dim, 2*embed_dim  )
        self.norm_hidden1 = nn.LayerNorm( 2*embed_dim  )
        
        self.ln_hidden2 = nn.Linear(  2*embed_dim, embed_dim  )
        self.norm_hidden2 = nn.LayerNorm( embed_dim  )

        self.ln_out = nn.Linear(  embed_dim, 1 )

        self.mh_pool = MultiHeadPoolingLayer( embed_dim, num_heads )
        self.norm_pool = nn.LayerNorm( embed_dim  )
        self.ln_stop = nn.Linear(  embed_dim, 1 )

        self.mh_pool_2 = MultiHeadPoolingLayer( embed_dim, num_heads )
        self.norm_pool_2 = nn.LayerNorm( embed_dim  )
        self.ln_baseline = nn.Linear(  embed_dim, 1 )

    def forward( self, sen_embed, relevance_embed, redundancy_embed , extraction_mask, dropout_rate = 0. ):
        if redundancy_embed is None:
            redundancy_embed = torch.zeros_like( sen_embed )
        net = self.norm_input( F.dropout( torch.cat( [ sen_embed, relevance_embed, redundancy_embed ], dim = 2 ) , p = dropout_rate  )  ) 
        net = F.relu( self.norm_hidden1( F.dropout( self.ln_hidden1( net ) , p = dropout_rate  )   ))
        hidden_net = F.relu( self.norm_hidden2( F.dropout( self.ln_hidden2( net)  , p = dropout_rate  )  ))
        
        p = self.ln_out( hidden_net ).sigmoid().squeeze(2)

        net = F.relu( self.norm_pool(  F.dropout( self.mh_pool( hidden_net, extraction_mask) , p = dropout_rate  )  ))
        p_stop = self.ln_stop( net ).sigmoid().squeeze(1)

        net = F.relu( self.norm_pool_2(  F.dropout( self.mh_pool_2( hidden_net, extraction_mask ) , p = dropout_rate  )  ))
        baseline = self.ln_baseline(net)

        return p, p_stop, baseline


## naive tokenizer with just lower() function
class SentenceTokenizer:
    def __init__(self ):
        pass
    def tokenize(self, sen ):
        return sen.lower()


class Vocab:
    def __init__(self, words, eos_token = "<eos>", pad_token = "<pad>", unk_token = "<unk>" ):
        self.words = words
        self.index_to_word = {}
        self.word_to_index = {}
        for idx in range( len(words) ):
            self.index_to_word[ idx ] = words[idx]
            self.word_to_index[ words[idx] ] = idx
        self.eos_token = eos_token
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.eos_index = self.word_to_index[self.eos_token]
        self.pad_index = self.word_to_index[self.pad_token]

        self.tokenizer = SentenceTokenizer()   

    def index2word( self, idx ):
        return self.index_to_word.get( idx, self.unk_token)
    def word2index( self, word ):
        return self.word_to_index.get( word, -1 )
    # The sentence needs to be tokenized 
    def sent2seq( self, sent, max_len = None , tokenize = True):
        if tokenize:
            sent = self.tokenizer.tokenize(sent)
        seq = []
        for w in sent.split():
            if w in self.word_to_index:
                seq.append( self.word2index(w) )
        if max_len is not None:
            if len(seq) >= max_len:
                seq = seq[:max_len -1]
                seq.append( self.eos_index )
            else:
                seq.append( self.eos_index )
                seq += [ self.pad_index ] * ( max_len - len(seq) )
        return seq
    def seq2sent( self, seq ):
        sent = []
        for i in seq:
            if i == self.eos_index or i == self.pad_index:
                break
            sent.append( self.index2word(i) )
        return " ".join(sent)



class MemSum:
    def __init__( self, model_path, vocabulary_path, gpu = None ):
        
        ## max_doc_len is used to truncate too long sentence into at most 100 words 
        max_seq_len =100
        ## max_doc_len is used to truncate too long document into at most 200 sentences 
        max_doc_len = 200
        
        ## These parameters below have been fintuned for the pretrained model
        embed_dim=200
        num_heads=8
        hidden_dim = 1024
        N_enc_l = 2
        N_enc_g = 2
        N_dec = 3
        
        with open( vocabulary_path , "rb" ) as f:
            words = pickle.load(f)
        self.vocab = Vocab( words )
        vocab_size = len(words)
        self.local_sentence_encoder = LocalSentenceEncoder( vocab_size, self.vocab.pad_index, embed_dim,num_heads,hidden_dim,N_enc_l, None )
        self.global_context_encoder = GlobalContextEncoder( embed_dim, num_heads, hidden_dim, N_enc_g )
        self.extraction_context_decoder = ExtractionContextDecoder( embed_dim, num_heads, hidden_dim, N_dec )
        self.extractor = Extractor( embed_dim, num_heads )
        ckpt = torch.load( model_path, map_location = "cpu" )
        self.local_sentence_encoder.load_state_dict( ckpt["local_sentence_encoder"] )
        self.global_context_encoder.load_state_dict( ckpt["global_context_encoder"] )
        self.extraction_context_decoder.load_state_dict( ckpt["extraction_context_decoder"] )
        self.extractor.load_state_dict(ckpt["extractor"])
        
        self.device =  torch.device( "cuda:%d"%(gpu) if gpu is not None and torch.cuda.is_available() else "cpu"  )        
        self.local_sentence_encoder.to(self.device)
        self.global_context_encoder.to(self.device)
        self.extraction_context_decoder.to(self.device)
        self.extractor.to(self.device)
        
        self.sentence_tokenizer = SentenceTokenizer()
        self.max_seq_len = max_seq_len
        self.max_doc_len = max_doc_len
    
    def get_ngram(self,  w_list, n = 4 ):
        ngram_set = set()
        for pos in range(len(w_list) - n + 1 ):
            ngram_set.add( "_".join( w_list[ pos:pos+n] )  )
        return ngram_set

    def extract( self, document_batch, p_stop_thres , ngram_blocking , ngram, return_sentence_position, return_sentence_score_history, max_extracted_sentences_per_document ):
        """document_batch is a batch of documents:
        [  [ sen1, sen2, ... , senL1 ], 
           [ sen1, sen2, ... , senL2], ...
         ]
        """
        ## tokenization:
        document_length_list = []
        sentence_length_list = []
        tokenized_document_batch = []
        for document in document_batch:
            tokenized_document = []
            for sen in document:
                tokenized_sen = self.sentence_tokenizer.tokenize( sen )
                tokenized_document.append( tokenized_sen )
                sentence_length_list.append( len(tokenized_sen.split()) )
            tokenized_document_batch.append( tokenized_document )
            document_length_list.append( len(tokenized_document) )

        max_document_length =  self.max_doc_len 
        max_sentence_length =  self.max_seq_len 
        ## convert to sequence
        seqs = []
        doc_mask = []
        
        for document in tokenized_document_batch:
            if len(document) > max_document_length:
                # doc_mask.append(  [0] * max_document_length )
                document = document[:max_document_length]
            else:
                # doc_mask.append(  [0] * len(document) +[1] * ( max_document_length -  len(document) ) )
                document = document + [""] * ( max_document_length -  len(document) )

            doc_mask.append(  [ 1 if sen.strip() == "" else 0 for sen in  document   ] )

            document_sequences = []
            for sen in document:
                seq = self.vocab.sent2seq( sen, max_sentence_length )
                document_sequences.append(seq)
            seqs.append(document_sequences)
        seqs = np.asarray(seqs)
        doc_mask = np.asarray(doc_mask) == 1
        seqs = torch.from_numpy(seqs).to(self.device)
        doc_mask = torch.from_numpy(doc_mask).to(self.device)

        extracted_sentences = []
        sentence_score_history = []
        p_stop_history = []
        
        with torch.no_grad():
            num_sentences = seqs.size(1)
            sen_embed  = self.local_sentence_encoder( seqs.view(-1, seqs.size(2) )  )
            sen_embed = sen_embed.view( -1, num_sentences, sen_embed.size(1) )
            relevance_embed = self.global_context_encoder( sen_embed, doc_mask  )
    
            num_documents = seqs.size(0)
            doc_mask = doc_mask.detach().cpu().numpy()
            seqs = seqs.detach().cpu().numpy()
    
            extracted_sentences = []
            extracted_sentences_positions = []
        
            for doc_i in range(num_documents):
                current_doc_mask = doc_mask[doc_i:doc_i+1]
                current_remaining_mask_np = np.ones_like(current_doc_mask ).astype(np.bool_) | current_doc_mask
                current_extraction_mask_np = np.zeros_like(current_doc_mask).astype(np.bool_) | current_doc_mask
        
                current_sen_embed = sen_embed[doc_i:doc_i+1]
                current_relevance_embed = relevance_embed[ doc_i:doc_i+1 ]
                current_redundancy_embed = None
        
                current_hyps = []
                extracted_sen_ngrams = set()

                sentence_score_history_for_doc_i = []

                p_stop_history_for_doc_i = []
                
                for step in range( max_extracted_sentences_per_document+1 ) :
                    current_extraction_mask = torch.from_numpy( current_extraction_mask_np ).to(self.device)
                    current_remaining_mask = torch.from_numpy( current_remaining_mask_np ).to(self.device)
                    if step > 0:
                        current_redundancy_embed = self.extraction_context_decoder( current_sen_embed, current_remaining_mask, current_extraction_mask  )
                    p, p_stop, _ = self.extractor( current_sen_embed, current_relevance_embed, current_redundancy_embed , current_extraction_mask  )
                    p_stop = p_stop.unsqueeze(1)
            
            
                    p = p.masked_fill( current_extraction_mask, 1e-12 ) 

                    sentence_score_history_for_doc_i.append( p.detach().cpu().numpy() )

                    p_stop_history_for_doc_i.append(  p_stop.squeeze(1).item() )

                    normalized_p = p / p.sum(dim=1, keepdims = True)

                    stop = p_stop.squeeze(1).item()> p_stop_thres #and step > 0
                    
                    #sen_i = normalized_p.argmax(dim=1)[0]
                    _, sorted_sen_indices =normalized_p.sort(dim=1, descending= True)
                    sorted_sen_indices = sorted_sen_indices[0]
                    
                    extracted = False
                    for sen_i in sorted_sen_indices:
                        sen_i = sen_i.item()
                        if sen_i< len(document_batch[doc_i]):
                            sen = document_batch[doc_i][sen_i]
                        else:
                            break
                        sen_ngrams = self.get_ngram( sen.lower().split(), ngram )
                        if not ngram_blocking or len( extracted_sen_ngrams &  sen_ngrams ) < 1:
                            extracted_sen_ngrams.update( sen_ngrams )
                            extracted = True
                            break
                                        
                    if stop or step == max_extracted_sentences_per_document or not extracted:
                        extracted_sentences.append( [ document_batch[doc_i][sen_i] for sen_i in  current_hyps if sen_i < len(document_batch[doc_i])    ] )
                        extracted_sentences_positions.append( [ sen_i for sen_i in  current_hyps if sen_i < len(document_batch[doc_i])  ]  )
                        break
                    else:
                        current_hyps.append(sen_i)
                        current_extraction_mask_np[0, sen_i] = True
                        current_remaining_mask_np[0, sen_i] = False

                sentence_score_history.append(sentence_score_history_for_doc_i)
                p_stop_history.append( p_stop_history_for_doc_i )
        

        results = [extracted_sentences]
        if return_sentence_position:
            results.append( extracted_sentences_positions )
        if return_sentence_score_history:
            results+=[sentence_score_history , p_stop_history ]
        if len(results) == 1:
            results = results[0]
        
        return results
    
    ## document is a list of sentences
    def summarize( self, document, p_stop_thres = 0.7, max_extracted_sentences_per_document = 10, return_sentence_position = False ):
        sentences, sentence_positions = self.extract( [document], p_stop_thres, ngram_blocking = False, ngram = 0, return_sentence_position = True, return_sentence_score_history = False, max_extracted_sentences_per_document = max_extracted_sentences_per_document )
        try:
            sentences, sentence_positions = list(zip(*sorted( zip( sentences[0], sentence_positions[0] ), key = lambda x:x[1] )))
        except ValueError:
            sentences, sentence_positions = (), ()
        if return_sentence_position:
            return sentences, sentence_positions
        else:
            return sentences