HuyenNguyen nguyenvulebinh commited on
Commit
394811b
0 Parent(s):

Duplicate from nguyenvulebinh/spoken-norm-taggen

Browse files

Co-authored-by: Binh Nguyen <nguyenvulebinh@users.noreply.huggingface.co>

Files changed (10) hide show
  1. .gitattributes +27 -0
  2. README.md +38 -0
  3. app.py +25 -0
  4. attentions.py +466 -0
  5. data_handling.py +336 -0
  6. infer.py +374 -0
  7. model_config_handling.py +90 -0
  8. model_handling.py +763 -0
  9. requirements.txt +8 -0
  10. utils.py +271 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Spoken Norm
3
+ emoji: 📊
4
+ colorFrom: gray
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ duplicated_from: nguyenvulebinh/spoken-norm-taggen
10
+ ---
11
+
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio` or `streamlit`
28
+
29
+ `sdk_version` : _string_
30
+ Only applicable for `streamlit` SDK.
31
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
+
33
+ `app_file`: _string_
34
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
35
+ Path is relative to the root of the repository.
36
+
37
+ `pinned`: _boolean_
38
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from infer import infer
3
+
4
+
5
+
6
+ def format_text(text_input, list_bias_input):
7
+ print('{}\n{}\n\n'.format(text_input, list_bias_input))
8
+ bias_list = list_bias_input.strip().split('\n')
9
+ norm_result = infer([text_input], bias_list)
10
+ return norm_result[0]
11
+
12
+
13
+ title = "Transformation spoken text to written text"
14
+
15
+ iface = gr.Interface(format_text,
16
+ [
17
+ gr.inputs.Textbox(
18
+ lines=1,
19
+ default="ngày hai tám tháng tư cô vít bùng phát ở xì cút len chiếm tám mươi phần trăm là biến chủng đen ta và bê ta và ô mi cờ ron"),
20
+ gr.inputs.Textbox(
21
+ lines=5, default='covid\ndelta\nbeta\nomicron | ô mi cờ ron\nscotland | sờ cốt lờn | xì cút len'),
22
+ ],
23
+ outputs="text",
24
+ title=title)
25
+ iface.launch()
attentions.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ import numpy as np
7
+ from typing import Optional, Tuple
8
+
9
+
10
+ class ScaledDotProductAttention(nn.Module):
11
+ """
12
+ Scaled Dot-Product Attention proposed in "Attention Is All You Need"
13
+ Compute the dot products of the query with all keys, divide each by sqrt(dim),
14
+ and apply a softmax function to obtain the weights on the values
15
+
16
+ Args: dim, mask
17
+ dim (int): dimention of attention
18
+ mask (torch.Tensor): tensor containing indices to be masked
19
+
20
+ Inputs: query, key, value, mask
21
+ - **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
22
+ - **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
23
+ - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
24
+ - **mask** (-): tensor containing indices to be masked
25
+
26
+ Returns: context, attn
27
+ - **context**: tensor containing the context vector from attention mechanism.
28
+ - **attn**: tensor containing the attention (alignment) from the encoder outputs.
29
+ """
30
+
31
+ def __init__(self, dim: int):
32
+ super(ScaledDotProductAttention, self).__init__()
33
+ self.sqrt_dim = np.sqrt(dim)
34
+
35
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[
36
+ Tensor, Tensor]:
37
+ score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
38
+
39
+ if mask is not None:
40
+ score.masked_fill_(mask.view(score.size()), -float('Inf'))
41
+
42
+ attn = F.softmax(score, -1)
43
+ context = torch.bmm(attn, value)
44
+ return context, attn
45
+
46
+
47
+ class DotProductAttention(nn.Module):
48
+ """
49
+ Compute the dot products of the query with all values and apply a softmax function to obtain the weights on the values
50
+ """
51
+
52
+ def __init__(self, hidden_dim):
53
+ super(DotProductAttention, self).__init__()
54
+ self.normalize = nn.LayerNorm(hidden_dim)
55
+
56
+ def forward(self, query: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
57
+ batch_size, hidden_dim, input_size = query.size(0), query.size(2), value.size(1)
58
+
59
+ score = torch.bmm(query, value.transpose(1, 2))
60
+ attn = F.softmax(score.view(-1, input_size), dim=1).view(batch_size, -1, input_size)
61
+ context = torch.bmm(attn, value)
62
+
63
+ return context, attn
64
+
65
+
66
+ class AdditiveAttention(nn.Module):
67
+ """
68
+ Applies a additive attention (bahdanau) mechanism on the output features from the decoder.
69
+ Additive attention proposed in "Neural Machine Translation by Jointly Learning to Align and Translate" paper.
70
+
71
+ Args:
72
+ hidden_dim (int): dimesion of hidden state vector
73
+
74
+ Inputs: query, value
75
+ - **query** (batch_size, q_len, hidden_dim): tensor containing the output features from the decoder.
76
+ - **value** (batch_size, v_len, hidden_dim): tensor containing features of the encoded input sequence.
77
+
78
+ Returns: context, attn
79
+ - **context**: tensor containing the context vector from attention mechanism.
80
+ - **attn**: tensor containing the alignment from the encoder outputs.
81
+
82
+ Reference:
83
+ - **Neural Machine Translation by Jointly Learning to Align and Translate**: https://arxiv.org/abs/1409.0473
84
+ """
85
+
86
+ def __init__(self, hidden_dim: int) -> None:
87
+ super(AdditiveAttention, self).__init__()
88
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
89
+ self.key_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
90
+ self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
91
+ self.score_proj = nn.Linear(hidden_dim, 1)
92
+
93
+ def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
94
+ score = self.score_proj(torch.tanh(self.key_proj(key) + self.query_proj(query) + self.bias)).squeeze(-1)
95
+ attn = F.softmax(score, dim=-1)
96
+ context = torch.bmm(attn.unsqueeze(1), value)
97
+ return context, attn
98
+
99
+
100
+ class LocationAwareAttention(nn.Module):
101
+ """
102
+ Applies a location-aware attention mechanism on the output features from the decoder.
103
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
104
+ The location-aware attention mechanism is performing well in speech recognition tasks.
105
+ We refer to implementation of ClovaCall Attention style.
106
+
107
+ Args:
108
+ hidden_dim (int): dimesion of hidden state vector
109
+ smoothing (bool): flag indication whether to use smoothing or not.
110
+
111
+ Inputs: query, value, last_attn, smoothing
112
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
113
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
114
+ - **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
115
+
116
+ Returns: output, attn
117
+ - **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
118
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
119
+
120
+ Reference:
121
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
122
+ - **ClovaCall**: https://github.com/clovaai/ClovaCall/blob/master/las.pytorch/models/attention.py
123
+ """
124
+
125
+ def __init__(self, hidden_dim: int, smoothing: bool = True) -> None:
126
+ super(LocationAwareAttention, self).__init__()
127
+ self.hidden_dim = hidden_dim
128
+ self.conv1d = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=3, padding=1)
129
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
130
+ self.value_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
131
+ self.score_proj = nn.Linear(hidden_dim, 1, bias=True)
132
+ self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
133
+ self.smoothing = smoothing
134
+
135
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
136
+ batch_size, hidden_dim, seq_len = query.size(0), query.size(2), value.size(1)
137
+
138
+ # Initialize previous attention (alignment) to zeros
139
+ if last_attn is None:
140
+ last_attn = value.new_zeros(batch_size, seq_len)
141
+
142
+ conv_attn = torch.transpose(self.conv1d(last_attn.unsqueeze(1)), 1, 2)
143
+ score = self.score_proj(torch.tanh(
144
+ self.query_proj(query.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
145
+ + self.value_proj(value.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
146
+ + conv_attn
147
+ + self.bias
148
+ )).squeeze(dim=-1)
149
+
150
+ if self.smoothing:
151
+ score = torch.sigmoid(score)
152
+ attn = torch.div(score, score.sum(dim=-1).unsqueeze(dim=-1))
153
+ else:
154
+ attn = F.softmax(score, dim=-1)
155
+
156
+ context = torch.bmm(attn.unsqueeze(dim=1), value).squeeze(dim=1) # Bx1xT X BxTxD => Bx1xD => BxD
157
+
158
+ return context, attn
159
+
160
+
161
+ class MultiHeadLocationAwareAttention(nn.Module):
162
+ """
163
+ Applies a multi-headed location-aware attention mechanism on the output features from the decoder.
164
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
165
+ The location-aware attention mechanism is performing well in speech recognition tasks.
166
+ In the above paper applied a signle head, but we applied multi head concept.
167
+
168
+ Args:
169
+ hidden_dim (int): The number of expected features in the output
170
+ num_heads (int): The number of heads. (default: )
171
+ conv_out_channel (int): The number of out channel in convolution
172
+
173
+ Inputs: query, value, prev_attn
174
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
175
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
176
+ - **prev_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
177
+
178
+ Returns: output, attn
179
+ - **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
180
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
181
+
182
+ Reference:
183
+ - **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
184
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
185
+ """
186
+
187
+ def __init__(self, hidden_dim: int, num_heads: int = 8, conv_out_channel: int = 10) -> None:
188
+ super(MultiHeadLocationAwareAttention, self).__init__()
189
+ self.hidden_dim = hidden_dim
190
+ self.num_heads = num_heads
191
+ self.dim = int(hidden_dim / num_heads)
192
+ self.conv1d = nn.Conv1d(num_heads, conv_out_channel, kernel_size=3, padding=1)
193
+ self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
194
+ self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
195
+ self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
196
+ self.score_proj = nn.Linear(self.dim, 1, bias=True)
197
+ self.bias = nn.Parameter(torch.rand(self.dim).uniform_(-0.1, 0.1))
198
+
199
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
200
+ batch_size, seq_len = value.size(0), value.size(1)
201
+
202
+ if last_attn is None:
203
+ last_attn = value.new_zeros(batch_size, self.num_heads, seq_len)
204
+
205
+ loc_energy = torch.tanh(self.loc_proj(self.conv1d(last_attn).transpose(1, 2)))
206
+ loc_energy = loc_energy.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(-1, seq_len, self.dim)
207
+
208
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
209
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
210
+ query = query.contiguous().view(-1, 1, self.dim)
211
+ value = value.contiguous().view(-1, seq_len, self.dim)
212
+
213
+ score = self.score_proj(torch.tanh(value + query + loc_energy + self.bias)).squeeze(2)
214
+ attn = F.softmax(score, dim=1)
215
+
216
+ value = value.view(batch_size, seq_len, self.num_heads, self.dim).permute(0, 2, 1, 3)
217
+ value = value.contiguous().view(-1, seq_len, self.dim)
218
+
219
+ context = torch.bmm(attn.unsqueeze(1), value).view(batch_size, -1, self.num_heads * self.dim)
220
+ attn = attn.view(batch_size, self.num_heads, -1)
221
+
222
+ return context, attn
223
+
224
+
225
+ class MultiHeadAttention(nn.Module):
226
+ """
227
+ Multi-Head Attention proposed in "Attention Is All You Need"
228
+ Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
229
+ project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
230
+ These are concatenated and once again projected, resulting in the final values.
231
+ Multi-head attention allows the model to jointly attend to information from different representation
232
+ subspaces at different positions.
233
+
234
+ MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
235
+ where head_i = Attention(Q · W_q, K · W_k, V · W_v)
236
+
237
+ Args:
238
+ d_model (int): The dimension of keys / values / quries (default: 512)
239
+ num_heads (int): The number of attention heads. (default: 8)
240
+
241
+ Inputs: query, key, value, mask
242
+ - **query** (batch, q_len, d_model): In transformer, three different ways:
243
+ Case 1: come from previoys decoder layer
244
+ Case 2: come from the input embedding
245
+ Case 3: come from the output embedding (masked)
246
+
247
+ - **key** (batch, k_len, d_model): In transformer, three different ways:
248
+ Case 1: come from the output of the encoder
249
+ Case 2: come from the input embeddings
250
+ Case 3: come from the output embedding (masked)
251
+
252
+ - **value** (batch, v_len, d_model): In transformer, three different ways:
253
+ Case 1: come from the output of the encoder
254
+ Case 2: come from the input embeddings
255
+ Case 3: come from the output embedding (masked)
256
+
257
+ - **mask** (-): tensor containing indices to be masked
258
+
259
+ Returns: output, attn
260
+ - **output** (batch, output_len, dimensions): tensor containing the attended output features.
261
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
262
+ """
263
+
264
+ def __init__(self, d_model: int = 512, num_heads: int = 8):
265
+ super(MultiHeadAttention, self).__init__()
266
+
267
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
268
+
269
+ self.d_head = int(d_model / num_heads)
270
+ self.num_heads = num_heads
271
+ self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
272
+ self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
273
+ self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
274
+ self.value_proj = nn.Linear(d_model, self.d_head * num_heads)
275
+
276
+ def forward(
277
+ self,
278
+ query: Tensor,
279
+ key: Tensor,
280
+ value: Tensor,
281
+ mask: Optional[Tensor] = None
282
+ ) -> Tuple[Tensor, Tensor]:
283
+ batch_size = value.size(0)
284
+
285
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
286
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head) # BxK_LENxNxD
287
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head) # BxV_LENxNxD
288
+
289
+ query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxQ_LENxD
290
+ key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxK_LENxD
291
+ value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxV_LENxD
292
+
293
+ if mask is not None:
294
+ mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # BxNxQ_LENxK_LEN
295
+
296
+ context, attn = self.scaled_dot_attn(query, key, value, mask)
297
+
298
+ context = context.view(self.num_heads, batch_size, -1, self.d_head)
299
+ context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) # BxTxND
300
+
301
+ return context, attn
302
+
303
+
304
+ class RelativeMultiHeadAttention(nn.Module):
305
+ """
306
+ Multi-head attention with relative positional encoding.
307
+ This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
308
+
309
+ Args:
310
+ d_model (int): The dimension of model
311
+ num_heads (int): The number of attention heads.
312
+ dropout_p (float): probability of dropout
313
+
314
+ Inputs: query, key, value, pos_embedding, mask
315
+ - **query** (batch, time, dim): Tensor containing query vector
316
+ - **key** (batch, time, dim): Tensor containing key vector
317
+ - **value** (batch, time, dim): Tensor containing value vector
318
+ - **pos_embedding** (batch, time, dim): Positional embedding tensor
319
+ - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
320
+
321
+ Returns:
322
+ - **outputs**: Tensor produces by relative multi head attention module.
323
+ """
324
+
325
+ def __init__(
326
+ self,
327
+ d_model: int = 512,
328
+ num_heads: int = 16,
329
+ dropout_p: float = 0.1,
330
+ ):
331
+ super(RelativeMultiHeadAttention, self).__init__()
332
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
333
+ self.d_model = d_model
334
+ self.d_head = int(d_model / num_heads)
335
+ self.num_heads = num_heads
336
+ self.sqrt_dim = math.sqrt(d_model)
337
+
338
+ self.query_proj = nn.Linear(d_model, d_model)
339
+ self.key_proj = nn.Linear(d_model, d_model)
340
+ self.value_proj = nn.Linear(d_model, d_model)
341
+ self.pos_proj = nn.Linear(d_model, d_model, bias=False)
342
+
343
+ self.dropout = nn.Dropout(p=dropout_p)
344
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
345
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
346
+ torch.nn.init.xavier_uniform_(self.u_bias)
347
+ torch.nn.init.xavier_uniform_(self.v_bias)
348
+
349
+ self.out_proj = nn.Linear(d_model, d_model)
350
+
351
+ def forward(
352
+ self,
353
+ query: Tensor,
354
+ key: Tensor,
355
+ value: Tensor,
356
+ pos_embedding: Tensor,
357
+ mask: Optional[Tensor] = None,
358
+ ) -> Tensor:
359
+ batch_size = value.size(0)
360
+
361
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
362
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
363
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
364
+ pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
365
+
366
+ content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
367
+ pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
368
+ pos_score = self._compute_relative_positional_encoding(pos_score)
369
+
370
+ score = (content_score + pos_score) / self.sqrt_dim
371
+
372
+ if mask is not None:
373
+ mask = mask.unsqueeze(1)
374
+ score.masked_fill_(mask, -1e9)
375
+
376
+ attn = F.softmax(score, -1)
377
+ attn = self.dropout(attn)
378
+
379
+ context = torch.matmul(attn, value).transpose(1, 2)
380
+ context = context.contiguous().view(batch_size, -1, self.d_model)
381
+
382
+ return self.out_proj(context)
383
+
384
+ def _compute_relative_positional_encoding(self, pos_score: Tensor) -> Tensor:
385
+ batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
386
+ zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
387
+ padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
388
+
389
+ padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
390
+ pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
391
+
392
+ return pos_score
393
+
394
+
395
+ class CustomizingAttention(nn.Module):
396
+ r"""
397
+ Customizing Attention
398
+
399
+ Applies a multi-head + location-aware attention mechanism on the output features from the decoder.
400
+ Multi-head attention proposed in "Attention Is All You Need" paper.
401
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
402
+ I combined these two attention mechanisms as custom.
403
+
404
+ Args:
405
+ hidden_dim (int): The number of expected features in the output
406
+ num_heads (int): The number of heads. (default: )
407
+ conv_out_channel (int): The dimension of convolution
408
+
409
+ Inputs: query, value, last_attn
410
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
411
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
412
+ - **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s alignment
413
+
414
+ Returns: output, attn
415
+ - **output** (batch, output_len, dimensions): tensor containing the attended output features from the decoder.
416
+ - **attn** (batch * num_heads, v_len): tensor containing the alignment from the encoder outputs.
417
+
418
+ Reference:
419
+ - **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
420
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
421
+ """
422
+
423
+ def __init__(self, hidden_dim: int, num_heads: int = 4, conv_out_channel: int = 10) -> None:
424
+ super(CustomizingAttention, self).__init__()
425
+ self.hidden_dim = hidden_dim
426
+ self.num_heads = num_heads
427
+ self.dim = int(hidden_dim / num_heads)
428
+ self.scaled_dot_attn = ScaledDotProductAttention(self.dim)
429
+ self.conv1d = nn.Conv1d(1, conv_out_channel, kernel_size=3, padding=1)
430
+ self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=True)
431
+ self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
432
+ self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
433
+ self.bias = nn.Parameter(torch.rand(self.dim * num_heads).uniform_(-0.1, 0.1))
434
+
435
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
436
+ batch_size, q_len, v_len = value.size(0), query.size(1), value.size(1)
437
+
438
+ if last_attn is None:
439
+ last_attn = value.new_zeros(batch_size * self.num_heads, v_len)
440
+
441
+ loc_energy = self.get_loc_energy(last_attn, batch_size, v_len) # get location energy
442
+
443
+ query = self.query_proj(query).view(batch_size, q_len, self.num_heads * self.dim)
444
+ value = self.value_proj(value).view(batch_size, v_len, self.num_heads * self.dim) + loc_energy + self.bias
445
+
446
+ query = query.view(batch_size, q_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
447
+ value = value.view(batch_size, v_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
448
+ query = query.contiguous().view(-1, q_len, self.dim)
449
+ value = value.contiguous().view(-1, v_len, self.dim)
450
+
451
+ context, attn = self.scaled_dot_attn(query, value)
452
+ attn = attn.squeeze()
453
+
454
+ context = context.view(self.num_heads, batch_size, q_len, self.dim).permute(1, 2, 0, 3)
455
+ context = context.contiguous().view(batch_size, q_len, -1)
456
+
457
+ return context, attn
458
+
459
+ def get_loc_energy(self, last_attn: Tensor, batch_size: int, v_len: int) -> Tensor:
460
+ conv_feat = self.conv1d(last_attn.unsqueeze(1))
461
+ conv_feat = conv_feat.view(batch_size, self.num_heads, -1, v_len).permute(0, 1, 3, 2)
462
+
463
+ loc_energy = self.loc_proj(conv_feat).view(batch_size, self.num_heads, v_len, self.dim)
464
+ loc_energy = loc_energy.permute(0, 2, 1, 3).reshape(batch_size, v_len, self.num_heads * self.dim)
465
+
466
+ return loc_energy
data_handling.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import model_handling
3
+ from transformers import PreTrainedTokenizerBase
4
+ from typing import Optional, Union, Any
5
+ from transformers.file_utils import PaddingStrategy
6
+ import re
7
+ import os
8
+ from tqdm import tqdm
9
+ # import time
10
+ import json
11
+ import random
12
+ import regtag
13
+ from dataclasses import dataclass
14
+ import validators
15
+
16
+ import utils
17
+
18
+ regexp = re.compile(r"\d{4}[\-/]\d{2}[\-/]\d{2}t\d{2}:\d{2}:\d{2}")
19
+ target_bias_words = set(regtag.get_general_en_word())
20
+ tokenizer = None
21
+
22
+
23
+ def get_bias_words():
24
+ regtag.augment.get_random_oov()
25
+ return list(regtag.augment.oov_dict.keys())
26
+
27
+
28
+ def check_common_phrase(word):
29
+ if validators.email(word.replace(' @', '@')):
30
+ return True
31
+ if validators.domain(word):
32
+ return True
33
+ if validators.url(word):
34
+ return True
35
+ if word in regtag.get_general_en_word():
36
+ return True
37
+ return False
38
+
39
+
40
+ @dataclass
41
+ class DataCollatorForNormSeq2Seq:
42
+ tokenizer: PreTrainedTokenizerBase
43
+ model: Optional[Any] = None
44
+ padding: Union[bool, str, PaddingStrategy] = True
45
+ max_length: Optional[int] = None
46
+ pad_to_multiple_of: Optional[int] = None
47
+ label_pad_token_id: int = -100
48
+ return_tensors: str = "pt"
49
+
50
+ def bias_phrases_extractor(self, features, max_bias_per_sample=30):
51
+ # src_ids, src_length, tgt_ids, tgt_length
52
+ phrase_candidate = []
53
+ sample_output_words = []
54
+ bias_labels = []
55
+
56
+ for sample in features:
57
+ words = []
58
+ for idx, (src_word_len, tgt_word_len) in enumerate(zip(sample['inputs_length'], sample['outputs_length'])):
59
+ src_start_idx = sum(sample['inputs_length'][:idx])
60
+ tgt_start_idx = sum(sample['outputs_length'][:idx])
61
+ word_input = self.tokenizer.decode(sample['input_ids'][src_start_idx: src_start_idx + src_word_len])
62
+ word_output = self.tokenizer.decode(sample['outputs'][tgt_start_idx: tgt_start_idx + tgt_word_len])
63
+ words.append(word_output)
64
+ if word_input != word_output and not any(map(str.isdigit, word_output)):
65
+ phrase_candidate.append(word_output)
66
+ sample_output_words.append(words)
67
+
68
+ phrase_candidate = list(set(phrase_candidate))
69
+ phrase_candidate_revised = []
70
+ phrase_candidate_common = []
71
+ raw_phrase_candidate = []
72
+ for item in phrase_candidate:
73
+ raw_item = self.tokenizer.sp_model.DecodePieces(item.split())
74
+ if check_common_phrase(raw_item):
75
+ phrase_candidate_common.append(raw_item)
76
+ else:
77
+ phrase_candidate_revised.append(item)
78
+ raw_phrase_candidate.append(raw_item)
79
+
80
+ remain_phrase = max(0, max_bias_per_sample * len(features) - len(phrase_candidate_revised))
81
+
82
+ if remain_phrase > 0:
83
+ words_candidate = list(
84
+ set(get_bias_words()) - set(raw_phrase_candidate))
85
+ random.shuffle(words_candidate)
86
+ phrase_candidate_revised += [' '.join(self.tokenizer.sp_model.EncodeAsPieces(item)[:5]) for item in
87
+ words_candidate[:remain_phrase]]
88
+
89
+ for i in range(len(features)):
90
+ sample_bias_lables = []
91
+ for w_idx, w in enumerate(sample_output_words[i]):
92
+ try:
93
+ sample_bias_lables.extend(
94
+ [phrase_candidate_revised.index(w) + 1] * features[i]['outputs_length'][w_idx])
95
+ except:
96
+ # random ignore 0 label
97
+ if random.random() < 0.5:
98
+ sample_bias_lables.extend([0] * features[i]['outputs_length'][w_idx])
99
+ else:
100
+ sample_bias_lables.extend([self.label_pad_token_id] * features[i]['outputs_length'][w_idx])
101
+ bias_labels.append(sample_bias_lables)
102
+ assert len(sample_bias_lables) == len(features[i]['outputs']), "{} vs {}".format(sample_bias_lables,
103
+ features[i]['outputs'])
104
+
105
+ # phrase_candidate_ids = [self.tokenizer.encode(item) for item in phrase_candidate]
106
+ phrase_candidate_ids = [self.tokenizer.encode(self.tokenizer.sp_model.DecodePieces(item.split())) for item in
107
+ phrase_candidate_revised]
108
+ phrase_candidate_mask = [[self.tokenizer.pad_token_id] * len(item) for item in phrase_candidate_ids]
109
+
110
+ return phrase_candidate_ids, phrase_candidate_mask, bias_labels
111
+ # pass
112
+
113
+ def encode_list_string(self, list_text):
114
+ text_tokenized = self.tokenizer(list_text)
115
+ return self.tokenizer.pad(
116
+ text_tokenized,
117
+ padding=self.padding,
118
+ max_length=self.max_length,
119
+ pad_to_multiple_of=self.pad_to_multiple_of,
120
+ return_tensors='pt',
121
+ )
122
+
123
+ def __call__(self, features, return_tensors=None):
124
+ # start_time = time.time()
125
+ batch_src, batch_tgt = [], []
126
+ for item in features:
127
+ src_spans, tgt_spans = utils.make_spoken(item['text'])
128
+ batch_src.append(src_spans)
129
+ batch_tgt.append(tgt_spans)
130
+ # print("Make src-tgt {}s".format(time.time() - start_time))
131
+ # start_time = time.time()
132
+
133
+ features = preprocess_function({"src": batch_src, "tgt": batch_tgt})
134
+
135
+
136
+ # print("Make feature {}s".format(time.time() - start_time))
137
+ # start_time = time.time()
138
+
139
+ phrase_candidate_ids, phrase_candidate_mask, samples_bias_labels = self.bias_phrases_extractor(features)
140
+ # print("Make bias {}s".format(time.time() - start_time))
141
+ # start_time = time.time()
142
+
143
+ if return_tensors is None:
144
+ return_tensors = self.return_tensors
145
+ labels = [feature["outputs"] for feature in features] if "outputs" in features[0].keys() else None
146
+ spoken_labels = [feature["spoken_label"] for feature in features] if "spoken_label" in features[0].keys() else None
147
+ spoken_idx = [feature["src_spoken_idx"] for feature in features] if "src_spoken_idx" in features[0].keys() else None
148
+
149
+ word_src_lengths = [feature["inputs_length"] for feature in features] if "inputs_length" in features[0].keys() else None
150
+ word_tgt_lengths = [feature["outputs_length"] for feature in features] if "outputs_length" in features[0].keys() else None
151
+ # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
152
+ # same length to return tensors.
153
+ if labels is not None:
154
+ max_label_length = max(len(l) for l in labels)
155
+ max_src_length = max(len(l) for l in spoken_labels)
156
+ max_spoken_idx_length = max(len(l) for l in spoken_idx)
157
+ max_word_src_length = max(len(l) for l in word_src_lengths)
158
+ max_word_tgt_length = max(len(l) for l in word_tgt_lengths)
159
+
160
+ padding_side = self.tokenizer.padding_side
161
+ for feature, bias_labels in zip(features, samples_bias_labels):
162
+ remainder = [self.label_pad_token_id] * (max_label_length - len(feature["outputs"]))
163
+ remainder_word_tgt_length = [0] * (max_word_tgt_length - len(feature["outputs_length"]))
164
+ remainder_spoken = [self.label_pad_token_id] * (max_src_length - len(feature["spoken_label"]))
165
+ remainder_spoken_idx = [self.label_pad_token_id] * (max_spoken_idx_length - len(feature["src_spoken_idx"]))
166
+ remainder_word_src_length = [0] * (max_word_src_length - len(feature["inputs_length"]))
167
+
168
+ feature["labels"] = (
169
+ feature["outputs"] + [
170
+ self.tokenizer.eos_token_id] + remainder if padding_side == "right" else remainder + feature[
171
+ "outputs"] + [self.tokenizer.eos_token_id]
172
+ )
173
+ feature["labels_bias"] = (
174
+ bias_labels + [0] + remainder if padding_side == "right" else remainder + bias_labels + [0]
175
+ )
176
+
177
+ feature["spoken_label"] = [self.label_pad_token_id] + feature["spoken_label"] + [self.label_pad_token_id]
178
+ feature["spoken_label"] = feature["spoken_label"] + remainder_spoken if padding_side == "right" else remainder_spoken + feature["spoken_label"]
179
+ feature["src_spoken_idx"] = feature["src_spoken_idx"] + remainder_spoken_idx
180
+
181
+ feature['inputs_length'] = [1] + feature['inputs_length'] + [1]
182
+ feature['outputs_length'] = feature['outputs_length'] + [1]
183
+
184
+ feature["inputs_length"] = feature["inputs_length"] + remainder_word_src_length
185
+ feature["outputs_length"] = feature["outputs_length"] + remainder_word_tgt_length
186
+
187
+
188
+ features_inputs = [{
189
+ "input_ids": [self.tokenizer.bos_token_id] + item["input_ids"] + [self.tokenizer.eos_token_id],
190
+ "attention_mask": [self.tokenizer.pad_token_id] + item["attention_mask"] + [self.tokenizer.pad_token_id]
191
+ } for item in features]
192
+ features_inputs = self.tokenizer.pad(
193
+ features_inputs,
194
+ padding=self.padding,
195
+ max_length=self.max_length,
196
+ pad_to_multiple_of=self.pad_to_multiple_of,
197
+ return_tensors=return_tensors,
198
+ )
199
+
200
+ bias_phrases_inputs = [{
201
+ "input_ids": ids,
202
+ "attention_mask": mask
203
+ } for ids, mask in zip(phrase_candidate_ids, phrase_candidate_mask)]
204
+ bias_phrases_inputs = self.tokenizer.pad(
205
+ bias_phrases_inputs,
206
+ padding=self.padding,
207
+ max_length=self.max_length,
208
+ pad_to_multiple_of=self.pad_to_multiple_of,
209
+ return_tensors=return_tensors,
210
+ )
211
+
212
+ outputs = self.tokenizer.pad({"input_ids": [feature["labels"] for feature in features]},
213
+ return_tensors=return_tensors)['input_ids']
214
+ outputs_bias = self.tokenizer.pad({"input_ids": [feature["labels_bias"] for feature in features]},
215
+ return_tensors=return_tensors)['input_ids']
216
+ spoken_label = self.tokenizer.pad({"input_ids": [feature["spoken_label"] for feature in features]},
217
+ return_tensors=return_tensors)['input_ids']
218
+ spoken_idx = self.tokenizer.pad({"input_ids": [feature["src_spoken_idx"] for feature in features]},
219
+ return_tensors=return_tensors)['input_ids'] + 1 # 1 for bos token
220
+ word_src_lengths = self.tokenizer.pad({"input_ids": [feature["inputs_length"] for feature in features]},
221
+ return_tensors=return_tensors)['input_ids']
222
+ word_tgt_lengths = self.tokenizer.pad({"input_ids": [feature["outputs_length"] for feature in features]},
223
+ return_tensors=return_tensors)['input_ids']
224
+
225
+ features = {
226
+ "input_ids": features_inputs["input_ids"],
227
+ "spoken_label": spoken_label,
228
+ "spoken_idx": spoken_idx,
229
+ "word_src_lengths": word_src_lengths,
230
+ "word_tgt_lengths": word_tgt_lengths,
231
+ "attention_mask": features_inputs["attention_mask"],
232
+ "bias_input_ids": bias_phrases_inputs["input_ids"],
233
+ "bias_attention_mask": bias_phrases_inputs["attention_mask"],
234
+ "labels": outputs,
235
+ "labels_bias": outputs_bias
236
+ }
237
+
238
+ # print("Make batch {}s".format(time.time() - start_time))
239
+ # start_time = time.time()
240
+
241
+ # prepare decoder_input_ids
242
+ if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"):
243
+ decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
244
+ features["decoder_input_ids"] = decoder_input_ids
245
+
246
+ return features
247
+
248
+
249
+ # data init
250
+ def init_data(train_corpus_path='./data-bin/raw/train_raw.txt',
251
+ test_corpus_path='./data-bin/raw/valid_raw.txt'):
252
+ dataset_oov = datasets.load_dataset('text', data_files={"train": train_corpus_path,
253
+ "test": test_corpus_path})
254
+
255
+ print(dataset_oov)
256
+ return dataset_oov
257
+
258
+
259
+ def preprocess_function(batch):
260
+
261
+ global tokenizer
262
+ if tokenizer is None:
263
+ tokenizer = model_handling.init_tokenizer()
264
+
265
+ features = []
266
+ for src_words, tgt_words in zip(batch["src"], batch["tgt"]):
267
+ src_ids, pad_ids, src_lengths, tgt_ids, tgt_lengths = [], [], [], [], []
268
+ src_spoken_label = [] # 0: "O", 1: "B", 2: "I"
269
+
270
+ src_spoken_idx = []
271
+ tgt_spoken_ids = []
272
+
273
+ for idx, (src, tgt) in enumerate(zip(src_words, tgt_words)):
274
+ is_remain = False
275
+ if src == tgt:
276
+ is_remain = True
277
+
278
+ src_tokenized = tokenizer(src)
279
+ if len(src_tokenized['input_ids']) < 3:
280
+ continue
281
+ # hardcode fix tokenizer email
282
+ if validators.email(tgt):
283
+ tgt_tokenized = tokenizer(tgt.replace('@', ' @'))
284
+ else:
285
+ tgt_tokenized = tokenizer(tgt)
286
+ if len(tgt_tokenized['input_ids']) < 3:
287
+ continue
288
+ src_ids.extend(src_tokenized["input_ids"][1:-1])
289
+ if is_remain:
290
+ src_spoken_label.extend([0 if random.random() < 0.5 else -100 for _ in range(len(src_tokenized["input_ids"][1:-1]))])
291
+ if random.random() < 0.1:
292
+ # Random pick normal word for spoken norm
293
+ src_spoken_idx.append(idx)
294
+ tgt_spoken_ids.append(tgt_tokenized["input_ids"][1:-1])
295
+ else:
296
+ src_spoken_label.extend([1] + [2] * (len(src_tokenized["input_ids"][1:-1]) - 1))
297
+ src_spoken_idx.append(idx)
298
+ tgt_spoken_ids.append(tgt_tokenized["input_ids"][1:-1])
299
+
300
+ pad_ids.extend(src_tokenized["attention_mask"][1:-1])
301
+ src_lengths.append(len(src_tokenized["input_ids"]) - 2)
302
+ tgt_ids.extend(tgt_tokenized["input_ids"][1:-1])
303
+ tgt_lengths.append(len(tgt_tokenized["input_ids"]) - 2)
304
+ if len(src_ids) > 70 or len(tgt_ids) > 70:
305
+ # print("Ignore sample")
306
+ break
307
+
308
+ if len(src_ids) < 1 or len(tgt_ids) < 1:
309
+ continue
310
+ # else:
311
+ # print("ignore")
312
+
313
+ features.append({
314
+ "input_ids": src_ids,
315
+ "attention_mask": pad_ids,
316
+ "spoken_label": src_spoken_label,
317
+ "inputs_length": src_lengths,
318
+ "outputs": tgt_ids,
319
+ "outputs_length": tgt_lengths,
320
+ "src_spoken_idx": src_spoken_idx,
321
+ "tgt_spoken_ids": tgt_spoken_ids
322
+ })
323
+
324
+ return features
325
+
326
+
327
+ if __name__ == "__main__":
328
+ split_datasets = init_data()
329
+
330
+ model, model_tokenizer = model_handling.init_model()
331
+ data_collator = DataCollatorForNormSeq2Seq(model_tokenizer, model=model)
332
+
333
+ # start = time.time()
334
+ batch = data_collator([split_datasets["train"][i] for i in [random.randint(0, 900) for _ in range(0, 12)]])
335
+ print(batch)
336
+ # print("{}s".format(time.time() - start))
infer.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import torch
4
+ import model_handling
5
+ from data_handling import DataCollatorForNormSeq2Seq
6
+ from model_handling import EncoderDecoderSpokenNorm
7
+ import os
8
+ import random
9
+ import data_handling
10
+ from transformers.generation_logits_process import LogitsProcessorList
11
+ from transformers.generation_stopping_criteria import StoppingCriteriaList
12
+ from transformers.generation_beam_search import BeamSearchScorer
13
+ from dataclasses import dataclass
14
+ from transformers.file_utils import ModelOutput
15
+ import utils
16
+
17
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "4"
18
+
19
+ use_gpu = False
20
+ if use_gpu:
21
+ if not torch.cuda.is_available():
22
+ use_gpu = False
23
+ tokenizer = model_handling.init_tokenizer()
24
+ model = EncoderDecoderSpokenNorm.from_pretrained('nguyenvulebinh/spoken-norm-taggen-v2').eval()
25
+ data_collator = DataCollatorForNormSeq2Seq(tokenizer)
26
+ if use_gpu:
27
+ model = model.cuda()
28
+
29
+
30
+ def make_batch_input(text_input_list):
31
+ batch_src_ids, batch_src_lengths = [], []
32
+ for text_input in text_input_list:
33
+ src_ids, src_lengths = [], []
34
+ for src in text_input.split():
35
+ src_tokenized = tokenizer(src)
36
+ ids = src_tokenized["input_ids"][1:-1]
37
+ src_ids.extend(ids)
38
+ src_lengths.append(len(ids))
39
+ src_ids = torch.tensor([0] + src_ids + [2])
40
+ src_lengths = torch.tensor([1] + src_lengths + [1]) + 1
41
+ batch_src_ids.append(src_ids)
42
+ batch_src_lengths.append(src_lengths)
43
+ assert sum(src_lengths - 1) == len(src_ids), "{} vs {}".format(sum(src_lengths), len(src_ids))
44
+ input_tokenized = tokenizer.pad({"input_ids": batch_src_ids}, padding=True)
45
+ input_word_length = tokenizer.pad({"input_ids": batch_src_lengths}, padding=True)["input_ids"] - 1
46
+ return input_tokenized['input_ids'], input_tokenized['attention_mask'], input_word_length
47
+
48
+
49
+ def make_batch_bias_list(bias_list):
50
+ if len(bias_list) > 0:
51
+ bias = data_collator.encode_list_string(bias_list)
52
+ bias_input_ids = bias['input_ids']
53
+ bias_attention_mask = bias['attention_mask']
54
+ else:
55
+ bias_input_ids = None
56
+ bias_attention_mask = None
57
+
58
+ return bias_input_ids, bias_attention_mask
59
+
60
+
61
+ def build_spoken_pronounce_mapping(bias_list):
62
+ list_pronounce = []
63
+ mapping = dict({})
64
+ for item in bias_list:
65
+ pronounces = item.split(' | ')[1:]
66
+ pronounces = [tokenizer(item)['input_ids'][1:-1] for item in pronounces]
67
+ list_pronounce.extend(pronounces)
68
+ subword_ids = list(set([item for sublist in list_pronounce for item in sublist]))
69
+ mapping = {item: [] for item in subword_ids}
70
+ for item in list_pronounce:
71
+ for wid in subword_ids:
72
+ if wid in item:
73
+ mapping[wid].append(item)
74
+ return mapping
75
+
76
+ def find_pivot(seq, subseq):
77
+ n = len(seq)
78
+ m = len(subseq)
79
+ result = []
80
+ for i in range(n - m + 1):
81
+ if seq[i] == subseq[0] and seq[i:i + m] == subseq:
82
+ result.append(i)
83
+ return result
84
+
85
+ def revise_spoken_tagging(list_tags, list_words, pronounce_mapping):
86
+ if len(pronounce_mapping) == 0:
87
+ return list_tags
88
+ result = []
89
+ for tags_tensor, sen in zip(list_tags, list_words):
90
+ tags = tags_tensor.detach().numpy().tolist()
91
+ sen = sen.detach().numpy().tolist()
92
+ candidate_pronounce = dict({})
93
+ for idx in range(len(tags)):
94
+ if tags[idx] != 0 and sen[idx] in pronounce_mapping:
95
+ for pronounce in pronounce_mapping[sen[idx]]:
96
+ pronounce_word = str(pronounce)
97
+ start_find_idx = max(0, idx - len(pronounce))
98
+ end_find_idx = idx + len(pronounce)
99
+ find_idx = find_pivot(sen[start_find_idx: end_find_idx], pronounce)
100
+ if len(find_idx) > 0:
101
+ find_idx = [item + start_find_idx for item in find_idx]
102
+ for map_idx in find_idx:
103
+ if candidate_pronounce.get(map_idx, None) is None:
104
+ candidate_pronounce[map_idx] = len(pronounce)
105
+ else:
106
+ candidate_pronounce[map_idx] = max(candidate_pronounce[map_idx], len(pronounce))
107
+ for idx, len_word in candidate_pronounce.items():
108
+ tags_tensor[idx] = 1
109
+ for i in range(1, len_word):
110
+ tags_tensor[idx + i] = 2
111
+ result.append(tags_tensor)
112
+ return result
113
+
114
+
115
+ def make_spoken_feature(input_features, text_input_list, pronounce_mapping=dict({})):
116
+ features = {
117
+ "input_ids": input_features[0],
118
+ "word_src_lengths": input_features[2],
119
+ "attention_mask": input_features[1],
120
+ # "bias_input_ids": bias_features[0],
121
+ # "bias_attention_mask": bias_features[1],
122
+ "bias_input_ids": None,
123
+ "bias_attention_mask": None,
124
+ }
125
+ if use_gpu:
126
+ for key in features.keys():
127
+ if features[key] is not None:
128
+ features[key] = features[key].cuda()
129
+
130
+ encoder_output = model.get_encoder()(**features)
131
+ spoken_tagging_output = torch.argmax(encoder_output[0].spoken_tagging_output, dim=-1)
132
+ spoken_tagging_output = revise_spoken_tagging(spoken_tagging_output, features['input_ids'], pronounce_mapping)
133
+
134
+ # print(spoken_tagging_output)
135
+ # print(features['input_ids'])
136
+ word_src_lengths = features['word_src_lengths']
137
+ encoder_features = encoder_output[0][0]
138
+ list_spoken_features = []
139
+ list_pre_norm = []
140
+ for tagging_sample, sample_word_length, text_input_features, sample_text in zip(spoken_tagging_output, word_src_lengths, encoder_features, text_input_list):
141
+ spoken_feature_idx = []
142
+ sample_words = ['<s>'] + sample_text.split() + ['</s>']
143
+ norm_words = []
144
+ spoken_phrase = []
145
+ spoken_features = []
146
+ if tagging_sample.sum() == 0:
147
+ list_pre_norm.append(sample_words)
148
+ continue
149
+ for idx, word_length in enumerate(sample_word_length):
150
+ if word_length > 0:
151
+ start = sample_word_length[:idx].sum()
152
+ end = start + word_length
153
+ if tagging_sample[start: end].sum() > 0 and sample_words[idx] not in ['<s>', '</s>']:
154
+ # Word has start tag
155
+ if (tagging_sample[start: end] == 1).sum():
156
+ if len(spoken_phrase) > 0:
157
+ norm_words.append('<mask>[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase)))
158
+ spoken_phrase = []
159
+ list_spoken_features.append(torch.cat(spoken_features))
160
+ spoken_features = []
161
+ spoken_phrase.append(sample_words[idx])
162
+ spoken_features.append(text_input_features[start: end])
163
+ else:
164
+ if len(spoken_phrase) > 0:
165
+ norm_words.append('<mask>[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase)))
166
+ spoken_phrase = []
167
+ list_spoken_features.append(torch.cat(spoken_features))
168
+ spoken_features = []
169
+ norm_words.append(sample_words[idx])
170
+ if len(spoken_phrase) > 0:
171
+ norm_words.append('<mask>[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase)))
172
+ spoken_phrase = []
173
+ list_spoken_features.append(torch.cat(spoken_features))
174
+ spoken_features = []
175
+ list_pre_norm.append(norm_words)
176
+
177
+
178
+ list_features_mask = []
179
+ if len(list_spoken_features) > 0:
180
+ feature_pad = torch.zeros_like(list_spoken_features[0][:1, :])
181
+ max_length = max([len(item) for item in list_spoken_features])
182
+ for i in range(len(list_spoken_features)):
183
+ spoken_length = len(list_spoken_features[i])
184
+ remain_length = max_length - spoken_length
185
+ device = list_spoken_features[i].device
186
+ list_spoken_features[i] = torch.cat([list_spoken_features[i],
187
+ feature_pad.expand(remain_length, feature_pad.size(-1))]).unsqueeze(0)
188
+ list_features_mask.append(torch.cat([torch.ones(spoken_length, device=device, dtype=torch.int64),
189
+ torch.zeros(remain_length, device=device, dtype=torch.int64)]).unsqueeze(0))
190
+ if len(list_spoken_features) > 0:
191
+ list_spoken_features = torch.cat(list_spoken_features)
192
+ list_features_mask = torch.cat(list_features_mask)
193
+
194
+ return list_spoken_features, list_features_mask, list_pre_norm
195
+
196
+
197
+ def make_bias_feature(bias_raw_features):
198
+ features = {
199
+ "bias_input_ids": bias_raw_features[0],
200
+ "bias_attention_mask": bias_raw_features[1]
201
+ }
202
+ if use_gpu:
203
+ for key in features.keys():
204
+ if features[key] is not None:
205
+ features[key] = features[key].cuda()
206
+ return model.forward_bias(**features)
207
+
208
+
209
+ def decode_plain_output(decoder_output):
210
+ plain_output = [item.split()[1:] for item in tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=False)]
211
+ scores = torch.stack(list(decoder_output['scores'])).transpose(1, 0)
212
+ logit_output = torch.gather(scores, -1, decoder_output['sequences'][:, 1:].unsqueeze(-1)).squeeze(-1)
213
+ special_tokens = list(tokenizer.special_tokens_map.values())
214
+ generated_output = []
215
+ generated_scores = []
216
+ # filter special tokens
217
+ for out_text, out_score in zip(plain_output, logit_output):
218
+ temp_str, tmp_score = [], []
219
+ for piece, score in zip(out_text, out_score):
220
+ if piece not in special_tokens:
221
+ temp_str.append(piece)
222
+ tmp_score.append(score)
223
+ if len(temp_str) > 0:
224
+ generated_output.append(' '.join(temp_str).replace('▁', '|').replace(' ', '').replace('|', ' ').strip())
225
+ generated_scores.append((sum(tmp_score)/len(tmp_score)).cpu().detach().numpy().tolist())
226
+ else:
227
+ generated_output.append("")
228
+ generated_scores.append(0)
229
+ return generated_output, generated_scores
230
+
231
+
232
+ def generate_spoken_norm(list_spoken_features, list_features_mask, bias_features):
233
+ @dataclass
234
+ class EncoderOutputs(ModelOutput):
235
+ last_hidden_state: torch.FloatTensor = None
236
+ hidden_states: torch.FloatTensor = None
237
+ attentions: torch.FloatTensor = None
238
+
239
+ batch_size = list_spoken_features.size(0)
240
+ max_length = 50
241
+ device = list_spoken_features.device
242
+ decoder_input_ids = torch.zeros((batch_size, 1), device=device, dtype=torch.int64)
243
+ stopping_criteria = model._get_stopping_criteria(max_length=max_length, max_time=None,
244
+ stopping_criteria=StoppingCriteriaList())
245
+ model_kwargs = {
246
+ "encoder_outputs": EncoderOutputs(last_hidden_state=list_spoken_features),
247
+ "encoder_bias_outputs": bias_features,
248
+ "attention_mask": list_features_mask
249
+ }
250
+ decoder_output = model.greedy_search(
251
+ decoder_input_ids,
252
+ logits_processor=LogitsProcessorList(),
253
+ stopping_criteria=stopping_criteria,
254
+ pad_token_id=tokenizer.pad_token_id,
255
+ eos_token_id=tokenizer.eos_token_id,
256
+ output_scores=True,
257
+ return_dict_in_generate=True,
258
+ **model_kwargs,
259
+ )
260
+ plain_output, plain_score = decode_plain_output(decoder_output)
261
+ # plain_output = tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=True)
262
+ # # print(decoder_output)
263
+ # plain_output = [word.replace('▁', '|').replace(' ', '').replace('|', ' ').strip() for word in plain_output]
264
+ return plain_output, plain_score
265
+
266
+
267
+ def generate_beam_spoken_norm(list_spoken_features, list_features_mask, bias_features, num_beams=3):
268
+ @dataclass
269
+ class EncoderOutputs(ModelOutput):
270
+ last_hidden_state: torch.FloatTensor = None
271
+
272
+ batch_size = list_spoken_features.size(0)
273
+ max_length = 50
274
+ num_return_sequences = 1
275
+ device = list_spoken_features.device
276
+ decoder_input_ids = torch.zeros((batch_size, 1), device=device, dtype=torch.int64)
277
+ stopping_criteria = model._get_stopping_criteria(max_length=max_length, max_time=None,
278
+ stopping_criteria=StoppingCriteriaList())
279
+ model_kwargs = {
280
+ "encoder_outputs": EncoderOutputs(last_hidden_state=list_spoken_features),
281
+ "encoder_bias_outputs": bias_features,
282
+ "attention_mask": list_features_mask
283
+ }
284
+ beam_scorer = BeamSearchScorer(
285
+ batch_size=batch_size,
286
+ num_beams=num_beams,
287
+ device=device,
288
+ do_early_stopping=True,
289
+ num_beam_hyps_to_keep=num_return_sequences,
290
+ )
291
+ decoder_input_ids, model_kwargs = model._expand_inputs_for_generation(
292
+ decoder_input_ids, expand_size=num_beams, is_encoder_decoder=True, **model_kwargs
293
+ )
294
+
295
+ decoder_output = model.beam_search(
296
+ decoder_input_ids,
297
+ beam_scorer,
298
+ logits_processor=LogitsProcessorList(),
299
+ stopping_criteria=stopping_criteria,
300
+ pad_token_id=tokenizer.pad_token_id,
301
+ eos_token_id=tokenizer.eos_token_id,
302
+ output_scores=None,
303
+ return_dict_in_generate=True,
304
+ **model_kwargs,
305
+ )
306
+
307
+ plain_output = tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=True)
308
+ plain_output = [word.replace('▁', '|').replace(' ', '').replace('|', ' ').strip() for word in plain_output]
309
+ return plain_output, None
310
+
311
+
312
+ def reformat_normed_term(list_pre_norm, spoken_norm_output, spoken_norm_output_score=None, threshold=None, debug=False):
313
+ output = []
314
+ for pre_norm in list_pre_norm:
315
+ normed_words = []
316
+ # words = pre_norm.split()
317
+ for w in pre_norm:
318
+ if w.startswith('<mask>'):
319
+ term = w[7:].split('](')
320
+ # print(w)
321
+ # print(term)
322
+ term_idx = int(term[0])
323
+ norm_val = spoken_norm_output[term_idx]
324
+ norm_val_score = None if (spoken_norm_output_score is None or threshold is None) else spoken_norm_output_score[term_idx]
325
+ pre_norm_val = term[1][:-1]
326
+ if debug:
327
+ if norm_val_score is not None:
328
+ normed_words.append("({})({:.2f})[{}]".format(norm_val, norm_val_score, pre_norm_val))
329
+ else:
330
+ normed_words.append("({})[{}]".format(norm_val, pre_norm_val))
331
+ else:
332
+ if threshold is not None and norm_val_score is not None:
333
+ if norm_val_score > threshold:
334
+ normed_words.append(norm_val)
335
+ else:
336
+ normed_words.append(pre_norm_val)
337
+ else:
338
+ normed_words.append(norm_val)
339
+ else:
340
+ normed_words.append(w)
341
+ output.append(" ".join(normed_words))
342
+ return output
343
+
344
+
345
+ def infer(text_input_list, bias_list):
346
+ # extract bias feature
347
+ bias_raw_features = make_batch_bias_list(bias_list)
348
+ bias_features = make_bias_feature(bias_raw_features)
349
+ pronounce_mapping = build_spoken_pronounce_mapping(bias_list)
350
+
351
+ # Chunk split input and create feature
352
+ text_input_chunk_list = [utils.split_chunk_input(item, chunk_size=60, overlap=20) for item in text_input_list]
353
+ num_chunks = [len(i) for i in text_input_chunk_list]
354
+ flatten_list = [y for x in text_input_chunk_list for y in x]
355
+ input_raw_features = make_batch_input(flatten_list)
356
+
357
+ # Extract norm term and spoken feature
358
+ list_spoken_features, list_features_mask, list_pre_norm = make_spoken_feature(input_raw_features, flatten_list, pronounce_mapping)
359
+
360
+ # Merge overlap chunks
361
+ list_pre_norm_by_input = []
362
+ for idx, input_num in enumerate(num_chunks):
363
+ start = sum(num_chunks[:idx])
364
+ end = start + num_chunks[idx]
365
+ list_pre_norm_by_input.append(list_pre_norm[start:end])
366
+ text_input_list_pre_norm = [utils.merge_chunk_pre_norm(list_chunks, overlap=20, debug=False) for list_chunks in list_pre_norm_by_input]
367
+
368
+ if len(list_spoken_features) > 0:
369
+ spoken_norm_output, spoken_norm_score = generate_spoken_norm(list_spoken_features, list_features_mask, bias_features)
370
+ else:
371
+ spoken_norm_output, spoken_norm_score = [], None
372
+
373
+ return reformat_normed_term(text_input_list_pre_norm, spoken_norm_output, spoken_norm_score, threshold=15, debug=False)
374
+
model_config_handling.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import copy
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers import BertConfig
21
+ from transformers.utils import logging
22
+ # from model_handling import DecoderSpokenNorm
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class DecoderSpokenNormConfig(BertConfig):
29
+ # model_type = "decoder-spoken-norm"
30
+
31
+ def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
32
+ """Constructs RobertaConfig."""
33
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
34
+ self.num_hidden_layers=2
35
+ # self.hidden_layers_from_pretrained = list(range(self.num_hidden_layers))
36
+ # self.hidden_layers_from_pretrained = [0, 3]
37
+
38
+ # if len(self.hidden_layers_from_pretrained) < self.num_hidden_layers:
39
+ # self.num_hidden_layers = len(self.hidden_layers_from_pretrained)
40
+
41
+
42
+ class EncoderDecoderSpokenNormConfig(PretrainedConfig):
43
+ # model_type = "encoder-decoder-spoken-norm"
44
+ is_composition = True
45
+
46
+ def __init__(self, **kwargs):
47
+ super().__init__(**kwargs)
48
+ assert (
49
+ "encoder" in kwargs and "decoder" in kwargs
50
+ ), "Config has to be initialized with encoder and decoder config"
51
+ encoder_config = kwargs.pop("encoder")
52
+ encoder_model_type = encoder_config.pop("model_type")
53
+ decoder_config = kwargs.pop("decoder")
54
+ decoder_model_type = decoder_config.pop("model_type")
55
+
56
+ from transformers.models.auto.configuration_auto import AutoConfig
57
+
58
+ self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
59
+ self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
60
+ self.is_encoder_decoder = True
61
+
62
+ @classmethod
63
+ def from_encoder_decoder_configs(
64
+ cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
65
+ ) -> PretrainedConfig:
66
+ r"""
67
+ Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model
68
+ configuration and decoder model configuration.
69
+
70
+ Returns:
71
+ :class:`EncoderDecoderConfig`: An instance of a configuration object
72
+ """
73
+ logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
74
+ decoder_config.is_decoder = True
75
+ decoder_config.add_cross_attention = True
76
+
77
+ return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
78
+
79
+ def to_dict(self):
80
+ """
81
+ Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
82
+
83
+ Returns:
84
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
85
+ """
86
+ output = copy.deepcopy(self.__dict__)
87
+ output["encoder"] = self.encoder.to_dict()
88
+ output["decoder"] = self.decoder.to_dict()
89
+ output["model_type"] = self.__class__.model_type
90
+ return output
model_handling.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.file_utils import cached_path, hf_bucket_url
2
+ from importlib.machinery import SourceFileLoader
3
+ import os
4
+ from transformers import EncoderDecoderModel, AutoConfig, AutoModel, EncoderDecoderConfig, RobertaForCausalLM, \
5
+ RobertaModel
6
+ from transformers.modeling_utils import PreTrainedModel, logging
7
+ import torch
8
+ from torch.nn import CrossEntropyLoss, Parameter
9
+ from transformers.modeling_outputs import Seq2SeqLMOutput, CausalLMOutputWithCrossAttentions, \
10
+ ModelOutput
11
+ from attentions import ScaledDotProductAttention, MultiHeadAttention
12
+ from collections import namedtuple
13
+ from typing import Dict, Any, Optional, Tuple
14
+ from dataclasses import dataclass
15
+ import random
16
+ from model_config_handling import EncoderDecoderSpokenNormConfig, DecoderSpokenNormConfig, PretrainedConfig
17
+
18
+ cache_dir = './cache'
19
+ model_name = 'nguyenvulebinh/envibert'
20
+
21
+ if not os.path.exists(cache_dir):
22
+ os.makedirs(cache_dir)
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ @dataclass
27
+ class SpokenNormOutput(ModelOutput):
28
+ loss: Optional[torch.FloatTensor] = None
29
+ logits: torch.FloatTensor = None
30
+ logits_spoken_tagging: torch.FloatTensor = None
31
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
32
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
33
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
34
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
35
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
36
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
37
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
38
+
39
+
40
+
41
+
42
+ def collect_spoken_phrases_features(encoder_hidden_states, word_src_lengths, spoken_label):
43
+ list_features = []
44
+ list_features_mask = []
45
+ max_length = word_src_lengths.max()
46
+ feature_pad = torch.zeros_like(encoder_hidden_states[0, :1, :])
47
+ for hidden_state, word_length, list_idx in zip(encoder_hidden_states, word_src_lengths, spoken_label):
48
+ for idx in list_idx:
49
+ if idx > 0:
50
+ start = sum(word_length[:idx])
51
+ end = start + word_length[idx]
52
+ remain_length = max_length - word_length[idx]
53
+ list_features_mask.append(torch.cat([torch.ones_like(spoken_label[0, 0]).expand(word_length[idx]),
54
+ torch.zeros_like(
55
+ spoken_label[0, 0].expand(remain_length))]).unsqueeze(0))
56
+ spoken_phrases_feature = hidden_state[start: end]
57
+
58
+ list_features.append(torch.cat([spoken_phrases_feature,
59
+ feature_pad.expand(remain_length, feature_pad.size(-1))]).unsqueeze(0))
60
+ return torch.cat(list_features), torch.cat(list_features_mask)
61
+
62
+
63
+ def collect_spoken_phrases_labels(decoder_input_ids, labels, labels_bias, word_tgt_lengths, spoken_idx):
64
+ list_decoder_input_ids = []
65
+ list_labels = []
66
+ list_labels_bias = []
67
+ max_length = word_tgt_lengths.max()
68
+ init_decoder_ids = torch.tensor([0], device=labels.device, dtype=labels.dtype)
69
+ pad_decoder_ids = torch.tensor([1], device=labels.device, dtype=labels.dtype)
70
+ eos_decoder_ids = torch.tensor([2], device=labels.device, dtype=labels.dtype)
71
+ none_labels_bias = torch.tensor([0], device=labels.device, dtype=labels.dtype)
72
+ ignore_labels_bias = torch.tensor([-100], device=labels.device, dtype=labels.dtype)
73
+
74
+ for decoder_inputs, decoder_label, decoder_label_bias, word_length, list_idx in zip(decoder_input_ids,
75
+ labels, labels_bias,
76
+ word_tgt_lengths, spoken_idx):
77
+ for idx in list_idx:
78
+ if idx > 0:
79
+ start = sum(word_length[:idx - 1])
80
+ end = start + word_length[idx - 1]
81
+ remain_length = max_length - word_length[idx - 1]
82
+ remain_decoder_input_ids = max_length - len(decoder_inputs[start + 1:end + 1])
83
+ list_decoder_input_ids.append(torch.cat([init_decoder_ids,
84
+ decoder_inputs[start + 1:end + 1],
85
+ pad_decoder_ids.expand(remain_decoder_input_ids)]).unsqueeze(0))
86
+ list_labels.append(torch.cat([decoder_label[start:end],
87
+ eos_decoder_ids,
88
+ ignore_labels_bias.expand(remain_length)]).unsqueeze(0))
89
+ list_labels_bias.append(torch.cat([decoder_label_bias[start:end],
90
+ none_labels_bias,
91
+ ignore_labels_bias.expand(remain_length)]).unsqueeze(0))
92
+
93
+ decoder_input_ids = torch.cat(list_decoder_input_ids)
94
+ labels = torch.cat(list_labels)
95
+ labels_bias = torch.cat(list_labels_bias)
96
+
97
+ return decoder_input_ids, labels, labels_bias
98
+
99
+
100
+ class EncoderDecoderSpokenNorm(EncoderDecoderModel):
101
+ config_class = EncoderDecoderSpokenNormConfig
102
+
103
+ def __init__(
104
+ self,
105
+ config: Optional[PretrainedConfig] = None,
106
+ encoder: Optional[PreTrainedModel] = None,
107
+ decoder: Optional[PreTrainedModel] = None,
108
+ ):
109
+ if config is None and (encoder is None or decoder is None):
110
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
111
+ if config is None:
112
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
113
+ else:
114
+ if not isinstance(config, self.config_class):
115
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
116
+
117
+ if config.decoder.cross_attention_hidden_size is not None:
118
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
119
+ raise ValueError(
120
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
121
+ "it has to be equal to the encoder's `hidden_size`. "
122
+ f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
123
+ f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
124
+ )
125
+
126
+ # initialize with config
127
+ super().__init__(config)
128
+
129
+ if encoder is None:
130
+ from transformers.models.auto.modeling_auto import AutoModel
131
+
132
+ encoder = AutoModel.from_config(config.encoder)
133
+
134
+ if decoder is None:
135
+ # from transformers.models.auto.modeling_auto import AutoModelForCausalLM
136
+
137
+ decoder = DecoderSpokenNorm._from_config(config.decoder)
138
+
139
+ self.encoder = encoder
140
+ self.decoder = decoder
141
+
142
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
143
+ logger.warning(
144
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
145
+ )
146
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
147
+ logger.warning(
148
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
149
+ )
150
+
151
+ # make sure that the individual model's config refers to the shared config
152
+ # so that the updates to the config will be synced
153
+ self.encoder.config = self.config.encoder
154
+ self.decoder.config = self.config.decoder
155
+
156
+ # encoder outputs might need to be projected to different dimension for decoder
157
+ if (
158
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
159
+ and self.decoder.config.cross_attention_hidden_size is None
160
+ ):
161
+ self.enc_to_dec_proj = torch.nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
162
+
163
+ if self.encoder.get_output_embeddings() is not None:
164
+ raise ValueError(
165
+ f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
166
+ )
167
+
168
+ # spoken tagging
169
+ self.dropout = torch.nn.Dropout(0.3)
170
+ # 0: "O", 1: "B", 2: "I"
171
+ self.spoken_tagging_classifier = torch.nn.Linear(config.encoder.hidden_size, 3)
172
+
173
+ # tie encoder, decoder weights if config set accordingly
174
+ self.tie_weights()
175
+
176
+ @classmethod
177
+ def from_encoder_decoder_pretrained(
178
+ cls,
179
+ encoder_pretrained_model_name_or_path: str = None,
180
+ decoder_pretrained_model_name_or_path: str = None,
181
+ *model_args,
182
+ **kwargs
183
+ ) -> PreTrainedModel:
184
+
185
+ kwargs_encoder = {
186
+ argument[len("encoder_"):]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
187
+ }
188
+
189
+ kwargs_decoder = {
190
+ argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
191
+ }
192
+
193
+ # remove encoder, decoder kwargs from kwargs
194
+ for key in kwargs_encoder.keys():
195
+ del kwargs["encoder_" + key]
196
+ for key in kwargs_decoder.keys():
197
+ del kwargs["decoder_" + key]
198
+
199
+ # Load and initialize the encoder and decoder
200
+ # The distinction between encoder and decoder at the model level is made
201
+ # by the value of the flag `is_decoder` that we need to set correctly.
202
+ encoder = kwargs_encoder.pop("model", None)
203
+ if encoder is None:
204
+ if encoder_pretrained_model_name_or_path is None:
205
+ raise ValueError(
206
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
207
+ "to be defined."
208
+ )
209
+
210
+ if "config" not in kwargs_encoder:
211
+ encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
212
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
213
+ logger.info(
214
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
215
+ "from a decoder model. Cross-attention and casual mask are disabled."
216
+ )
217
+ encoder_config.is_decoder = False
218
+ encoder_config.add_cross_attention = False
219
+
220
+ kwargs_encoder["config"] = encoder_config
221
+
222
+ encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args,
223
+ **kwargs_encoder)
224
+
225
+ decoder = kwargs_decoder.pop("model", None)
226
+ if decoder is None:
227
+ if decoder_pretrained_model_name_or_path is None:
228
+ raise ValueError(
229
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
230
+ "to be defined."
231
+ )
232
+
233
+ if "config" not in kwargs_decoder:
234
+ decoder_config = DecoderSpokenNormConfig.from_pretrained(decoder_pretrained_model_name_or_path)
235
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
236
+ logger.info(
237
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
238
+ f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
239
+ f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
240
+ "cross attention layers."
241
+ )
242
+ decoder_config.is_decoder = True
243
+ decoder_config.add_cross_attention = True
244
+
245
+ kwargs_decoder["config"] = decoder_config
246
+
247
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
248
+ logger.warning(
249
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
250
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
251
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
252
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
253
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
254
+ )
255
+
256
+ decoder = DecoderSpokenNorm.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
257
+
258
+ # instantiate config with corresponding kwargs
259
+ config = EncoderDecoderSpokenNormConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
260
+ return cls(encoder=encoder, decoder=decoder, config=config)
261
+
262
+ def get_encoder(self):
263
+ def forward(input_ids=None,
264
+ attention_mask=None,
265
+ bias_input_ids=None,
266
+ bias_attention_mask=None,
267
+ return_dict=True,
268
+ output_attentions=False,
269
+ output_hidden_states=False,
270
+ word_src_lengths=None,
271
+ spoken_idx=None,
272
+ **kwargs_encoder):
273
+ encoder_outputs = self.encoder(
274
+ input_ids=input_ids,
275
+ attention_mask=attention_mask,
276
+ inputs_embeds=None,
277
+ output_attentions=output_attentions,
278
+ output_hidden_states=output_hidden_states,
279
+ return_dict=return_dict,
280
+ **kwargs_encoder,
281
+ )
282
+ encoder_outputs.word_src_lengths = word_src_lengths
283
+ encoder_outputs.spoken_tagging_output = self.spoken_tagging_classifier(self.dropout(encoder_outputs[0]))
284
+ if spoken_idx is not None:
285
+ encoder_outputs.spoken_idx = spoken_idx
286
+ else:
287
+ pass
288
+
289
+ encoder_bias_outputs = self.forward_bias(bias_input_ids,
290
+ bias_attention_mask,
291
+ output_attentions=output_attentions,
292
+ return_dict=return_dict,
293
+ output_hidden_states=output_hidden_states,
294
+ **kwargs_encoder)
295
+ # d = {
296
+ # "encoder_bias_outputs": None,
297
+ # "bias_attention_mask": None,
298
+ # "last_hidden_state": None,
299
+ # "pooler_output": None
300
+ #
301
+ # }
302
+ # encoder_bias_outputs = namedtuple('Struct', d.keys())(*d.values())
303
+ # if bias_input_ids is not None:
304
+ # encoder_bias_outputs = self.encoder(
305
+ # input_ids=bias_input_ids,
306
+ # attention_mask=bias_attention_mask,
307
+ # inputs_embeds=None,
308
+ # output_attentions=output_attentions,
309
+ # output_hidden_states=output_hidden_states,
310
+ # return_dict=return_dict,
311
+ # **kwargs_encoder,
312
+ # )
313
+ # encoder_bias_outputs.bias_attention_mask = bias_attention_mask
314
+ return encoder_outputs, encoder_bias_outputs
315
+
316
+ return forward
317
+
318
+ def forward_bias(self,
319
+ bias_input_ids,
320
+ bias_attention_mask,
321
+ output_attentions=False,
322
+ return_dict=True,
323
+ output_hidden_states=False,
324
+ **kwargs_encoder):
325
+ d = {
326
+ "encoder_bias_outputs": None,
327
+ "bias_attention_mask": None,
328
+ "last_hidden_state": None,
329
+ "pooler_output": None
330
+
331
+ }
332
+ encoder_bias_outputs = namedtuple('Struct', d.keys())(*d.values())
333
+ if bias_input_ids is not None:
334
+ encoder_bias_outputs = self.encoder(
335
+ input_ids=bias_input_ids,
336
+ attention_mask=bias_attention_mask,
337
+ inputs_embeds=None,
338
+ output_attentions=output_attentions,
339
+ output_hidden_states=output_hidden_states,
340
+ return_dict=return_dict,
341
+ **kwargs_encoder,
342
+ )
343
+ encoder_bias_outputs.bias_attention_mask = bias_attention_mask
344
+ return encoder_bias_outputs
345
+
346
+ def _prepare_encoder_decoder_kwargs_for_generation(
347
+ self, input_ids: torch.LongTensor, model_kwargs, model_input_name
348
+ ) -> Dict[str, Any]:
349
+ if "encoder_outputs" not in model_kwargs:
350
+ # retrieve encoder hidden states
351
+ encoder = self.get_encoder()
352
+ encoder_kwargs = {
353
+ argument: value
354
+ for argument, value in model_kwargs.items()
355
+ if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
356
+ }
357
+ encoder_outputs, encoder_bias_outputs = encoder(input_ids, return_dict=True, **encoder_kwargs)
358
+ model_kwargs["encoder_outputs"]: ModelOutput = encoder_outputs
359
+ model_kwargs["encoder_bias_outputs"]: ModelOutput = encoder_bias_outputs
360
+
361
+ return model_kwargs
362
+
363
+ def _prepare_decoder_input_ids_for_generation(
364
+ self,
365
+ batch_size: int,
366
+ decoder_start_token_id: int = None,
367
+ bos_token_id: int = None,
368
+ model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
369
+ ) -> torch.LongTensor:
370
+
371
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
372
+ return model_kwargs.pop("decoder_input_ids")
373
+ else:
374
+ decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
375
+ num_spoken_phrases = (model_kwargs['encoder_outputs'].spoken_idx >= 0).view(-1).sum()
376
+ return torch.ones((num_spoken_phrases, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
377
+
378
+ def prepare_inputs_for_generation(
379
+ self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
380
+ ):
381
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
382
+ decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
383
+ input_dict = {
384
+ "attention_mask": attention_mask,
385
+ "decoder_attention_mask": decoder_attention_mask,
386
+ "decoder_input_ids": decoder_inputs["input_ids"],
387
+ "encoder_outputs": encoder_outputs,
388
+ "encoder_bias_outputs": kwargs["encoder_bias_outputs"],
389
+ "past_key_values": decoder_inputs["past_key_values"],
390
+ "use_cache": use_cache,
391
+ }
392
+ return input_dict
393
+
394
+ def forward(
395
+ self,
396
+ input_ids=None,
397
+ attention_mask=None,
398
+ decoder_input_ids=None,
399
+ bias_input_ids=None,
400
+ bias_attention_mask=None,
401
+ labels_bias=None,
402
+ decoder_attention_mask=None,
403
+ encoder_outputs=None,
404
+ encoder_bias_outputs=None,
405
+ past_key_values=None,
406
+ inputs_embeds=None,
407
+ decoder_inputs_embeds=None,
408
+ labels=None,
409
+ use_cache=None,
410
+ spoken_label=None,
411
+ word_src_lengths=None,
412
+ word_tgt_lengths=None,
413
+ spoken_idx=None,
414
+ output_attentions=None,
415
+ output_hidden_states=None,
416
+ return_dict=None,
417
+ inputs_length=None,
418
+ outputs=None,
419
+ outputs_length=None,
420
+ text=None,
421
+ **kwargs,
422
+ ):
423
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
424
+
425
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
426
+
427
+ kwargs_decoder = {
428
+ argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
429
+ }
430
+ spoken_tagging_output = None
431
+ if encoder_outputs is None:
432
+ encoder_outputs = self.encoder(
433
+ input_ids=input_ids,
434
+ attention_mask=attention_mask,
435
+ inputs_embeds=inputs_embeds,
436
+ output_attentions=output_attentions,
437
+ output_hidden_states=output_hidden_states,
438
+ return_dict=return_dict,
439
+ **kwargs_encoder,
440
+ )
441
+ spoken_tagging_output = self.spoken_tagging_classifier(self.dropout(encoder_outputs[0]))
442
+ # else:
443
+ # word_src_lengths = encoder_outputs.word_src_lengths
444
+ # spoken_tagging_output = encoder_outputs.spoken_tagging_output
445
+
446
+ if encoder_bias_outputs is None:
447
+ encoder_bias_outputs = self.encoder(
448
+ input_ids=bias_input_ids,
449
+ attention_mask=bias_attention_mask,
450
+ inputs_embeds=inputs_embeds,
451
+ output_attentions=output_attentions,
452
+ output_hidden_states=output_hidden_states,
453
+ return_dict=return_dict,
454
+ **kwargs_encoder,
455
+ )
456
+ encoder_bias_outputs.bias_attention_mask = bias_attention_mask
457
+
458
+ encoder_hidden_states = encoder_outputs[0]
459
+
460
+ # if spoken_idx is None:
461
+ # # extract spoken_idx from spoken_tagging_output
462
+ # spoken_idx = None
463
+
464
+ # encoder_hidden_states, attention_mask = collect_spoken_phrases_features(encoder_hidden_states,
465
+ # word_src_lengths,
466
+ # spoken_idx)
467
+ # if labels is not None:
468
+ # decoder_input_ids, labels, labels_bias = collect_spoken_phrases_labels(decoder_input_ids,
469
+ # labels, labels_bias,
470
+ # word_tgt_lengths,
471
+ # spoken_idx)
472
+
473
+ if spoken_idx is not None:
474
+ encoder_hidden_states, attention_mask = collect_spoken_phrases_features(encoder_hidden_states,
475
+ word_src_lengths,
476
+ spoken_idx)
477
+
478
+ decoder_input_ids, labels, labels_bias = collect_spoken_phrases_labels(decoder_input_ids,
479
+ labels, labels_bias,
480
+ word_tgt_lengths,
481
+ spoken_idx)
482
+
483
+
484
+ # optionally project encoder_hidden_states
485
+ if (
486
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
487
+ and self.decoder.config.cross_attention_hidden_size is None
488
+ ):
489
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
490
+
491
+ # Decode
492
+ decoder_outputs = self.decoder(
493
+ input_ids=decoder_input_ids,
494
+ attention_mask=decoder_attention_mask,
495
+ encoder_hidden_states=encoder_hidden_states,
496
+ encoder_bias_pooling=encoder_bias_outputs.pooler_output,
497
+ # encoder_bias_hidden_states=encoder_bias_outputs[0],
498
+ encoder_bias_hidden_states=encoder_bias_outputs.last_hidden_state,
499
+ bias_attention_mask=encoder_bias_outputs.bias_attention_mask,
500
+ encoder_attention_mask=attention_mask,
501
+ inputs_embeds=decoder_inputs_embeds,
502
+ output_attentions=output_attentions,
503
+ output_hidden_states=output_hidden_states,
504
+ use_cache=use_cache,
505
+ past_key_values=past_key_values,
506
+ return_dict=return_dict,
507
+ labels_bias=labels_bias,
508
+ **kwargs_decoder,
509
+ )
510
+
511
+ # Compute loss independent from decoder (as some shift the logits inside them)
512
+ loss = None
513
+ if labels is not None:
514
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
515
+ loss_fct = CrossEntropyLoss()
516
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
517
+ loss = loss + decoder_outputs.loss
518
+
519
+ if spoken_label is not None:
520
+ loss_fct = CrossEntropyLoss()
521
+ spoken_tagging_loss = loss_fct(spoken_tagging_output.reshape(-1, 3), spoken_label.view(-1))
522
+ loss = loss + spoken_tagging_loss
523
+
524
+ if not return_dict:
525
+ if loss is not None:
526
+ return (loss,) + decoder_outputs + encoder_outputs
527
+ else:
528
+ return decoder_outputs + encoder_outputs
529
+
530
+ return SpokenNormOutput(
531
+ loss=loss,
532
+ logits=decoder_outputs.logits,
533
+ logits_spoken_tagging=spoken_tagging_output,
534
+ past_key_values=decoder_outputs.past_key_values,
535
+ decoder_hidden_states=decoder_outputs.hidden_states,
536
+ decoder_attentions=decoder_outputs.attentions,
537
+ cross_attentions=decoder_outputs.cross_attentions,
538
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
539
+ encoder_hidden_states=encoder_outputs.hidden_states,
540
+ encoder_attentions=encoder_outputs.attentions,
541
+ )
542
+
543
+
544
+ class DecoderSpokenNorm(RobertaForCausalLM):
545
+ config_class = DecoderSpokenNormConfig
546
+
547
+ # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
548
+ def __init__(self, config):
549
+ super().__init__(config)
550
+ self.dense_query_copy = torch.nn.Linear(config.hidden_size, config.hidden_size)
551
+ self.mem_no_entry = Parameter(torch.randn(config.hidden_size).unsqueeze(0))
552
+ self.bias_attention_layer = MultiHeadAttention(config.hidden_size)
553
+ self.copy_attention_layer = MultiHeadAttention(config.hidden_size)
554
+
555
+ def forward_bias_attention(self, query, values, values_mask):
556
+ """
557
+ :param query: batch * output_steps * hidden_state
558
+ :param values: batch * output_steps * max_bias_steps * hidden_state
559
+ :param values_mask: batch * output_steps * max_bias_steps
560
+ :return: batch * output_steps * hidden_state
561
+ """
562
+ batch, output_steps, hidden_state = query.size()
563
+ _, _, max_bias_steps, _ = values.size()
564
+
565
+ query = query.view(batch * output_steps, 1, hidden_state)
566
+ values = values.view(-1, max_bias_steps, hidden_state)
567
+ values_mask = 1 - values_mask.view(-1, max_bias_steps)
568
+ result_attention, attention_score = self.bias_attention_layer(query=query,
569
+ key=values,
570
+ value=values,
571
+ mask=values_mask.bool())
572
+ result_attention = result_attention.squeeze(1).view(batch, output_steps, hidden_state)
573
+ return result_attention
574
+
575
+ def forward_copy_attention(self, query, values, values_mask):
576
+ """
577
+ :param query: batch * output_steps * hidden_state
578
+ :param values: batch * max_encoder_steps * hidden_state
579
+ :param values_mask: batch * output_steps * max_encoder_steps
580
+ :return: batch * output_steps * hidden_state
581
+ """
582
+ dot_attn_score = torch.bmm(query, values.transpose(2, 1))
583
+ attn_mask = (1 - values_mask.clone().unsqueeze(1)).bool()
584
+ dot_attn_score.masked_fill_(attn_mask, -float('inf'))
585
+ dot_attn_score = torch.softmax(dot_attn_score, dim=-1)
586
+ result_attention = torch.bmm(dot_attn_score, values)
587
+ return result_attention
588
+
589
+ def forward(
590
+ self,
591
+ input_ids=None,
592
+ attention_mask=None,
593
+ token_type_ids=None,
594
+ position_ids=None,
595
+ head_mask=None,
596
+ encoder_bias_pooling=None,
597
+ encoder_bias_hidden_states=None,
598
+ bias_attention_mask=None,
599
+ inputs_embeds=None,
600
+ encoder_hidden_states=None,
601
+ encoder_attention_mask=None,
602
+ labels=None,
603
+ labels_bias=None,
604
+ past_key_values=None,
605
+ use_cache=None,
606
+ output_attentions=None,
607
+ output_hidden_states=None,
608
+ return_dict=None,
609
+ ):
610
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
611
+ if labels is not None:
612
+ use_cache = False
613
+
614
+ # attention with input encoded
615
+ outputs = self.roberta(
616
+ input_ids,
617
+ attention_mask=attention_mask,
618
+ token_type_ids=token_type_ids,
619
+ position_ids=position_ids,
620
+ head_mask=head_mask,
621
+ inputs_embeds=inputs_embeds,
622
+ encoder_hidden_states=encoder_hidden_states,
623
+ encoder_attention_mask=encoder_attention_mask,
624
+ past_key_values=past_key_values,
625
+ use_cache=use_cache,
626
+ output_attentions=output_attentions,
627
+ output_hidden_states=output_hidden_states,
628
+ return_dict=return_dict,
629
+ )
630
+
631
+ # Query for bias
632
+ sequence_output = outputs[0]
633
+ bias_indicate_output = None
634
+
635
+ # output copy attention
636
+ query_copy = torch.relu(self.dense_query_copy(sequence_output))
637
+ sequence_atten_copy_output = self.forward_copy_attention(query_copy,
638
+ encoder_hidden_states,
639
+ encoder_attention_mask)
640
+
641
+ if encoder_bias_pooling is not None:
642
+
643
+ # Make bias features
644
+ encoder_bias_pooling = torch.cat([self.mem_no_entry, encoder_bias_pooling], dim=0)
645
+ mem_no_entry_feature = torch.zeros_like(encoder_bias_hidden_states[0]).unsqueeze(0)
646
+ mem_no_entry_mask = torch.ones_like(bias_attention_mask[0]).unsqueeze(0)
647
+ encoder_bias_hidden_states = torch.cat([mem_no_entry_feature, encoder_bias_hidden_states], dim=0)
648
+ bias_attention_mask = torch.cat([mem_no_entry_mask, bias_attention_mask], dim=0)
649
+
650
+ # Compute ranking score
651
+ b, s, h = sequence_output.size()
652
+ bias_ranking_score = sequence_output.view(b * s, h).mm(encoder_bias_pooling.T)
653
+ bias_ranking_score = bias_ranking_score.view(b, s, encoder_bias_pooling.size(0))
654
+
655
+ # teacher force with bias label
656
+ if not self.training:
657
+ bias_indicate_output = torch.argmax(bias_ranking_score, dim=-1)
658
+ else:
659
+ if random.random() < 0.5:
660
+ bias_indicate_output = labels_bias.clone()
661
+ bias_indicate_output[torch.where(bias_indicate_output < 0)] = 0
662
+ else:
663
+ bias_indicate_output = torch.argmax(bias_ranking_score, dim=-1)
664
+
665
+ # Bias encoder hidden state
666
+ _, max_len, _ = encoder_bias_hidden_states.size()
667
+ bias_encoder_hidden_states = torch.index_select(input=encoder_bias_hidden_states,
668
+ dim=0,
669
+ index=bias_indicate_output.view(b * s)).view(b, s, max_len,
670
+ h)
671
+ bias_encoder_attention_mask = torch.index_select(input=bias_attention_mask,
672
+ dim=0,
673
+ index=bias_indicate_output.view(b * s)).view(b, s, max_len)
674
+
675
+ sequence_atten_bias_output = self.forward_bias_attention(sequence_output,
676
+ bias_encoder_hidden_states,
677
+ bias_encoder_attention_mask)
678
+
679
+ # Find output words
680
+ prediction_scores = self.lm_head(sequence_output + sequence_atten_bias_output + sequence_atten_copy_output)
681
+ else:
682
+ prediction_scores = self.lm_head(sequence_output + sequence_atten_copy_output)
683
+
684
+ # run attention with bias
685
+
686
+ bias_ranking_loss = None
687
+ if labels_bias is not None:
688
+ loss_fct = CrossEntropyLoss()
689
+ bias_ranking_loss = loss_fct(bias_ranking_score.view(-1, encoder_bias_pooling.size(0)),
690
+ labels_bias.view(-1))
691
+
692
+ if not return_dict:
693
+ output = (prediction_scores,) + outputs[2:]
694
+ return ((bias_ranking_loss,) + output) if bias_ranking_loss is not None else output
695
+
696
+ result = CausalLMOutputWithCrossAttentions(
697
+ loss=bias_ranking_loss,
698
+ logits=prediction_scores,
699
+ past_key_values=outputs.past_key_values,
700
+ hidden_states=outputs.hidden_states,
701
+ attentions=outputs.attentions,
702
+ cross_attentions=outputs.cross_attentions,
703
+ )
704
+
705
+ result.bias_indicate_output = bias_indicate_output
706
+
707
+ return result
708
+
709
+
710
+ def download_tokenizer_files():
711
+ resources = ['envibert_tokenizer.py', 'dict.txt', 'sentencepiece.bpe.model']
712
+ for item in resources:
713
+ if not os.path.exists(os.path.join(cache_dir, item)):
714
+ tmp_file = hf_bucket_url(model_name, filename=item)
715
+ tmp_file = cached_path(tmp_file, cache_dir=cache_dir)
716
+ os.rename(tmp_file, os.path.join(cache_dir, item))
717
+
718
+
719
+ def init_tokenizer():
720
+ download_tokenizer_files()
721
+ tokenizer = SourceFileLoader("envibert.tokenizer",
722
+ os.path.join(cache_dir,
723
+ 'envibert_tokenizer.py')).load_module().RobertaTokenizer(cache_dir)
724
+ tokenizer.model_input_names = ["input_ids",
725
+ "attention_mask",
726
+ "bias_input_ids",
727
+ "bias_attention_mask",
728
+ "labels"
729
+ "labels_bias"]
730
+ return tokenizer
731
+
732
+
733
+ def init_model():
734
+ download_tokenizer_files()
735
+ tokenizer = SourceFileLoader("envibert.tokenizer",
736
+ os.path.join(cache_dir,
737
+ 'envibert_tokenizer.py')).load_module().RobertaTokenizer(cache_dir)
738
+ tokenizer.model_input_names = ["input_ids",
739
+ "attention_mask",
740
+ "bias_input_ids",
741
+ "bias_attention_mask",
742
+ "labels"
743
+ "labels_bias"]
744
+ # set encoder decoder tying to True
745
+ roberta_shared = EncoderDecoderSpokenNorm.from_encoder_decoder_pretrained(model_name,
746
+ model_name,
747
+ tie_encoder_decoder=False)
748
+
749
+ # set special tokens
750
+ roberta_shared.config.decoder_start_token_id = tokenizer.bos_token_id
751
+ roberta_shared.config.eos_token_id = tokenizer.eos_token_id
752
+ roberta_shared.config.pad_token_id = tokenizer.pad_token_id
753
+
754
+ # sensible parameters for beam search
755
+ # set decoding params
756
+ roberta_shared.config.max_length = 50
757
+ roberta_shared.config.early_stopping = True
758
+ roberta_shared.config.no_repeat_ngram_size = 3
759
+ roberta_shared.config.length_penalty = 2.0
760
+ roberta_shared.config.num_beams = 1
761
+ roberta_shared.config.vocab_size = roberta_shared.config.encoder.vocab_size
762
+
763
+ return roberta_shared, tokenizer
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==1.10.0
2
+ sentencepiece==0.1.91
3
+ transformers==4.16.2
4
+ datasets==1.17.0
5
+ regtag
6
+ validators
7
+ jiwer
8
+ gradio
utils.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import difflib
2
+ import regtag
3
+ import random
4
+
5
+
6
+ def merge_span(words, tags):
7
+ spans, span_tags = [], []
8
+ current_tag = 'O'
9
+ span = []
10
+ for w, t in zip(words, tags):
11
+ w = w.strip(":-")
12
+ if len(w) == 0:
13
+ continue
14
+ t_info = t.split('-')
15
+ if t_info[-1] != current_tag or t_info[0] == 'B':
16
+ if len(span) > 0:
17
+ spans.append(' '.join(span))
18
+ span_tags.append(current_tag)
19
+ span = [w]
20
+ current_tag = t_info[-1]
21
+ else:
22
+ span.append(w)
23
+ if len(span) > 0:
24
+ spans.append(' '.join(span))
25
+ span_tags.append(current_tag)
26
+ return spans, span_tags
27
+
28
+
29
+ def make_spoken(text, do_split=True):
30
+ src, tgt = [], []
31
+ if do_split:
32
+ chunk_size = random.choice(list(range(0, 10)) + list(range(10, 35)) * 4)
33
+ if chunk_size > 0:
34
+ text = random.choice(split_chunk_input(text, chunk_size))
35
+ else:
36
+ text = ''
37
+ words, word_tags = merge_span(*regtag.tagging(text))
38
+ for span, t in zip(words, word_tags):
39
+ if t == 'O':
40
+ for w in span.split():
41
+ w = w.strip('/.,?!').lower()
42
+ if len(w) > 0:
43
+ src.append(w)
44
+ tgt.append(w)
45
+ if random.random() < 0.01:
46
+ random_value = regtag.augment.get_random_span()
47
+ tgt.append(random_value[0])
48
+ src.append(random_value[1].lower())
49
+ else:
50
+ random_value = regtag.augment.get_random_span(t, span.lower())
51
+ tgt.append(random_value[0])
52
+ src.append(random_value[1].lower())
53
+
54
+ if len(src) == 0:
55
+ tgt, src = regtag.get_random_span()
56
+ src = [src]
57
+ tgt = [tgt]
58
+
59
+ return src, tgt
60
+
61
+
62
+ def split_chunk_input(raw_text, chunk_size):
63
+ input_words = raw_text.strip().split()
64
+ clean_data = [input_words[i:i + chunk_size] for i in range(0, len(input_words), chunk_size)]
65
+ if len(clean_data) > 1:
66
+ clean_data = [" ".join(clean_data[i] + clean_data[i + 1]) for i in range(len(clean_data) - 1)]
67
+ else:
68
+ clean_data = [" ".join(clean_data[0])]
69
+ return clean_data
70
+
71
+
72
+ def split_chunk_input(raw_text, chunk_size=40, overlap=10):
73
+ input_words = raw_text.strip().split()
74
+ part_per_chunk = chunk_size // overlap
75
+ clean_data = [input_words[i:i + overlap] for i in range(0, len(input_words), overlap)]
76
+ if len(clean_data) > 1:
77
+ merge_data = []
78
+ for i in range(0, len(clean_data) - 1, part_per_chunk - 1):
79
+ merge_data.append(' '.join([y for x in clean_data[i:i + part_per_chunk] for y in x]))
80
+ else:
81
+ merge_data = [" ".join(clean_data[0])]
82
+ return merge_data
83
+
84
+
85
+ def merge_two_chunk(chunk_1, chunk_2, overlap, debug=False):
86
+ def extract_phrase_word(phrase):
87
+ if phrase.startswith('<mask>'):
88
+ return phrase[7:].split('](')[1][:-1].split()
89
+ else:
90
+ return [phrase]
91
+
92
+ def has_tag(phrase):
93
+ if phrase.startswith('<') and phrase.endswith(')'):
94
+ return True
95
+ return False
96
+
97
+ def extract_compete_region(list_phrases, is_head):
98
+ if is_head:
99
+ list_phrases = list_phrases[::-1]
100
+ compete = []
101
+ remain = []
102
+ handle_count = 0
103
+ for phrase in list_phrases:
104
+ phrase_word = extract_phrase_word(phrase)
105
+ if len(phrase_word) + handle_count <= overlap:
106
+ compete.append(phrase)
107
+ handle_count += len(phrase_word)
108
+ else:
109
+ if handle_count < overlap:
110
+ remain_compete_count = overlap - handle_count
111
+ remain.append(phrase)
112
+ if not is_head:
113
+ compete.extend(["<delete>({})".format(item) for item in phrase_word[:remain_compete_count]])
114
+ else:
115
+ compete.extend(
116
+ ["<delete>({})".format(item) for item in phrase_word[::-1][:remain_compete_count]])
117
+ handle_count = overlap
118
+ else:
119
+ remain.append(phrase)
120
+ if is_head:
121
+ compete = compete[::-1]
122
+ remain = remain[::-1]
123
+ return remain, compete
124
+
125
+ def is_equal(phrase_1, phrase_2):
126
+ if phrase_1 == phrase_2:
127
+ return True
128
+ if extract_phrase_word(phrase_1) == extract_phrase_word(phrase_2):
129
+ if phrase_1.startswith('<mask>') and phrase_2.startswith('<mask>'):
130
+ return True
131
+ return False
132
+
133
+ def merge_compete(list_1, list_2):
134
+ idx_list_1, idx_list_2, combine_phrases = [], [], []
135
+ mark_term_complete = []
136
+ list_raw = [extract_phrase_word(item) for item in list_1]
137
+ list_raw = [y for x in list_raw for y in x]
138
+ for idx, phrase in enumerate(list_1):
139
+ idx_list_1.extend([idx] * len(extract_phrase_word(phrase)))
140
+ for idx, phrase in enumerate(list_2):
141
+ idx_list_2.extend([idx] * len(extract_phrase_word(phrase)))
142
+ # print(idx_list_1, idx_list_2)
143
+ for idx, (idx_1, idx_2) in enumerate(zip(idx_list_1, idx_list_2)):
144
+ if list_1[idx_1].startswith('<delete>') or list_2[idx_2].startswith('<delete>'):
145
+ continue
146
+ elif is_equal(list_1[idx_1], list_2[idx_2]):
147
+ # print(list_1[idx_1])
148
+ if '1_{}'.format(idx_1) not in mark_term_complete and '2_{}'.format(idx_2) not in mark_term_complete:
149
+ if idx <= overlap//2:
150
+ combine_phrases.append(list_1[idx_1])
151
+ mark_term_complete.append('1_{}'.format(idx_1))
152
+ else:
153
+ combine_phrases.append(list_2[idx_2])
154
+ mark_term_complete.append('2_{}'.format(idx_2))
155
+ else:
156
+ combine_phrases.append(list_raw[idx])
157
+ mark_term_complete.extend(['1_{}'.format(idx_1), '2_{}'.format(idx_2)])
158
+ # print(mark_term_complete)
159
+ return combine_phrases
160
+
161
+ remain_1, compete_1 = extract_compete_region(chunk_1, is_head=True)
162
+ remain_2, compete_2 = extract_compete_region(chunk_2[1:-1], is_head=False)
163
+ compromise = merge_compete(compete_1, compete_2)
164
+
165
+ if debug:
166
+ print(remain_1, '\n', compete_1)
167
+ print('-----------------------')
168
+ print(compete_2, '\n', remain_2)
169
+ print('-----------------------')
170
+ print(compromise, '\n\n')
171
+
172
+ return remain_1 + compromise + remain_2
173
+
174
+
175
+ def merge_chunk_pre_norm(list_chunks, overlap, debug=False):
176
+ if len(list_chunks) == 0:
177
+ return []
178
+ if len(list_chunks) == 1:
179
+ return list_chunks[0][1:-1]
180
+ current_chunk = list_chunks[0][1:-1]
181
+ for tmp_chunk in list_chunks[1:]:
182
+ current_chunk = merge_two_chunk(current_chunk, tmp_chunk, overlap, debug=debug)
183
+ return current_chunk
184
+
185
+
186
+ def equalize(s1, s2):
187
+ l1 = s1.split()
188
+ l2 = s2.split()
189
+ res1 = []
190
+ res2 = []
191
+ combine = []
192
+ prev = difflib.Match(0, 0, 0)
193
+ for match in difflib.SequenceMatcher(a=l1, b=l2).get_matching_blocks():
194
+ if prev.a + prev.size != match.a:
195
+ for i in range(prev.a + prev.size, match.a):
196
+ res2 += ['_' * len(l1[i])]
197
+ res1 += l1[prev.a + prev.size:match.a]
198
+
199
+ for i in l1[prev.a + prev.size:match.a]:
200
+ if len(combine) < len(l1) // 2:
201
+ print(l1[prev.a + prev.size:match.a])
202
+ combine.append(i)
203
+ if prev.b + prev.size != match.b:
204
+ for i in range(prev.b + prev.size, match.b):
205
+ res1 += ['_' * len(l2[i])]
206
+ res2 += l2[prev.b + prev.size:match.b]
207
+
208
+ for i in l2[prev.b + prev.size:match.b]:
209
+ if len(combine) >= len(l2) // 2:
210
+ print(l2[prev.b + prev.size:match.b])
211
+ combine.append(i)
212
+ res1 += l1[match.a:match.a + match.size]
213
+ res2 += l2[match.b:match.b + match.size]
214
+ combine += l2[match.b:match.b + match.size]
215
+ prev = match
216
+ return ' '.join(res1), ' '.join(res2), combine
217
+
218
+
219
+ def count_overlap(words_1, words_2):
220
+ # print(words_1, words_2)
221
+ assert len(words_1) == len(words_2)
222
+ len_overlap = 0
223
+ for match in difflib.SequenceMatcher(a=words_1, b=words_2).get_matching_blocks():
224
+ len_overlap += match.size
225
+
226
+ # for w1, w2 in zip(words_1, words_2):
227
+ # if w1 == w2:
228
+ # len_overlap += 1
229
+ return len_overlap
230
+
231
+
232
+ def find_overlap_chunk(txt_1, txt_2):
233
+ # print(txt_1)
234
+ # print(txt_2)
235
+ window_view = 1
236
+ idx_1 = len(txt_1) - window_view
237
+ idx_2 = window_view
238
+ over_lap = 0
239
+ current_best_idx_1 = len(txt_1)
240
+ current_best_idx_2 = 0
241
+
242
+ while window_view <= len(txt_1) and window_view <= len(txt_2):
243
+ current_overlap = count_overlap(txt_1[idx_1:], txt_2[:idx_2])
244
+ print(current_overlap)
245
+ if over_lap < current_overlap:
246
+ over_lap = current_overlap
247
+ current_best_idx_1 = idx_1
248
+ current_best_idx_2 = idx_2
249
+ window_view += 1
250
+ idx_1 = len(txt_1) - window_view
251
+ idx_2 = window_view
252
+ # else:
253
+ # break
254
+ print('----->', txt_1[current_best_idx_1:], txt_2[:current_best_idx_2])
255
+ return txt_1[current_best_idx_1:], txt_2[:current_best_idx_2]
256
+
257
+
258
+ def concat_chunks(list_chunks):
259
+ concat_string = list_chunks[0].split()
260
+ for i in range(1, len(list_chunks)):
261
+ remain_string = list_chunks[i].split()
262
+ s1, s2 = find_overlap_chunk(concat_string, remain_string)
263
+ s1 = ' '.join(s1)
264
+ s2 = ' '.join(s2)
265
+ _, _, overlap_merged = equalize(s1, s2)
266
+ merge_len = len(s1.split())
267
+
268
+ concat_string = concat_string[:len(concat_string) - merge_len] + overlap_merged + remain_string[merge_len:]
269
+
270
+ concat_string = ' '.join(concat_string)
271
+ return concat_string