TheComputerMan commited on
Commit
9944ebb
1 Parent(s): e1ce12a

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +320 -0
utils.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet, modified by Florian Lux
3
+ """
4
+
5
+ import os
6
+ from abc import ABC
7
+
8
+ import torch
9
+
10
+
11
+ def cumsum_durations(durations):
12
+ out = [0]
13
+ for duration in durations:
14
+ out.append(duration + out[-1])
15
+ centers = list()
16
+ for index, _ in enumerate(out):
17
+ if index + 1 < len(out):
18
+ centers.append((out[index] + out[index + 1]) / 2)
19
+ return out, centers
20
+
21
+
22
+ def delete_old_checkpoints(checkpoint_dir, keep=5):
23
+ checkpoint_list = list()
24
+ for el in os.listdir(checkpoint_dir):
25
+ if el.endswith(".pt") and el != "best.pt":
26
+ checkpoint_list.append(int(el.split(".")[0].split("_")[1]))
27
+ if len(checkpoint_list) <= keep:
28
+ return
29
+ else:
30
+ checkpoint_list.sort(reverse=False)
31
+ checkpoints_to_delete = [os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(step)) for step in checkpoint_list[:-keep]]
32
+ for old_checkpoint in checkpoints_to_delete:
33
+ os.remove(os.path.join(old_checkpoint))
34
+
35
+
36
+ def get_most_recent_checkpoint(checkpoint_dir, verbose=True):
37
+ checkpoint_list = list()
38
+ for el in os.listdir(checkpoint_dir):
39
+ if el.endswith(".pt") and el != "best.pt":
40
+ checkpoint_list.append(int(el.split(".")[0].split("_")[1]))
41
+ if len(checkpoint_list) == 0:
42
+ print("No previous checkpoints found, cannot reload.")
43
+ return None
44
+ checkpoint_list.sort(reverse=True)
45
+ if verbose:
46
+ print("Reloading checkpoint_{}.pt".format(checkpoint_list[0]))
47
+ return os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(checkpoint_list[0]))
48
+
49
+
50
+ def make_pad_mask(lengths, xs=None, length_dim=-1, device=None):
51
+ """
52
+ Make mask tensor containing indices of padded part.
53
+
54
+ Args:
55
+ lengths (LongTensor or List): Batch of lengths (B,).
56
+ xs (Tensor, optional): The reference tensor.
57
+ If set, masks will be the same shape as this tensor.
58
+ length_dim (int, optional): Dimension indicator of the above tensor.
59
+ See the example.
60
+
61
+ Returns:
62
+ Tensor: Mask tensor containing indices of padded part.
63
+ dtype=torch.uint8 in PyTorch 1.2-
64
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
65
+
66
+ """
67
+ if length_dim == 0:
68
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
69
+
70
+ if not isinstance(lengths, list):
71
+ lengths = lengths.tolist()
72
+ bs = int(len(lengths))
73
+ if xs is None:
74
+ maxlen = int(max(lengths))
75
+ else:
76
+ maxlen = xs.size(length_dim)
77
+
78
+ if device is not None:
79
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64, device=device)
80
+ else:
81
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
82
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
83
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
84
+ mask = seq_range_expand >= seq_length_expand
85
+
86
+ if xs is not None:
87
+ assert xs.size(0) == bs, (xs.size(0), bs)
88
+
89
+ if length_dim < 0:
90
+ length_dim = xs.dim() + length_dim
91
+ # ind = (:, None, ..., None, :, , None, ..., None)
92
+ ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim()))
93
+ mask = mask[ind].expand_as(xs).to(xs.device)
94
+ return mask
95
+
96
+
97
+ def make_non_pad_mask(lengths, xs=None, length_dim=-1, device=None):
98
+ """
99
+ Make mask tensor containing indices of non-padded part.
100
+
101
+ Args:
102
+ lengths (LongTensor or List): Batch of lengths (B,).
103
+ xs (Tensor, optional): The reference tensor.
104
+ If set, masks will be the same shape as this tensor.
105
+ length_dim (int, optional): Dimension indicator of the above tensor.
106
+ See the example.
107
+
108
+ Returns:
109
+ ByteTensor: mask tensor containing indices of padded part.
110
+ dtype=torch.uint8 in PyTorch 1.2-
111
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
112
+
113
+ """
114
+ return ~make_pad_mask(lengths, xs, length_dim, device=device)
115
+
116
+
117
+ def initialize(model, init):
118
+ """
119
+ Initialize weights of a neural network module.
120
+
121
+ Parameters are initialized using the given method or distribution.
122
+
123
+ Args:
124
+ model: Target.
125
+ init: Method of initialization.
126
+ """
127
+
128
+ # weight init
129
+ for p in model.parameters():
130
+ if p.dim() > 1:
131
+ if init == "xavier_uniform":
132
+ torch.nn.init.xavier_uniform_(p.data)
133
+ elif init == "xavier_normal":
134
+ torch.nn.init.xavier_normal_(p.data)
135
+ elif init == "kaiming_uniform":
136
+ torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
137
+ elif init == "kaiming_normal":
138
+ torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
139
+ else:
140
+ raise ValueError("Unknown initialization: " + init)
141
+ # bias init
142
+ for p in model.parameters():
143
+ if p.dim() == 1:
144
+ p.data.zero_()
145
+
146
+ # reset some modules with default init
147
+ for m in model.modules():
148
+ if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm)):
149
+ m.reset_parameters()
150
+
151
+
152
+ def pad_list(xs, pad_value):
153
+ """
154
+ Perform padding for the list of tensors.
155
+
156
+ Args:
157
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
158
+ pad_value (float): Value for padding.
159
+
160
+ Returns:
161
+ Tensor: Padded tensor (B, Tmax, `*`).
162
+
163
+ """
164
+ n_batch = len(xs)
165
+ max_len = max(x.size(0) for x in xs)
166
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
167
+
168
+ for i in range(n_batch):
169
+ pad[i, : xs[i].size(0)] = xs[i]
170
+
171
+ return pad
172
+
173
+
174
+ def subsequent_mask(size, device="cpu", dtype=torch.bool):
175
+ """
176
+ Create mask for subsequent steps (size, size).
177
+
178
+ :param int size: size of mask
179
+ :param str device: "cpu" or "cuda" or torch.Tensor.device
180
+ :param torch.dtype dtype: result dtype
181
+ :rtype
182
+ """
183
+ ret = torch.ones(size, size, device=device, dtype=dtype)
184
+ return torch.tril(ret, out=ret)
185
+
186
+
187
+ class ScorerInterface:
188
+ """
189
+ Scorer interface for beam search.
190
+
191
+ The scorer performs scoring of the all tokens in vocabulary.
192
+
193
+ Examples:
194
+ * Search heuristics
195
+ * :class:`espnet.nets.scorers.length_bonus.LengthBonus`
196
+ * Decoder networks of the sequence-to-sequence models
197
+ * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder`
198
+ * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
199
+ * Neural language models
200
+ * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
201
+ * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
202
+ * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
203
+
204
+ """
205
+
206
+ def init_state(self, x):
207
+ """
208
+ Get an initial state for decoding (optional).
209
+
210
+ Args:
211
+ x (torch.Tensor): The encoded feature tensor
212
+
213
+ Returns: initial state
214
+
215
+ """
216
+ return None
217
+
218
+ def select_state(self, state, i, new_id=None):
219
+ """
220
+ Select state with relative ids in the main beam search.
221
+
222
+ Args:
223
+ state: Decoder state for prefix tokens
224
+ i (int): Index to select a state in the main beam search
225
+ new_id (int): New label index to select a state if necessary
226
+
227
+ Returns:
228
+ state: pruned state
229
+
230
+ """
231
+ return None if state is None else state[i]
232
+
233
+ def score(self, y, state, x):
234
+ """
235
+ Score new token (required).
236
+
237
+ Args:
238
+ y (torch.Tensor): 1D torch.int64 prefix tokens.
239
+ state: Scorer state for prefix tokens
240
+ x (torch.Tensor): The encoder feature that generates ys.
241
+
242
+ Returns:
243
+ tuple[torch.Tensor, Any]: Tuple of
244
+ scores for next token that has a shape of `(n_vocab)`
245
+ and next state for ys
246
+
247
+ """
248
+ raise NotImplementedError
249
+
250
+ def final_score(self, state):
251
+ """
252
+ Score eos (optional).
253
+
254
+ Args:
255
+ state: Scorer state for prefix tokens
256
+
257
+ Returns:
258
+ float: final score
259
+
260
+ """
261
+ return 0.0
262
+
263
+
264
+ class BatchScorerInterface(ScorerInterface, ABC):
265
+
266
+ def batch_init_state(self, x):
267
+ """
268
+ Get an initial state for decoding (optional).
269
+
270
+ Args:
271
+ x (torch.Tensor): The encoded feature tensor
272
+
273
+ Returns: initial state
274
+
275
+ """
276
+ return self.init_state(x)
277
+
278
+ def batch_score(self, ys, states, xs):
279
+ """
280
+ Score new token batch (required).
281
+
282
+ Args:
283
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
284
+ states (List[Any]): Scorer states for prefix tokens.
285
+ xs (torch.Tensor):
286
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
287
+
288
+ Returns:
289
+ tuple[torch.Tensor, List[Any]]: Tuple of
290
+ batchfied scores for next token with shape of `(n_batch, n_vocab)`
291
+ and next state list for ys.
292
+
293
+ """
294
+ scores = list()
295
+ outstates = list()
296
+ for i, (y, state, x) in enumerate(zip(ys, states, xs)):
297
+ score, outstate = self.score(y, state, x)
298
+ outstates.append(outstate)
299
+ scores.append(score)
300
+ scores = torch.cat(scores, 0).view(ys.shape[0], -1)
301
+ return scores, outstates
302
+
303
+
304
+ def to_device(m, x):
305
+ """Send tensor into the device of the module.
306
+ Args:
307
+ m (torch.nn.Module): Torch module.
308
+ x (Tensor): Torch tensor.
309
+ Returns:
310
+ Tensor: Torch tensor located in the same place as torch module.
311
+ """
312
+ if isinstance(m, torch.nn.Module):
313
+ device = next(m.parameters()).device
314
+ elif isinstance(m, torch.Tensor):
315
+ device = m.device
316
+ else:
317
+ raise TypeError(
318
+ "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
319
+ )
320
+ return x.to(device)