Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files- app.py +23 -0
- memsum.py +523 -0
- model/MemSum_Final/model.pt +3 -0
- model/glove/vocabulary_200dim.pkl +3 -0
- requirements.txt +87 -0
app.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from memsum import MemSum
|
3 |
+
from nltk import sent_tokenize
|
4 |
+
|
5 |
+
model_path = "model/MemSum_Final/model.pt"
|
6 |
+
summarizer = MemSum(model_path, "model/glove/vocabulary_200dim.pkl")
|
7 |
+
|
8 |
+
def preprocess(text):
|
9 |
+
text = text.replace("\n","")
|
10 |
+
text = text.replace("¶","")
|
11 |
+
text = " ".join(text.split())
|
12 |
+
return text
|
13 |
+
|
14 |
+
def summarize(text):
|
15 |
+
|
16 |
+
text = sent_tokenize( preprocess(text) )
|
17 |
+
|
18 |
+
summary = "\n".join( summarizer.summarize(text) )
|
19 |
+
|
20 |
+
return summary
|
21 |
+
|
22 |
+
demo = gr.Interface(fn=summarize, inputs="text", outputs="text")
|
23 |
+
demo.launch(share=True)
|
memsum.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@author: Nianlong Gu, Institute of Neuroinformatics, ETH Zurich
|
3 |
+
@email: nianlonggu@gmail.com
|
4 |
+
|
5 |
+
The source code for the paper: MemSum: Extractive Summarization of Long Documents using Multi-step Episodic Markov Decision Processes
|
6 |
+
|
7 |
+
When using this code or some of our pre-trained models for your application, please cite the following paper:
|
8 |
+
@article{DBLP:journals/corr/abs-2107-08929,
|
9 |
+
author = {Nianlong Gu and
|
10 |
+
Elliott Ash and
|
11 |
+
Richard H. R. Hahnloser},
|
12 |
+
title = {MemSum: Extractive Summarization of Long Documents using Multi-step
|
13 |
+
Episodic Markov Decision Processes},
|
14 |
+
journal = {CoRR},
|
15 |
+
volume = {abs/2107.08929},
|
16 |
+
year = {2021},
|
17 |
+
url = {https://arxiv.org/abs/2107.08929},
|
18 |
+
eprinttype = {arXiv},
|
19 |
+
eprint = {2107.08929},
|
20 |
+
timestamp = {Thu, 22 Jul 2021 11:14:11 +0200},
|
21 |
+
biburl = {https://dblp.org/rec/journals/corr/abs-2107-08929.bib},
|
22 |
+
bibsource = {dblp computer science bibliography, https://dblp.org}
|
23 |
+
}
|
24 |
+
"""
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import math
|
30 |
+
import pickle
|
31 |
+
import numpy as np
|
32 |
+
|
33 |
+
class AddMask( nn.Module ):
|
34 |
+
def __init__( self, pad_index ):
|
35 |
+
super().__init__()
|
36 |
+
self.pad_index = pad_index
|
37 |
+
def forward( self, x):
|
38 |
+
# here x is a batch of input sequences (not embeddings) with the shape of [ batch_size, seq_len]
|
39 |
+
mask = x == self.pad_index
|
40 |
+
return mask
|
41 |
+
|
42 |
+
|
43 |
+
class PositionalEncoding( nn.Module ):
|
44 |
+
def __init__(self, embed_dim, max_seq_len = 512 ):
|
45 |
+
super().__init__()
|
46 |
+
self.embed_dim = embed_dim
|
47 |
+
self.max_seq_len = max_seq_len
|
48 |
+
pe = torch.zeros( 1, max_seq_len, embed_dim )
|
49 |
+
for pos in range( max_seq_len ):
|
50 |
+
for i in range( 0, embed_dim, 2 ):
|
51 |
+
pe[ 0, pos, i ] = math.sin( pos / ( 10000 ** ( i/embed_dim ) ) )
|
52 |
+
if i+1 < embed_dim:
|
53 |
+
pe[ 0, pos, i+1 ] = math.cos( pos / ( 10000** ( i/embed_dim ) ) )
|
54 |
+
self.register_buffer( "pe", pe )
|
55 |
+
## register_buffer can register some variables that can be saved and loaded by state_dict, but not trainable since not accessible by model.parameters()
|
56 |
+
def forward( self, x ):
|
57 |
+
return x + self.pe[ :, : x.size(1), :]
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
class MultiHeadAttention( nn.Module ):
|
62 |
+
def __init__(self, embed_dim, num_heads ):
|
63 |
+
super().__init__()
|
64 |
+
dim_per_head = int( embed_dim/num_heads )
|
65 |
+
|
66 |
+
self.ln_q = nn.Linear( embed_dim, num_heads * dim_per_head )
|
67 |
+
self.ln_k = nn.Linear( embed_dim, num_heads * dim_per_head )
|
68 |
+
self.ln_v = nn.Linear( embed_dim, num_heads * dim_per_head )
|
69 |
+
|
70 |
+
self.ln_out = nn.Linear( num_heads * dim_per_head, embed_dim )
|
71 |
+
|
72 |
+
self.num_heads = num_heads
|
73 |
+
self.dim_per_head = dim_per_head
|
74 |
+
|
75 |
+
def forward( self, q,k,v, mask = None):
|
76 |
+
q = self.ln_q( q )
|
77 |
+
k = self.ln_k( k )
|
78 |
+
v = self.ln_v( v )
|
79 |
+
|
80 |
+
q = q.view( q.size(0), q.size(1), self.num_heads, self.dim_per_head ).transpose( 1,2 )
|
81 |
+
k = k.view( k.size(0), k.size(1), self.num_heads, self.dim_per_head ).transpose( 1,2 )
|
82 |
+
v = v.view( v.size(0), v.size(1), self.num_heads, self.dim_per_head ).transpose( 1,2 )
|
83 |
+
|
84 |
+
a = self.scaled_dot_product_attention( q,k, mask )
|
85 |
+
new_v = a.matmul(v)
|
86 |
+
new_v = new_v.transpose( 1,2 ).contiguous()
|
87 |
+
new_v = new_v.view( new_v.size(0), new_v.size(1), -1 )
|
88 |
+
new_v = self.ln_out(new_v)
|
89 |
+
return new_v
|
90 |
+
|
91 |
+
def scaled_dot_product_attention( self, q, k, mask = None ):
|
92 |
+
## note the here q and k have converted into multi-head mode
|
93 |
+
## q's shape is [ Batchsize, num_heads, seq_len_q, dim_per_head ]
|
94 |
+
## k's shape is [ Batchsize, num_heads, seq_len_k, dim_per_head ]
|
95 |
+
# scaled dot product
|
96 |
+
a = q.matmul( k.transpose( 2,3 ) )/ math.sqrt( q.size(-1) )
|
97 |
+
# apply mask (either padding mask or seqeunce mask)
|
98 |
+
if mask is not None:
|
99 |
+
a = a.masked_fill( mask.unsqueeze(1).unsqueeze(1) , -1e9 )
|
100 |
+
# apply softmax, to get the likelihood as attention matrix
|
101 |
+
a = F.softmax( a, dim=-1 )
|
102 |
+
return a
|
103 |
+
|
104 |
+
class FeedForward( nn.Module ):
|
105 |
+
def __init__( self, embed_dim, hidden_dim ):
|
106 |
+
super().__init__()
|
107 |
+
self.ln1 = nn.Linear( embed_dim, hidden_dim )
|
108 |
+
self.ln2 = nn.Linear( hidden_dim, embed_dim )
|
109 |
+
def forward( self, x):
|
110 |
+
net = F.relu(self.ln1(x))
|
111 |
+
out = self.ln2(net)
|
112 |
+
return out
|
113 |
+
|
114 |
+
|
115 |
+
class TransformerEncoderLayer(nn.Module):
|
116 |
+
def __init__(self, embed_dim, num_heads, hidden_dim ):
|
117 |
+
super().__init__()
|
118 |
+
self.mha = MultiHeadAttention( embed_dim, num_heads )
|
119 |
+
self.norm1 = nn.LayerNorm( embed_dim )
|
120 |
+
self.feed_forward = FeedForward( embed_dim, hidden_dim )
|
121 |
+
self.norm2 = nn.LayerNorm( embed_dim )
|
122 |
+
def forward( self, x, mask, dropout_rate = 0. ):
|
123 |
+
short_cut = x
|
124 |
+
net = F.dropout(self.mha( x,x,x, mask ), p = dropout_rate)
|
125 |
+
net = self.norm1( short_cut + net )
|
126 |
+
short_cut = net
|
127 |
+
net = F.dropout(self.feed_forward( net ), p = dropout_rate )
|
128 |
+
net = self.norm2( short_cut + net )
|
129 |
+
return net
|
130 |
+
|
131 |
+
class TransformerDecoderLayer( nn.Module ):
|
132 |
+
def __init__(self, embed_dim, num_heads, hidden_dim ):
|
133 |
+
super().__init__()
|
134 |
+
self.masked_mha = MultiHeadAttention( embed_dim, num_heads )
|
135 |
+
self.norm1 = nn.LayerNorm( embed_dim )
|
136 |
+
self.mha = MultiHeadAttention( embed_dim, num_heads )
|
137 |
+
self.norm2 = nn.LayerNorm( embed_dim )
|
138 |
+
self.feed_forward = FeedForward( embed_dim, hidden_dim )
|
139 |
+
self.norm3 = nn.LayerNorm( embed_dim )
|
140 |
+
def forward(self, encoder_output, x, src_mask, trg_mask , dropout_rate = 0. ):
|
141 |
+
short_cut = x
|
142 |
+
net = F.dropout(self.masked_mha( x,x,x, trg_mask ), p = dropout_rate)
|
143 |
+
net = self.norm1( short_cut + net )
|
144 |
+
short_cut = net
|
145 |
+
net = F.dropout(self.mha( net, encoder_output, encoder_output, src_mask ), p = dropout_rate)
|
146 |
+
net = self.norm2( short_cut + net )
|
147 |
+
short_cut = net
|
148 |
+
net = F.dropout(self.feed_forward( net ), p = dropout_rate)
|
149 |
+
net = self.norm3( short_cut + net )
|
150 |
+
return net
|
151 |
+
|
152 |
+
class MultiHeadPoolingLayer( nn.Module ):
|
153 |
+
def __init__( self, embed_dim, num_heads ):
|
154 |
+
super().__init__()
|
155 |
+
self.num_heads = num_heads
|
156 |
+
self.dim_per_head = int( embed_dim/num_heads )
|
157 |
+
self.ln_attention_score = nn.Linear( embed_dim, num_heads )
|
158 |
+
self.ln_value = nn.Linear( embed_dim, num_heads * self.dim_per_head )
|
159 |
+
self.ln_out = nn.Linear( num_heads * self.dim_per_head , embed_dim )
|
160 |
+
def forward(self, input_embedding , mask=None):
|
161 |
+
a = self.ln_attention_score( input_embedding )
|
162 |
+
v = self.ln_value( input_embedding )
|
163 |
+
|
164 |
+
a = a.view( a.size(0), a.size(1), self.num_heads, 1 ).transpose(1,2)
|
165 |
+
v = v.view( v.size(0), v.size(1), self.num_heads, self.dim_per_head ).transpose(1,2)
|
166 |
+
a = a.transpose(2,3)
|
167 |
+
if mask is not None:
|
168 |
+
a = a.masked_fill( mask.unsqueeze(1).unsqueeze(1) , -1e9 )
|
169 |
+
a = F.softmax(a , dim = -1 )
|
170 |
+
|
171 |
+
new_v = a.matmul(v)
|
172 |
+
new_v = new_v.transpose( 1,2 ).contiguous()
|
173 |
+
new_v = new_v.view( new_v.size(0), new_v.size(1) ,-1 ).squeeze(1)
|
174 |
+
new_v = self.ln_out( new_v )
|
175 |
+
return new_v
|
176 |
+
|
177 |
+
|
178 |
+
class LocalSentenceEncoder( nn.Module ):
|
179 |
+
def __init__( self, vocab_size, pad_index, embed_dim, num_heads , hidden_dim , num_enc_layers , pretrained_word_embedding ):
|
180 |
+
super().__init__()
|
181 |
+
self.addmask = AddMask( pad_index )
|
182 |
+
|
183 |
+
self.rnn = nn.LSTM( embed_dim, embed_dim, 2, batch_first = True, bidirectional = True)
|
184 |
+
self.mh_pool = MultiHeadPoolingLayer( 2*embed_dim, num_heads )
|
185 |
+
self.norm_out = nn.LayerNorm( 2*embed_dim )
|
186 |
+
self.ln_out = nn.Linear( 2*embed_dim, embed_dim )
|
187 |
+
|
188 |
+
if pretrained_word_embedding is not None:
|
189 |
+
## make sure the pad embedding is 0
|
190 |
+
pretrained_word_embedding[pad_index] = 0
|
191 |
+
self.register_buffer( "word_embedding", torch.from_numpy( pretrained_word_embedding ) )
|
192 |
+
else:
|
193 |
+
self.register_buffer( "word_embedding", torch.randn( vocab_size, embed_dim ) )
|
194 |
+
|
195 |
+
"""
|
196 |
+
input_seq 's shape: batch_size x seq_len
|
197 |
+
"""
|
198 |
+
def forward( self, input_seq, dropout_rate = 0. ):
|
199 |
+
mask = self.addmask( input_seq )
|
200 |
+
## batch_size x seq_len x embed_dim
|
201 |
+
net = self.word_embedding[ input_seq ]
|
202 |
+
net, _ = self.rnn( net )
|
203 |
+
net = self.ln_out(F.relu(self.norm_out(self.mh_pool( net, mask ))))
|
204 |
+
return net
|
205 |
+
|
206 |
+
|
207 |
+
class GlobalContextEncoder(nn.Module):
|
208 |
+
def __init__(self, embed_dim, num_heads, hidden_dim, num_dec_layers ):
|
209 |
+
super().__init__()
|
210 |
+
# self.pos_encode = PositionalEncoding( embed_dim)
|
211 |
+
# self.layer_list = nn.ModuleList( [ TransformerEncoderLayer( embed_dim, num_heads, hidden_dim ) for _ in range(num_dec_layers) ] )
|
212 |
+
self.rnn = nn.LSTM( embed_dim, embed_dim, 2, batch_first = True, bidirectional = True)
|
213 |
+
self.norm_out = nn.LayerNorm( 2*embed_dim )
|
214 |
+
self.ln_out = nn.Linear( 2*embed_dim, embed_dim )
|
215 |
+
|
216 |
+
def forward(self, sen_embed, doc_mask, dropout_rate = 0.):
|
217 |
+
net, _ = self.rnn( sen_embed )
|
218 |
+
net = self.ln_out(F.relu( self.norm_out(net) ) )
|
219 |
+
return net
|
220 |
+
|
221 |
+
|
222 |
+
class ExtractionContextDecoder( nn.Module ):
|
223 |
+
def __init__( self, embed_dim, num_heads, hidden_dim, num_dec_layers ):
|
224 |
+
super().__init__()
|
225 |
+
self.layer_list = nn.ModuleList( [ TransformerDecoderLayer( embed_dim, num_heads, hidden_dim ) for _ in range(num_dec_layers) ] )
|
226 |
+
## remaining_mask: set all unextracted sen indices as True
|
227 |
+
## extraction_mask: set all extracted sen indices as True
|
228 |
+
def forward( self, sen_embed, remaining_mask, extraction_mask, dropout_rate = 0. ):
|
229 |
+
net = sen_embed
|
230 |
+
for layer in self.layer_list:
|
231 |
+
# encoder_output, x, src_mask, trg_mask , dropout_rate = 0.
|
232 |
+
net = layer( sen_embed, net, remaining_mask, extraction_mask, dropout_rate )
|
233 |
+
return net
|
234 |
+
|
235 |
+
class Extractor( nn.Module ):
|
236 |
+
def __init__( self, embed_dim, num_heads ):
|
237 |
+
super().__init__()
|
238 |
+
self.norm_input = nn.LayerNorm( 3*embed_dim )
|
239 |
+
|
240 |
+
self.ln_hidden1 = nn.Linear( 3*embed_dim, 2*embed_dim )
|
241 |
+
self.norm_hidden1 = nn.LayerNorm( 2*embed_dim )
|
242 |
+
|
243 |
+
self.ln_hidden2 = nn.Linear( 2*embed_dim, embed_dim )
|
244 |
+
self.norm_hidden2 = nn.LayerNorm( embed_dim )
|
245 |
+
|
246 |
+
self.ln_out = nn.Linear( embed_dim, 1 )
|
247 |
+
|
248 |
+
self.mh_pool = MultiHeadPoolingLayer( embed_dim, num_heads )
|
249 |
+
self.norm_pool = nn.LayerNorm( embed_dim )
|
250 |
+
self.ln_stop = nn.Linear( embed_dim, 1 )
|
251 |
+
|
252 |
+
self.mh_pool_2 = MultiHeadPoolingLayer( embed_dim, num_heads )
|
253 |
+
self.norm_pool_2 = nn.LayerNorm( embed_dim )
|
254 |
+
self.ln_baseline = nn.Linear( embed_dim, 1 )
|
255 |
+
|
256 |
+
def forward( self, sen_embed, relevance_embed, redundancy_embed , extraction_mask, dropout_rate = 0. ):
|
257 |
+
if redundancy_embed is None:
|
258 |
+
redundancy_embed = torch.zeros_like( sen_embed )
|
259 |
+
net = self.norm_input( F.dropout( torch.cat( [ sen_embed, relevance_embed, redundancy_embed ], dim = 2 ) , p = dropout_rate ) )
|
260 |
+
net = F.relu( self.norm_hidden1( F.dropout( self.ln_hidden1( net ) , p = dropout_rate ) ))
|
261 |
+
hidden_net = F.relu( self.norm_hidden2( F.dropout( self.ln_hidden2( net) , p = dropout_rate ) ))
|
262 |
+
|
263 |
+
p = self.ln_out( hidden_net ).sigmoid().squeeze(2)
|
264 |
+
|
265 |
+
net = F.relu( self.norm_pool( F.dropout( self.mh_pool( hidden_net, extraction_mask) , p = dropout_rate ) ))
|
266 |
+
p_stop = self.ln_stop( net ).sigmoid().squeeze(1)
|
267 |
+
|
268 |
+
net = F.relu( self.norm_pool_2( F.dropout( self.mh_pool_2( hidden_net, extraction_mask ) , p = dropout_rate ) ))
|
269 |
+
baseline = self.ln_baseline(net)
|
270 |
+
|
271 |
+
return p, p_stop, baseline
|
272 |
+
|
273 |
+
|
274 |
+
## naive tokenizer with just lower() function
|
275 |
+
class SentenceTokenizer:
|
276 |
+
def __init__(self ):
|
277 |
+
pass
|
278 |
+
def tokenize(self, sen ):
|
279 |
+
return sen.lower()
|
280 |
+
|
281 |
+
|
282 |
+
class Vocab:
|
283 |
+
def __init__(self, words, eos_token = "<eos>", pad_token = "<pad>", unk_token = "<unk>" ):
|
284 |
+
self.words = words
|
285 |
+
self.index_to_word = {}
|
286 |
+
self.word_to_index = {}
|
287 |
+
for idx in range( len(words) ):
|
288 |
+
self.index_to_word[ idx ] = words[idx]
|
289 |
+
self.word_to_index[ words[idx] ] = idx
|
290 |
+
self.eos_token = eos_token
|
291 |
+
self.pad_token = pad_token
|
292 |
+
self.unk_token = unk_token
|
293 |
+
self.eos_index = self.word_to_index[self.eos_token]
|
294 |
+
self.pad_index = self.word_to_index[self.pad_token]
|
295 |
+
|
296 |
+
self.tokenizer = SentenceTokenizer()
|
297 |
+
|
298 |
+
def index2word( self, idx ):
|
299 |
+
return self.index_to_word.get( idx, self.unk_token)
|
300 |
+
def word2index( self, word ):
|
301 |
+
return self.word_to_index.get( word, -1 )
|
302 |
+
# The sentence needs to be tokenized
|
303 |
+
def sent2seq( self, sent, max_len = None , tokenize = True):
|
304 |
+
if tokenize:
|
305 |
+
sent = self.tokenizer.tokenize(sent)
|
306 |
+
seq = []
|
307 |
+
for w in sent.split():
|
308 |
+
if w in self.word_to_index:
|
309 |
+
seq.append( self.word2index(w) )
|
310 |
+
if max_len is not None:
|
311 |
+
if len(seq) >= max_len:
|
312 |
+
seq = seq[:max_len -1]
|
313 |
+
seq.append( self.eos_index )
|
314 |
+
else:
|
315 |
+
seq.append( self.eos_index )
|
316 |
+
seq += [ self.pad_index ] * ( max_len - len(seq) )
|
317 |
+
return seq
|
318 |
+
def seq2sent( self, seq ):
|
319 |
+
sent = []
|
320 |
+
for i in seq:
|
321 |
+
if i == self.eos_index or i == self.pad_index:
|
322 |
+
break
|
323 |
+
sent.append( self.index2word(i) )
|
324 |
+
return " ".join(sent)
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
class MemSum:
|
329 |
+
def __init__( self, model_path, vocabulary_path, gpu = None ):
|
330 |
+
|
331 |
+
## max_doc_len is used to truncate too long sentence into at most 100 words
|
332 |
+
max_seq_len =100
|
333 |
+
## max_doc_len is used to truncate too long document into at most 200 sentences
|
334 |
+
max_doc_len = 200
|
335 |
+
|
336 |
+
## These parameters below have been fintuned for the pretrained model
|
337 |
+
embed_dim=200
|
338 |
+
num_heads=8
|
339 |
+
hidden_dim = 1024
|
340 |
+
N_enc_l = 2
|
341 |
+
N_enc_g = 2
|
342 |
+
N_dec = 3
|
343 |
+
|
344 |
+
with open( vocabulary_path , "rb" ) as f:
|
345 |
+
words = pickle.load(f)
|
346 |
+
self.vocab = Vocab( words )
|
347 |
+
vocab_size = len(words)
|
348 |
+
self.local_sentence_encoder = LocalSentenceEncoder( vocab_size, self.vocab.pad_index, embed_dim,num_heads,hidden_dim,N_enc_l, None )
|
349 |
+
self.global_context_encoder = GlobalContextEncoder( embed_dim, num_heads, hidden_dim, N_enc_g )
|
350 |
+
self.extraction_context_decoder = ExtractionContextDecoder( embed_dim, num_heads, hidden_dim, N_dec )
|
351 |
+
self.extractor = Extractor( embed_dim, num_heads )
|
352 |
+
ckpt = torch.load( model_path, map_location = "cpu" )
|
353 |
+
self.local_sentence_encoder.load_state_dict( ckpt["local_sentence_encoder"] )
|
354 |
+
self.global_context_encoder.load_state_dict( ckpt["global_context_encoder"] )
|
355 |
+
self.extraction_context_decoder.load_state_dict( ckpt["extraction_context_decoder"] )
|
356 |
+
self.extractor.load_state_dict(ckpt["extractor"])
|
357 |
+
|
358 |
+
self.device = torch.device( "cuda:%d"%(gpu) if gpu is not None and torch.cuda.is_available() else "cpu" )
|
359 |
+
self.local_sentence_encoder.to(self.device)
|
360 |
+
self.global_context_encoder.to(self.device)
|
361 |
+
self.extraction_context_decoder.to(self.device)
|
362 |
+
self.extractor.to(self.device)
|
363 |
+
|
364 |
+
self.sentence_tokenizer = SentenceTokenizer()
|
365 |
+
self.max_seq_len = max_seq_len
|
366 |
+
self.max_doc_len = max_doc_len
|
367 |
+
|
368 |
+
def get_ngram(self, w_list, n = 4 ):
|
369 |
+
ngram_set = set()
|
370 |
+
for pos in range(len(w_list) - n + 1 ):
|
371 |
+
ngram_set.add( "_".join( w_list[ pos:pos+n] ) )
|
372 |
+
return ngram_set
|
373 |
+
|
374 |
+
def extract( self, document_batch, p_stop_thres , ngram_blocking , ngram, return_sentence_position, return_sentence_score_history, max_extracted_sentences_per_document ):
|
375 |
+
"""document_batch is a batch of documents:
|
376 |
+
[ [ sen1, sen2, ... , senL1 ],
|
377 |
+
[ sen1, sen2, ... , senL2], ...
|
378 |
+
]
|
379 |
+
"""
|
380 |
+
## tokenization:
|
381 |
+
document_length_list = []
|
382 |
+
sentence_length_list = []
|
383 |
+
tokenized_document_batch = []
|
384 |
+
for document in document_batch:
|
385 |
+
tokenized_document = []
|
386 |
+
for sen in document:
|
387 |
+
tokenized_sen = self.sentence_tokenizer.tokenize( sen )
|
388 |
+
tokenized_document.append( tokenized_sen )
|
389 |
+
sentence_length_list.append( len(tokenized_sen.split()) )
|
390 |
+
tokenized_document_batch.append( tokenized_document )
|
391 |
+
document_length_list.append( len(tokenized_document) )
|
392 |
+
|
393 |
+
max_document_length = self.max_doc_len
|
394 |
+
max_sentence_length = self.max_seq_len
|
395 |
+
## convert to sequence
|
396 |
+
seqs = []
|
397 |
+
doc_mask = []
|
398 |
+
|
399 |
+
for document in tokenized_document_batch:
|
400 |
+
if len(document) > max_document_length:
|
401 |
+
# doc_mask.append( [0] * max_document_length )
|
402 |
+
document = document[:max_document_length]
|
403 |
+
else:
|
404 |
+
# doc_mask.append( [0] * len(document) +[1] * ( max_document_length - len(document) ) )
|
405 |
+
document = document + [""] * ( max_document_length - len(document) )
|
406 |
+
|
407 |
+
doc_mask.append( [ 1 if sen.strip() == "" else 0 for sen in document ] )
|
408 |
+
|
409 |
+
document_sequences = []
|
410 |
+
for sen in document:
|
411 |
+
seq = self.vocab.sent2seq( sen, max_sentence_length )
|
412 |
+
document_sequences.append(seq)
|
413 |
+
seqs.append(document_sequences)
|
414 |
+
seqs = np.asarray(seqs)
|
415 |
+
doc_mask = np.asarray(doc_mask) == 1
|
416 |
+
seqs = torch.from_numpy(seqs).to(self.device)
|
417 |
+
doc_mask = torch.from_numpy(doc_mask).to(self.device)
|
418 |
+
|
419 |
+
extracted_sentences = []
|
420 |
+
sentence_score_history = []
|
421 |
+
p_stop_history = []
|
422 |
+
|
423 |
+
with torch.no_grad():
|
424 |
+
num_sentences = seqs.size(1)
|
425 |
+
sen_embed = self.local_sentence_encoder( seqs.view(-1, seqs.size(2) ) )
|
426 |
+
sen_embed = sen_embed.view( -1, num_sentences, sen_embed.size(1) )
|
427 |
+
relevance_embed = self.global_context_encoder( sen_embed, doc_mask )
|
428 |
+
|
429 |
+
num_documents = seqs.size(0)
|
430 |
+
doc_mask = doc_mask.detach().cpu().numpy()
|
431 |
+
seqs = seqs.detach().cpu().numpy()
|
432 |
+
|
433 |
+
extracted_sentences = []
|
434 |
+
extracted_sentences_positions = []
|
435 |
+
|
436 |
+
for doc_i in range(num_documents):
|
437 |
+
current_doc_mask = doc_mask[doc_i:doc_i+1]
|
438 |
+
current_remaining_mask_np = np.ones_like(current_doc_mask ).astype(np.bool_) | current_doc_mask
|
439 |
+
current_extraction_mask_np = np.zeros_like(current_doc_mask).astype(np.bool_) | current_doc_mask
|
440 |
+
|
441 |
+
current_sen_embed = sen_embed[doc_i:doc_i+1]
|
442 |
+
current_relevance_embed = relevance_embed[ doc_i:doc_i+1 ]
|
443 |
+
current_redundancy_embed = None
|
444 |
+
|
445 |
+
current_hyps = []
|
446 |
+
extracted_sen_ngrams = set()
|
447 |
+
|
448 |
+
sentence_score_history_for_doc_i = []
|
449 |
+
|
450 |
+
p_stop_history_for_doc_i = []
|
451 |
+
|
452 |
+
for step in range( max_extracted_sentences_per_document+1 ) :
|
453 |
+
current_extraction_mask = torch.from_numpy( current_extraction_mask_np ).to(self.device)
|
454 |
+
current_remaining_mask = torch.from_numpy( current_remaining_mask_np ).to(self.device)
|
455 |
+
if step > 0:
|
456 |
+
current_redundancy_embed = self.extraction_context_decoder( current_sen_embed, current_remaining_mask, current_extraction_mask )
|
457 |
+
p, p_stop, _ = self.extractor( current_sen_embed, current_relevance_embed, current_redundancy_embed , current_extraction_mask )
|
458 |
+
p_stop = p_stop.unsqueeze(1)
|
459 |
+
|
460 |
+
|
461 |
+
p = p.masked_fill( current_extraction_mask, 1e-12 )
|
462 |
+
|
463 |
+
sentence_score_history_for_doc_i.append( p.detach().cpu().numpy() )
|
464 |
+
|
465 |
+
p_stop_history_for_doc_i.append( p_stop.squeeze(1).item() )
|
466 |
+
|
467 |
+
normalized_p = p / p.sum(dim=1, keepdims = True)
|
468 |
+
|
469 |
+
stop = p_stop.squeeze(1).item()> p_stop_thres #and step > 0
|
470 |
+
|
471 |
+
#sen_i = normalized_p.argmax(dim=1)[0]
|
472 |
+
_, sorted_sen_indices =normalized_p.sort(dim=1, descending= True)
|
473 |
+
sorted_sen_indices = sorted_sen_indices[0]
|
474 |
+
|
475 |
+
extracted = False
|
476 |
+
for sen_i in sorted_sen_indices:
|
477 |
+
sen_i = sen_i.item()
|
478 |
+
if sen_i< len(document_batch[doc_i]):
|
479 |
+
sen = document_batch[doc_i][sen_i]
|
480 |
+
else:
|
481 |
+
break
|
482 |
+
sen_ngrams = self.get_ngram( sen.lower().split(), ngram )
|
483 |
+
if not ngram_blocking or len( extracted_sen_ngrams & sen_ngrams ) < 1:
|
484 |
+
extracted_sen_ngrams.update( sen_ngrams )
|
485 |
+
extracted = True
|
486 |
+
break
|
487 |
+
|
488 |
+
if stop or step == max_extracted_sentences_per_document or not extracted:
|
489 |
+
extracted_sentences.append( [ document_batch[doc_i][sen_i] for sen_i in current_hyps if sen_i < len(document_batch[doc_i]) ] )
|
490 |
+
extracted_sentences_positions.append( [ sen_i for sen_i in current_hyps if sen_i < len(document_batch[doc_i]) ] )
|
491 |
+
break
|
492 |
+
else:
|
493 |
+
current_hyps.append(sen_i)
|
494 |
+
current_extraction_mask_np[0, sen_i] = True
|
495 |
+
current_remaining_mask_np[0, sen_i] = False
|
496 |
+
|
497 |
+
sentence_score_history.append(sentence_score_history_for_doc_i)
|
498 |
+
p_stop_history.append( p_stop_history_for_doc_i )
|
499 |
+
|
500 |
+
|
501 |
+
results = [extracted_sentences]
|
502 |
+
if return_sentence_position:
|
503 |
+
results.append( extracted_sentences_positions )
|
504 |
+
if return_sentence_score_history:
|
505 |
+
results+=[sentence_score_history , p_stop_history ]
|
506 |
+
if len(results) == 1:
|
507 |
+
results = results[0]
|
508 |
+
|
509 |
+
return results
|
510 |
+
|
511 |
+
## document is a list of sentences
|
512 |
+
def summarize( self, document, p_stop_thres = 0.7, max_extracted_sentences_per_document = 10, return_sentence_position = False ):
|
513 |
+
sentences, sentence_positions = self.extract( [document], p_stop_thres, ngram_blocking = False, ngram = 0, return_sentence_position = True, return_sentence_score_history = False, max_extracted_sentences_per_document = max_extracted_sentences_per_document )
|
514 |
+
try:
|
515 |
+
sentences, sentence_positions = list(zip(*sorted( zip( sentences[0], sentence_positions[0] ), key = lambda x:x[1] )))
|
516 |
+
except ValueError:
|
517 |
+
sentences, sentence_positions = (), ()
|
518 |
+
if return_sentence_position:
|
519 |
+
return sentences, sentence_positions
|
520 |
+
else:
|
521 |
+
return sentences
|
522 |
+
|
523 |
+
|
model/MemSum_Final/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d0f6e31b7bcdede4ee9e08f8ced14b8460726d812b47dab4bf236d9f912358e7
|
3 |
+
size 396802782
|
model/glove/vocabulary_200dim.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd326474ff86a654fb23b3c4a809b6998271e3f9b4762c6b1f739e893405c7a1
|
3 |
+
size 4157881
|
requirements.txt
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.1.0
|
2 |
+
aiohttp==3.8.4
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==5.0.0
|
5 |
+
anyio==3.6.2
|
6 |
+
async-timeout==4.0.2
|
7 |
+
attrs==23.1.0
|
8 |
+
certifi==2023.5.7
|
9 |
+
charset-normalizer==3.1.0
|
10 |
+
click==8.1.3
|
11 |
+
cmake==3.26.3
|
12 |
+
contourpy==1.0.7
|
13 |
+
cycler==0.11.0
|
14 |
+
dropbox==11.36.0
|
15 |
+
fastapi==0.95.1
|
16 |
+
ffmpy==0.3.0
|
17 |
+
filelock==3.12.0
|
18 |
+
fonttools==4.39.4
|
19 |
+
frozenlist==1.3.3
|
20 |
+
fsspec==2023.5.0
|
21 |
+
gradio==3.30.0
|
22 |
+
gradio_client==0.2.4
|
23 |
+
h11==0.14.0
|
24 |
+
httpcore==0.17.0
|
25 |
+
httpx==0.24.0
|
26 |
+
huggingface-hub==0.14.1
|
27 |
+
idna==3.4
|
28 |
+
Jinja2==3.1.2
|
29 |
+
joblib==1.2.0
|
30 |
+
jsonschema==4.17.3
|
31 |
+
kiwisolver==1.4.4
|
32 |
+
linkify-it-py==2.0.2
|
33 |
+
lit==16.0.3
|
34 |
+
markdown-it-py==2.2.0
|
35 |
+
MarkupSafe==2.1.2
|
36 |
+
matplotlib==3.7.1
|
37 |
+
mdit-py-plugins==0.3.3
|
38 |
+
mdurl==0.1.2
|
39 |
+
mpmath==1.3.0
|
40 |
+
multidict==6.0.4
|
41 |
+
networkx==3.1
|
42 |
+
nltk==3.8.1
|
43 |
+
numpy==1.24.3
|
44 |
+
nvidia-cublas-cu11==11.10.3.66
|
45 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
46 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
47 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
48 |
+
nvidia-cudnn-cu11==8.5.0.96
|
49 |
+
nvidia-cufft-cu11==10.9.0.58
|
50 |
+
nvidia-curand-cu11==10.2.10.91
|
51 |
+
nvidia-cusolver-cu11==11.4.0.1
|
52 |
+
nvidia-cusparse-cu11==11.7.4.91
|
53 |
+
nvidia-nccl-cu11==2.14.3
|
54 |
+
nvidia-nvtx-cu11==11.7.91
|
55 |
+
orjson==3.8.12
|
56 |
+
packaging==23.1
|
57 |
+
pandas==2.0.1
|
58 |
+
Pillow==9.5.0
|
59 |
+
ply==3.11
|
60 |
+
pydantic==1.10.7
|
61 |
+
pydub==0.25.1
|
62 |
+
Pygments==2.15.1
|
63 |
+
pyparsing==3.0.9
|
64 |
+
pyrsistent==0.19.3
|
65 |
+
python-dateutil==2.8.2
|
66 |
+
python-multipart==0.0.6
|
67 |
+
pytz==2023.3
|
68 |
+
PyYAML==6.0
|
69 |
+
regex==2023.5.5
|
70 |
+
requests==2.30.0
|
71 |
+
semantic-version==2.10.0
|
72 |
+
six==1.16.0
|
73 |
+
sniffio==1.3.0
|
74 |
+
starlette==0.26.1
|
75 |
+
stone==3.3.1
|
76 |
+
sympy==1.11.1
|
77 |
+
toolz==0.12.0
|
78 |
+
torch==2.0.1
|
79 |
+
tqdm==4.65.0
|
80 |
+
triton==2.0.0
|
81 |
+
typing_extensions==4.5.0
|
82 |
+
tzdata==2023.3
|
83 |
+
uc-micro-py==1.0.2
|
84 |
+
urllib3==2.0.2
|
85 |
+
uvicorn==0.22.0
|
86 |
+
websockets==11.0.3
|
87 |
+
yarl==1.9.2
|