bauerem commited on
Commit
90fa1fd
1 Parent(s): 2abeaea

Upload 5 files

Browse files
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from memsum import MemSum
3
+ from nltk import sent_tokenize
4
+
5
+ model_path = "model/MemSum_Final/model.pt"
6
+ summarizer = MemSum(model_path, "model/glove/vocabulary_200dim.pkl")
7
+
8
+ def preprocess(text):
9
+ text = text.replace("\n","")
10
+ text = text.replace("¶","")
11
+ text = " ".join(text.split())
12
+ return text
13
+
14
+ def summarize(text):
15
+
16
+ text = sent_tokenize( preprocess(text) )
17
+
18
+ summary = "\n".join( summarizer.summarize(text) )
19
+
20
+ return summary
21
+
22
+ demo = gr.Interface(fn=summarize, inputs="text", outputs="text")
23
+ demo.launch(share=True)
memsum.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: Nianlong Gu, Institute of Neuroinformatics, ETH Zurich
3
+ @email: nianlonggu@gmail.com
4
+
5
+ The source code for the paper: MemSum: Extractive Summarization of Long Documents using Multi-step Episodic Markov Decision Processes
6
+
7
+ When using this code or some of our pre-trained models for your application, please cite the following paper:
8
+ @article{DBLP:journals/corr/abs-2107-08929,
9
+ author = {Nianlong Gu and
10
+ Elliott Ash and
11
+ Richard H. R. Hahnloser},
12
+ title = {MemSum: Extractive Summarization of Long Documents using Multi-step
13
+ Episodic Markov Decision Processes},
14
+ journal = {CoRR},
15
+ volume = {abs/2107.08929},
16
+ year = {2021},
17
+ url = {https://arxiv.org/abs/2107.08929},
18
+ eprinttype = {arXiv},
19
+ eprint = {2107.08929},
20
+ timestamp = {Thu, 22 Jul 2021 11:14:11 +0200},
21
+ biburl = {https://dblp.org/rec/journals/corr/abs-2107-08929.bib},
22
+ bibsource = {dblp computer science bibliography, https://dblp.org}
23
+ }
24
+ """
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import math
30
+ import pickle
31
+ import numpy as np
32
+
33
+ class AddMask( nn.Module ):
34
+ def __init__( self, pad_index ):
35
+ super().__init__()
36
+ self.pad_index = pad_index
37
+ def forward( self, x):
38
+ # here x is a batch of input sequences (not embeddings) with the shape of [ batch_size, seq_len]
39
+ mask = x == self.pad_index
40
+ return mask
41
+
42
+
43
+ class PositionalEncoding( nn.Module ):
44
+ def __init__(self, embed_dim, max_seq_len = 512 ):
45
+ super().__init__()
46
+ self.embed_dim = embed_dim
47
+ self.max_seq_len = max_seq_len
48
+ pe = torch.zeros( 1, max_seq_len, embed_dim )
49
+ for pos in range( max_seq_len ):
50
+ for i in range( 0, embed_dim, 2 ):
51
+ pe[ 0, pos, i ] = math.sin( pos / ( 10000 ** ( i/embed_dim ) ) )
52
+ if i+1 < embed_dim:
53
+ pe[ 0, pos, i+1 ] = math.cos( pos / ( 10000** ( i/embed_dim ) ) )
54
+ self.register_buffer( "pe", pe )
55
+ ## register_buffer can register some variables that can be saved and loaded by state_dict, but not trainable since not accessible by model.parameters()
56
+ def forward( self, x ):
57
+ return x + self.pe[ :, : x.size(1), :]
58
+
59
+
60
+
61
+ class MultiHeadAttention( nn.Module ):
62
+ def __init__(self, embed_dim, num_heads ):
63
+ super().__init__()
64
+ dim_per_head = int( embed_dim/num_heads )
65
+
66
+ self.ln_q = nn.Linear( embed_dim, num_heads * dim_per_head )
67
+ self.ln_k = nn.Linear( embed_dim, num_heads * dim_per_head )
68
+ self.ln_v = nn.Linear( embed_dim, num_heads * dim_per_head )
69
+
70
+ self.ln_out = nn.Linear( num_heads * dim_per_head, embed_dim )
71
+
72
+ self.num_heads = num_heads
73
+ self.dim_per_head = dim_per_head
74
+
75
+ def forward( self, q,k,v, mask = None):
76
+ q = self.ln_q( q )
77
+ k = self.ln_k( k )
78
+ v = self.ln_v( v )
79
+
80
+ q = q.view( q.size(0), q.size(1), self.num_heads, self.dim_per_head ).transpose( 1,2 )
81
+ k = k.view( k.size(0), k.size(1), self.num_heads, self.dim_per_head ).transpose( 1,2 )
82
+ v = v.view( v.size(0), v.size(1), self.num_heads, self.dim_per_head ).transpose( 1,2 )
83
+
84
+ a = self.scaled_dot_product_attention( q,k, mask )
85
+ new_v = a.matmul(v)
86
+ new_v = new_v.transpose( 1,2 ).contiguous()
87
+ new_v = new_v.view( new_v.size(0), new_v.size(1), -1 )
88
+ new_v = self.ln_out(new_v)
89
+ return new_v
90
+
91
+ def scaled_dot_product_attention( self, q, k, mask = None ):
92
+ ## note the here q and k have converted into multi-head mode
93
+ ## q's shape is [ Batchsize, num_heads, seq_len_q, dim_per_head ]
94
+ ## k's shape is [ Batchsize, num_heads, seq_len_k, dim_per_head ]
95
+ # scaled dot product
96
+ a = q.matmul( k.transpose( 2,3 ) )/ math.sqrt( q.size(-1) )
97
+ # apply mask (either padding mask or seqeunce mask)
98
+ if mask is not None:
99
+ a = a.masked_fill( mask.unsqueeze(1).unsqueeze(1) , -1e9 )
100
+ # apply softmax, to get the likelihood as attention matrix
101
+ a = F.softmax( a, dim=-1 )
102
+ return a
103
+
104
+ class FeedForward( nn.Module ):
105
+ def __init__( self, embed_dim, hidden_dim ):
106
+ super().__init__()
107
+ self.ln1 = nn.Linear( embed_dim, hidden_dim )
108
+ self.ln2 = nn.Linear( hidden_dim, embed_dim )
109
+ def forward( self, x):
110
+ net = F.relu(self.ln1(x))
111
+ out = self.ln2(net)
112
+ return out
113
+
114
+
115
+ class TransformerEncoderLayer(nn.Module):
116
+ def __init__(self, embed_dim, num_heads, hidden_dim ):
117
+ super().__init__()
118
+ self.mha = MultiHeadAttention( embed_dim, num_heads )
119
+ self.norm1 = nn.LayerNorm( embed_dim )
120
+ self.feed_forward = FeedForward( embed_dim, hidden_dim )
121
+ self.norm2 = nn.LayerNorm( embed_dim )
122
+ def forward( self, x, mask, dropout_rate = 0. ):
123
+ short_cut = x
124
+ net = F.dropout(self.mha( x,x,x, mask ), p = dropout_rate)
125
+ net = self.norm1( short_cut + net )
126
+ short_cut = net
127
+ net = F.dropout(self.feed_forward( net ), p = dropout_rate )
128
+ net = self.norm2( short_cut + net )
129
+ return net
130
+
131
+ class TransformerDecoderLayer( nn.Module ):
132
+ def __init__(self, embed_dim, num_heads, hidden_dim ):
133
+ super().__init__()
134
+ self.masked_mha = MultiHeadAttention( embed_dim, num_heads )
135
+ self.norm1 = nn.LayerNorm( embed_dim )
136
+ self.mha = MultiHeadAttention( embed_dim, num_heads )
137
+ self.norm2 = nn.LayerNorm( embed_dim )
138
+ self.feed_forward = FeedForward( embed_dim, hidden_dim )
139
+ self.norm3 = nn.LayerNorm( embed_dim )
140
+ def forward(self, encoder_output, x, src_mask, trg_mask , dropout_rate = 0. ):
141
+ short_cut = x
142
+ net = F.dropout(self.masked_mha( x,x,x, trg_mask ), p = dropout_rate)
143
+ net = self.norm1( short_cut + net )
144
+ short_cut = net
145
+ net = F.dropout(self.mha( net, encoder_output, encoder_output, src_mask ), p = dropout_rate)
146
+ net = self.norm2( short_cut + net )
147
+ short_cut = net
148
+ net = F.dropout(self.feed_forward( net ), p = dropout_rate)
149
+ net = self.norm3( short_cut + net )
150
+ return net
151
+
152
+ class MultiHeadPoolingLayer( nn.Module ):
153
+ def __init__( self, embed_dim, num_heads ):
154
+ super().__init__()
155
+ self.num_heads = num_heads
156
+ self.dim_per_head = int( embed_dim/num_heads )
157
+ self.ln_attention_score = nn.Linear( embed_dim, num_heads )
158
+ self.ln_value = nn.Linear( embed_dim, num_heads * self.dim_per_head )
159
+ self.ln_out = nn.Linear( num_heads * self.dim_per_head , embed_dim )
160
+ def forward(self, input_embedding , mask=None):
161
+ a = self.ln_attention_score( input_embedding )
162
+ v = self.ln_value( input_embedding )
163
+
164
+ a = a.view( a.size(0), a.size(1), self.num_heads, 1 ).transpose(1,2)
165
+ v = v.view( v.size(0), v.size(1), self.num_heads, self.dim_per_head ).transpose(1,2)
166
+ a = a.transpose(2,3)
167
+ if mask is not None:
168
+ a = a.masked_fill( mask.unsqueeze(1).unsqueeze(1) , -1e9 )
169
+ a = F.softmax(a , dim = -1 )
170
+
171
+ new_v = a.matmul(v)
172
+ new_v = new_v.transpose( 1,2 ).contiguous()
173
+ new_v = new_v.view( new_v.size(0), new_v.size(1) ,-1 ).squeeze(1)
174
+ new_v = self.ln_out( new_v )
175
+ return new_v
176
+
177
+
178
+ class LocalSentenceEncoder( nn.Module ):
179
+ def __init__( self, vocab_size, pad_index, embed_dim, num_heads , hidden_dim , num_enc_layers , pretrained_word_embedding ):
180
+ super().__init__()
181
+ self.addmask = AddMask( pad_index )
182
+
183
+ self.rnn = nn.LSTM( embed_dim, embed_dim, 2, batch_first = True, bidirectional = True)
184
+ self.mh_pool = MultiHeadPoolingLayer( 2*embed_dim, num_heads )
185
+ self.norm_out = nn.LayerNorm( 2*embed_dim )
186
+ self.ln_out = nn.Linear( 2*embed_dim, embed_dim )
187
+
188
+ if pretrained_word_embedding is not None:
189
+ ## make sure the pad embedding is 0
190
+ pretrained_word_embedding[pad_index] = 0
191
+ self.register_buffer( "word_embedding", torch.from_numpy( pretrained_word_embedding ) )
192
+ else:
193
+ self.register_buffer( "word_embedding", torch.randn( vocab_size, embed_dim ) )
194
+
195
+ """
196
+ input_seq 's shape: batch_size x seq_len
197
+ """
198
+ def forward( self, input_seq, dropout_rate = 0. ):
199
+ mask = self.addmask( input_seq )
200
+ ## batch_size x seq_len x embed_dim
201
+ net = self.word_embedding[ input_seq ]
202
+ net, _ = self.rnn( net )
203
+ net = self.ln_out(F.relu(self.norm_out(self.mh_pool( net, mask ))))
204
+ return net
205
+
206
+
207
+ class GlobalContextEncoder(nn.Module):
208
+ def __init__(self, embed_dim, num_heads, hidden_dim, num_dec_layers ):
209
+ super().__init__()
210
+ # self.pos_encode = PositionalEncoding( embed_dim)
211
+ # self.layer_list = nn.ModuleList( [ TransformerEncoderLayer( embed_dim, num_heads, hidden_dim ) for _ in range(num_dec_layers) ] )
212
+ self.rnn = nn.LSTM( embed_dim, embed_dim, 2, batch_first = True, bidirectional = True)
213
+ self.norm_out = nn.LayerNorm( 2*embed_dim )
214
+ self.ln_out = nn.Linear( 2*embed_dim, embed_dim )
215
+
216
+ def forward(self, sen_embed, doc_mask, dropout_rate = 0.):
217
+ net, _ = self.rnn( sen_embed )
218
+ net = self.ln_out(F.relu( self.norm_out(net) ) )
219
+ return net
220
+
221
+
222
+ class ExtractionContextDecoder( nn.Module ):
223
+ def __init__( self, embed_dim, num_heads, hidden_dim, num_dec_layers ):
224
+ super().__init__()
225
+ self.layer_list = nn.ModuleList( [ TransformerDecoderLayer( embed_dim, num_heads, hidden_dim ) for _ in range(num_dec_layers) ] )
226
+ ## remaining_mask: set all unextracted sen indices as True
227
+ ## extraction_mask: set all extracted sen indices as True
228
+ def forward( self, sen_embed, remaining_mask, extraction_mask, dropout_rate = 0. ):
229
+ net = sen_embed
230
+ for layer in self.layer_list:
231
+ # encoder_output, x, src_mask, trg_mask , dropout_rate = 0.
232
+ net = layer( sen_embed, net, remaining_mask, extraction_mask, dropout_rate )
233
+ return net
234
+
235
+ class Extractor( nn.Module ):
236
+ def __init__( self, embed_dim, num_heads ):
237
+ super().__init__()
238
+ self.norm_input = nn.LayerNorm( 3*embed_dim )
239
+
240
+ self.ln_hidden1 = nn.Linear( 3*embed_dim, 2*embed_dim )
241
+ self.norm_hidden1 = nn.LayerNorm( 2*embed_dim )
242
+
243
+ self.ln_hidden2 = nn.Linear( 2*embed_dim, embed_dim )
244
+ self.norm_hidden2 = nn.LayerNorm( embed_dim )
245
+
246
+ self.ln_out = nn.Linear( embed_dim, 1 )
247
+
248
+ self.mh_pool = MultiHeadPoolingLayer( embed_dim, num_heads )
249
+ self.norm_pool = nn.LayerNorm( embed_dim )
250
+ self.ln_stop = nn.Linear( embed_dim, 1 )
251
+
252
+ self.mh_pool_2 = MultiHeadPoolingLayer( embed_dim, num_heads )
253
+ self.norm_pool_2 = nn.LayerNorm( embed_dim )
254
+ self.ln_baseline = nn.Linear( embed_dim, 1 )
255
+
256
+ def forward( self, sen_embed, relevance_embed, redundancy_embed , extraction_mask, dropout_rate = 0. ):
257
+ if redundancy_embed is None:
258
+ redundancy_embed = torch.zeros_like( sen_embed )
259
+ net = self.norm_input( F.dropout( torch.cat( [ sen_embed, relevance_embed, redundancy_embed ], dim = 2 ) , p = dropout_rate ) )
260
+ net = F.relu( self.norm_hidden1( F.dropout( self.ln_hidden1( net ) , p = dropout_rate ) ))
261
+ hidden_net = F.relu( self.norm_hidden2( F.dropout( self.ln_hidden2( net) , p = dropout_rate ) ))
262
+
263
+ p = self.ln_out( hidden_net ).sigmoid().squeeze(2)
264
+
265
+ net = F.relu( self.norm_pool( F.dropout( self.mh_pool( hidden_net, extraction_mask) , p = dropout_rate ) ))
266
+ p_stop = self.ln_stop( net ).sigmoid().squeeze(1)
267
+
268
+ net = F.relu( self.norm_pool_2( F.dropout( self.mh_pool_2( hidden_net, extraction_mask ) , p = dropout_rate ) ))
269
+ baseline = self.ln_baseline(net)
270
+
271
+ return p, p_stop, baseline
272
+
273
+
274
+ ## naive tokenizer with just lower() function
275
+ class SentenceTokenizer:
276
+ def __init__(self ):
277
+ pass
278
+ def tokenize(self, sen ):
279
+ return sen.lower()
280
+
281
+
282
+ class Vocab:
283
+ def __init__(self, words, eos_token = "<eos>", pad_token = "<pad>", unk_token = "<unk>" ):
284
+ self.words = words
285
+ self.index_to_word = {}
286
+ self.word_to_index = {}
287
+ for idx in range( len(words) ):
288
+ self.index_to_word[ idx ] = words[idx]
289
+ self.word_to_index[ words[idx] ] = idx
290
+ self.eos_token = eos_token
291
+ self.pad_token = pad_token
292
+ self.unk_token = unk_token
293
+ self.eos_index = self.word_to_index[self.eos_token]
294
+ self.pad_index = self.word_to_index[self.pad_token]
295
+
296
+ self.tokenizer = SentenceTokenizer()
297
+
298
+ def index2word( self, idx ):
299
+ return self.index_to_word.get( idx, self.unk_token)
300
+ def word2index( self, word ):
301
+ return self.word_to_index.get( word, -1 )
302
+ # The sentence needs to be tokenized
303
+ def sent2seq( self, sent, max_len = None , tokenize = True):
304
+ if tokenize:
305
+ sent = self.tokenizer.tokenize(sent)
306
+ seq = []
307
+ for w in sent.split():
308
+ if w in self.word_to_index:
309
+ seq.append( self.word2index(w) )
310
+ if max_len is not None:
311
+ if len(seq) >= max_len:
312
+ seq = seq[:max_len -1]
313
+ seq.append( self.eos_index )
314
+ else:
315
+ seq.append( self.eos_index )
316
+ seq += [ self.pad_index ] * ( max_len - len(seq) )
317
+ return seq
318
+ def seq2sent( self, seq ):
319
+ sent = []
320
+ for i in seq:
321
+ if i == self.eos_index or i == self.pad_index:
322
+ break
323
+ sent.append( self.index2word(i) )
324
+ return " ".join(sent)
325
+
326
+
327
+
328
+ class MemSum:
329
+ def __init__( self, model_path, vocabulary_path, gpu = None ):
330
+
331
+ ## max_doc_len is used to truncate too long sentence into at most 100 words
332
+ max_seq_len =100
333
+ ## max_doc_len is used to truncate too long document into at most 200 sentences
334
+ max_doc_len = 200
335
+
336
+ ## These parameters below have been fintuned for the pretrained model
337
+ embed_dim=200
338
+ num_heads=8
339
+ hidden_dim = 1024
340
+ N_enc_l = 2
341
+ N_enc_g = 2
342
+ N_dec = 3
343
+
344
+ with open( vocabulary_path , "rb" ) as f:
345
+ words = pickle.load(f)
346
+ self.vocab = Vocab( words )
347
+ vocab_size = len(words)
348
+ self.local_sentence_encoder = LocalSentenceEncoder( vocab_size, self.vocab.pad_index, embed_dim,num_heads,hidden_dim,N_enc_l, None )
349
+ self.global_context_encoder = GlobalContextEncoder( embed_dim, num_heads, hidden_dim, N_enc_g )
350
+ self.extraction_context_decoder = ExtractionContextDecoder( embed_dim, num_heads, hidden_dim, N_dec )
351
+ self.extractor = Extractor( embed_dim, num_heads )
352
+ ckpt = torch.load( model_path, map_location = "cpu" )
353
+ self.local_sentence_encoder.load_state_dict( ckpt["local_sentence_encoder"] )
354
+ self.global_context_encoder.load_state_dict( ckpt["global_context_encoder"] )
355
+ self.extraction_context_decoder.load_state_dict( ckpt["extraction_context_decoder"] )
356
+ self.extractor.load_state_dict(ckpt["extractor"])
357
+
358
+ self.device = torch.device( "cuda:%d"%(gpu) if gpu is not None and torch.cuda.is_available() else "cpu" )
359
+ self.local_sentence_encoder.to(self.device)
360
+ self.global_context_encoder.to(self.device)
361
+ self.extraction_context_decoder.to(self.device)
362
+ self.extractor.to(self.device)
363
+
364
+ self.sentence_tokenizer = SentenceTokenizer()
365
+ self.max_seq_len = max_seq_len
366
+ self.max_doc_len = max_doc_len
367
+
368
+ def get_ngram(self, w_list, n = 4 ):
369
+ ngram_set = set()
370
+ for pos in range(len(w_list) - n + 1 ):
371
+ ngram_set.add( "_".join( w_list[ pos:pos+n] ) )
372
+ return ngram_set
373
+
374
+ def extract( self, document_batch, p_stop_thres , ngram_blocking , ngram, return_sentence_position, return_sentence_score_history, max_extracted_sentences_per_document ):
375
+ """document_batch is a batch of documents:
376
+ [ [ sen1, sen2, ... , senL1 ],
377
+ [ sen1, sen2, ... , senL2], ...
378
+ ]
379
+ """
380
+ ## tokenization:
381
+ document_length_list = []
382
+ sentence_length_list = []
383
+ tokenized_document_batch = []
384
+ for document in document_batch:
385
+ tokenized_document = []
386
+ for sen in document:
387
+ tokenized_sen = self.sentence_tokenizer.tokenize( sen )
388
+ tokenized_document.append( tokenized_sen )
389
+ sentence_length_list.append( len(tokenized_sen.split()) )
390
+ tokenized_document_batch.append( tokenized_document )
391
+ document_length_list.append( len(tokenized_document) )
392
+
393
+ max_document_length = self.max_doc_len
394
+ max_sentence_length = self.max_seq_len
395
+ ## convert to sequence
396
+ seqs = []
397
+ doc_mask = []
398
+
399
+ for document in tokenized_document_batch:
400
+ if len(document) > max_document_length:
401
+ # doc_mask.append( [0] * max_document_length )
402
+ document = document[:max_document_length]
403
+ else:
404
+ # doc_mask.append( [0] * len(document) +[1] * ( max_document_length - len(document) ) )
405
+ document = document + [""] * ( max_document_length - len(document) )
406
+
407
+ doc_mask.append( [ 1 if sen.strip() == "" else 0 for sen in document ] )
408
+
409
+ document_sequences = []
410
+ for sen in document:
411
+ seq = self.vocab.sent2seq( sen, max_sentence_length )
412
+ document_sequences.append(seq)
413
+ seqs.append(document_sequences)
414
+ seqs = np.asarray(seqs)
415
+ doc_mask = np.asarray(doc_mask) == 1
416
+ seqs = torch.from_numpy(seqs).to(self.device)
417
+ doc_mask = torch.from_numpy(doc_mask).to(self.device)
418
+
419
+ extracted_sentences = []
420
+ sentence_score_history = []
421
+ p_stop_history = []
422
+
423
+ with torch.no_grad():
424
+ num_sentences = seqs.size(1)
425
+ sen_embed = self.local_sentence_encoder( seqs.view(-1, seqs.size(2) ) )
426
+ sen_embed = sen_embed.view( -1, num_sentences, sen_embed.size(1) )
427
+ relevance_embed = self.global_context_encoder( sen_embed, doc_mask )
428
+
429
+ num_documents = seqs.size(0)
430
+ doc_mask = doc_mask.detach().cpu().numpy()
431
+ seqs = seqs.detach().cpu().numpy()
432
+
433
+ extracted_sentences = []
434
+ extracted_sentences_positions = []
435
+
436
+ for doc_i in range(num_documents):
437
+ current_doc_mask = doc_mask[doc_i:doc_i+1]
438
+ current_remaining_mask_np = np.ones_like(current_doc_mask ).astype(np.bool_) | current_doc_mask
439
+ current_extraction_mask_np = np.zeros_like(current_doc_mask).astype(np.bool_) | current_doc_mask
440
+
441
+ current_sen_embed = sen_embed[doc_i:doc_i+1]
442
+ current_relevance_embed = relevance_embed[ doc_i:doc_i+1 ]
443
+ current_redundancy_embed = None
444
+
445
+ current_hyps = []
446
+ extracted_sen_ngrams = set()
447
+
448
+ sentence_score_history_for_doc_i = []
449
+
450
+ p_stop_history_for_doc_i = []
451
+
452
+ for step in range( max_extracted_sentences_per_document+1 ) :
453
+ current_extraction_mask = torch.from_numpy( current_extraction_mask_np ).to(self.device)
454
+ current_remaining_mask = torch.from_numpy( current_remaining_mask_np ).to(self.device)
455
+ if step > 0:
456
+ current_redundancy_embed = self.extraction_context_decoder( current_sen_embed, current_remaining_mask, current_extraction_mask )
457
+ p, p_stop, _ = self.extractor( current_sen_embed, current_relevance_embed, current_redundancy_embed , current_extraction_mask )
458
+ p_stop = p_stop.unsqueeze(1)
459
+
460
+
461
+ p = p.masked_fill( current_extraction_mask, 1e-12 )
462
+
463
+ sentence_score_history_for_doc_i.append( p.detach().cpu().numpy() )
464
+
465
+ p_stop_history_for_doc_i.append( p_stop.squeeze(1).item() )
466
+
467
+ normalized_p = p / p.sum(dim=1, keepdims = True)
468
+
469
+ stop = p_stop.squeeze(1).item()> p_stop_thres #and step > 0
470
+
471
+ #sen_i = normalized_p.argmax(dim=1)[0]
472
+ _, sorted_sen_indices =normalized_p.sort(dim=1, descending= True)
473
+ sorted_sen_indices = sorted_sen_indices[0]
474
+
475
+ extracted = False
476
+ for sen_i in sorted_sen_indices:
477
+ sen_i = sen_i.item()
478
+ if sen_i< len(document_batch[doc_i]):
479
+ sen = document_batch[doc_i][sen_i]
480
+ else:
481
+ break
482
+ sen_ngrams = self.get_ngram( sen.lower().split(), ngram )
483
+ if not ngram_blocking or len( extracted_sen_ngrams & sen_ngrams ) < 1:
484
+ extracted_sen_ngrams.update( sen_ngrams )
485
+ extracted = True
486
+ break
487
+
488
+ if stop or step == max_extracted_sentences_per_document or not extracted:
489
+ extracted_sentences.append( [ document_batch[doc_i][sen_i] for sen_i in current_hyps if sen_i < len(document_batch[doc_i]) ] )
490
+ extracted_sentences_positions.append( [ sen_i for sen_i in current_hyps if sen_i < len(document_batch[doc_i]) ] )
491
+ break
492
+ else:
493
+ current_hyps.append(sen_i)
494
+ current_extraction_mask_np[0, sen_i] = True
495
+ current_remaining_mask_np[0, sen_i] = False
496
+
497
+ sentence_score_history.append(sentence_score_history_for_doc_i)
498
+ p_stop_history.append( p_stop_history_for_doc_i )
499
+
500
+
501
+ results = [extracted_sentences]
502
+ if return_sentence_position:
503
+ results.append( extracted_sentences_positions )
504
+ if return_sentence_score_history:
505
+ results+=[sentence_score_history , p_stop_history ]
506
+ if len(results) == 1:
507
+ results = results[0]
508
+
509
+ return results
510
+
511
+ ## document is a list of sentences
512
+ def summarize( self, document, p_stop_thres = 0.7, max_extracted_sentences_per_document = 10, return_sentence_position = False ):
513
+ sentences, sentence_positions = self.extract( [document], p_stop_thres, ngram_blocking = False, ngram = 0, return_sentence_position = True, return_sentence_score_history = False, max_extracted_sentences_per_document = max_extracted_sentences_per_document )
514
+ try:
515
+ sentences, sentence_positions = list(zip(*sorted( zip( sentences[0], sentence_positions[0] ), key = lambda x:x[1] )))
516
+ except ValueError:
517
+ sentences, sentence_positions = (), ()
518
+ if return_sentence_position:
519
+ return sentences, sentence_positions
520
+ else:
521
+ return sentences
522
+
523
+
model/MemSum_Final/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0f6e31b7bcdede4ee9e08f8ced14b8460726d812b47dab4bf236d9f912358e7
3
+ size 396802782
model/glove/vocabulary_200dim.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd326474ff86a654fb23b3c4a809b6998271e3f9b4762c6b1f739e893405c7a1
3
+ size 4157881
requirements.txt ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==5.0.0
5
+ anyio==3.6.2
6
+ async-timeout==4.0.2
7
+ attrs==23.1.0
8
+ certifi==2023.5.7
9
+ charset-normalizer==3.1.0
10
+ click==8.1.3
11
+ cmake==3.26.3
12
+ contourpy==1.0.7
13
+ cycler==0.11.0
14
+ dropbox==11.36.0
15
+ fastapi==0.95.1
16
+ ffmpy==0.3.0
17
+ filelock==3.12.0
18
+ fonttools==4.39.4
19
+ frozenlist==1.3.3
20
+ fsspec==2023.5.0
21
+ gradio==3.30.0
22
+ gradio_client==0.2.4
23
+ h11==0.14.0
24
+ httpcore==0.17.0
25
+ httpx==0.24.0
26
+ huggingface-hub==0.14.1
27
+ idna==3.4
28
+ Jinja2==3.1.2
29
+ joblib==1.2.0
30
+ jsonschema==4.17.3
31
+ kiwisolver==1.4.4
32
+ linkify-it-py==2.0.2
33
+ lit==16.0.3
34
+ markdown-it-py==2.2.0
35
+ MarkupSafe==2.1.2
36
+ matplotlib==3.7.1
37
+ mdit-py-plugins==0.3.3
38
+ mdurl==0.1.2
39
+ mpmath==1.3.0
40
+ multidict==6.0.4
41
+ networkx==3.1
42
+ nltk==3.8.1
43
+ numpy==1.24.3
44
+ nvidia-cublas-cu11==11.10.3.66
45
+ nvidia-cuda-cupti-cu11==11.7.101
46
+ nvidia-cuda-nvrtc-cu11==11.7.99
47
+ nvidia-cuda-runtime-cu11==11.7.99
48
+ nvidia-cudnn-cu11==8.5.0.96
49
+ nvidia-cufft-cu11==10.9.0.58
50
+ nvidia-curand-cu11==10.2.10.91
51
+ nvidia-cusolver-cu11==11.4.0.1
52
+ nvidia-cusparse-cu11==11.7.4.91
53
+ nvidia-nccl-cu11==2.14.3
54
+ nvidia-nvtx-cu11==11.7.91
55
+ orjson==3.8.12
56
+ packaging==23.1
57
+ pandas==2.0.1
58
+ Pillow==9.5.0
59
+ ply==3.11
60
+ pydantic==1.10.7
61
+ pydub==0.25.1
62
+ Pygments==2.15.1
63
+ pyparsing==3.0.9
64
+ pyrsistent==0.19.3
65
+ python-dateutil==2.8.2
66
+ python-multipart==0.0.6
67
+ pytz==2023.3
68
+ PyYAML==6.0
69
+ regex==2023.5.5
70
+ requests==2.30.0
71
+ semantic-version==2.10.0
72
+ six==1.16.0
73
+ sniffio==1.3.0
74
+ starlette==0.26.1
75
+ stone==3.3.1
76
+ sympy==1.11.1
77
+ toolz==0.12.0
78
+ torch==2.0.1
79
+ tqdm==4.65.0
80
+ triton==2.0.0
81
+ typing_extensions==4.5.0
82
+ tzdata==2023.3
83
+ uc-micro-py==1.0.2
84
+ urllib3==2.0.2
85
+ uvicorn==0.22.0
86
+ websockets==11.0.3
87
+ yarl==1.9.2