Jongmo commited on
Commit
a5bbcdb
1 Parent(s): 49664ed

Upload 25 files

Browse files
utils/__init__.py ADDED
File without changes
utils/__pycache__/bert_model.cpython-39.pyc ADDED
Binary file (30.6 kB). View file
 
utils/__pycache__/callbacks.cpython-39.pyc ADDED
Binary file (4.9 kB). View file
 
utils/__pycache__/file_utils.cpython-39.pyc ADDED
Binary file (6.81 kB). View file
 
utils/__pycache__/finetune.cpython-39.pyc ADDED
Binary file (20.2 kB). View file
 
utils/__pycache__/lightning_base.cpython-39.pyc ADDED
Binary file (13.5 kB). View file
 
utils/__pycache__/sentence_retrieval_model.cpython-39.pyc ADDED
Binary file (1.11 kB). View file
 
utils/__pycache__/sentence_retrieval_module.cpython-39.pyc ADDED
Binary file (2.52 kB). View file
 
utils/__pycache__/textual_entailment_module.cpython-39.pyc ADDED
Binary file (2.65 kB). View file
 
utils/__pycache__/utils_graph2text.cpython-39.pyc ADDED
Binary file (3.12 kB). View file
 
utils/__pycache__/utils_verbalisation_module.cpython-39.pyc ADDED
Binary file (23.9 kB). View file
 
utils/__pycache__/verbalisation_module.cpython-39.pyc ADDED
Binary file (7.37 kB). View file
 
utils/__pycache__/wikidata_utils.cpython-39.pyc ADDED
Binary file (5.29 kB). View file
 
utils/bert_model.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+ from __future__ import absolute_import, division, print_function, unicode_literals
19
+
20
+ import copy
21
+ import json
22
+ import logging
23
+ import math
24
+ import os
25
+ import shutil
26
+ import tarfile
27
+ import tempfile
28
+ import sys
29
+ from io import open
30
+
31
+ import torch
32
+ from torch import nn
33
+ from torch.nn import CrossEntropyLoss
34
+
35
+ from utils.file_utils import cached_path
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ PRETRAINED_MODEL_ARCHIVE_MAP = {
40
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
41
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
42
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
43
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
44
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
45
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
46
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
47
+ }
48
+ CONFIG_NAME = 'bert_config.json'
49
+ WEIGHTS_NAME = 'pytorch_model.bin'
50
+ TF_WEIGHTS_NAME = 'model.ckpt'
51
+
52
+ def load_tf_weights_in_bert(model, tf_checkpoint_path):
53
+ """ Load tf checkpoints in a pytorch model
54
+ """
55
+ try:
56
+ import re
57
+ import numpy as np
58
+ import tensorflow as tf
59
+ except ImportError:
60
+ print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
61
+ "https://www.tensorflow.org/install/ for installation instructions.")
62
+ raise
63
+ tf_path = os.path.abspath(tf_checkpoint_path)
64
+ print("Converting TensorFlow checkpoint from {}".format(tf_path))
65
+ # Load weights from TF model
66
+ init_vars = tf.train.list_variables(tf_path)
67
+ names = []
68
+ arrays = []
69
+ for name, shape in init_vars:
70
+ print("Loading TF weight {} with shape {}".format(name, shape))
71
+ array = tf.train.load_variable(tf_path, name)
72
+ names.append(name)
73
+ arrays.append(array)
74
+
75
+ for name, array in zip(names, arrays):
76
+ name = name.split('/')
77
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
78
+ # which are not required for using pretrained model
79
+ if any(n in ["adam_v", "adam_m"] for n in name):
80
+ print("Skipping {}".format("/".join(name)))
81
+ continue
82
+ pointer = model
83
+ for m_name in name:
84
+ if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
85
+ l = re.split(r'_(\d+)', m_name)
86
+ else:
87
+ l = [m_name]
88
+ if l[0] == 'kernel' or l[0] == 'gamma':
89
+ pointer = getattr(pointer, 'weight')
90
+ elif l[0] == 'output_bias' or l[0] == 'beta':
91
+ pointer = getattr(pointer, 'bias')
92
+ elif l[0] == 'output_weights':
93
+ pointer = getattr(pointer, 'weight')
94
+ else:
95
+ pointer = getattr(pointer, l[0])
96
+ if len(l) >= 2:
97
+ num = int(l[1])
98
+ pointer = pointer[num]
99
+ if m_name[-11:] == '_embeddings':
100
+ pointer = getattr(pointer, 'weight')
101
+ elif m_name == 'kernel':
102
+ array = np.transpose(array)
103
+ try:
104
+ assert pointer.shape == array.shape
105
+ except AssertionError as e:
106
+ e.args += (pointer.shape, array.shape)
107
+ raise
108
+ print("Initialize PyTorch weight {}".format(name))
109
+ pointer.data = torch.from_numpy(array)
110
+ return model
111
+
112
+
113
+ def gelu(x):
114
+ """Implementation of the gelu activation function.
115
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
116
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
117
+ Also see https://arxiv.org/abs/1606.08415
118
+ """
119
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
120
+
121
+
122
+ def swish(x):
123
+ return x * torch.sigmoid(x)
124
+
125
+
126
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
127
+
128
+
129
+ class BertConfig(object):
130
+ """Configuration class to store the configuration of a `BertModel`.
131
+ """
132
+ def __init__(self,
133
+ vocab_size_or_config_json_file,
134
+ hidden_size=768,
135
+ num_hidden_layers=12,
136
+ num_attention_heads=12,
137
+ intermediate_size=3072,
138
+ hidden_act="gelu",
139
+ hidden_dropout_prob=0.1,
140
+ attention_probs_dropout_prob=0.1,
141
+ max_position_embeddings=512,
142
+ type_vocab_size=2,
143
+ initializer_range=0.02):
144
+ """Constructs BertConfig.
145
+
146
+ Args:
147
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
148
+ hidden_size: Size of the encoder layers and the pooler layer.
149
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
150
+ num_attention_heads: Number of attention heads for each attention layer in
151
+ the Transformer encoder.
152
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
153
+ layer in the Transformer encoder.
154
+ hidden_act: The non-linear activation function (function or string) in the
155
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
156
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
157
+ layers in the embeddings, encoder, and pooler.
158
+ attention_probs_dropout_prob: The dropout ratio for the attention
159
+ probabilities.
160
+ max_position_embeddings: The maximum sequence length that this model might
161
+ ever be used with. Typically set this to something large just in case
162
+ (e.g., 512 or 1024 or 2048).
163
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
164
+ `BertModel`.
165
+ initializer_range: The sttdev of the truncated_normal_initializer for
166
+ initializing all weight matrices.
167
+ """
168
+ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
169
+ and isinstance(vocab_size_or_config_json_file, unicode)):
170
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
171
+ json_config = json.loads(reader.read())
172
+ for key, value in json_config.items():
173
+ self.__dict__[key] = value
174
+ elif isinstance(vocab_size_or_config_json_file, int):
175
+ self.vocab_size = vocab_size_or_config_json_file
176
+ self.hidden_size = hidden_size
177
+ self.num_hidden_layers = num_hidden_layers
178
+ self.num_attention_heads = num_attention_heads
179
+ self.hidden_act = hidden_act
180
+ self.intermediate_size = intermediate_size
181
+ self.hidden_dropout_prob = hidden_dropout_prob
182
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
183
+ self.max_position_embeddings = max_position_embeddings
184
+ self.type_vocab_size = type_vocab_size
185
+ self.initializer_range = initializer_range
186
+ else:
187
+ raise ValueError("First argument must be either a vocabulary size (int)"
188
+ "or the path to a pretrained model config file (str)")
189
+
190
+ @classmethod
191
+ def from_dict(cls, json_object):
192
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
193
+ config = BertConfig(vocab_size_or_config_json_file=-1)
194
+ for key, value in json_object.items():
195
+ config.__dict__[key] = value
196
+ return config
197
+
198
+ @classmethod
199
+ def from_json_file(cls, json_file):
200
+ """Constructs a `BertConfig` from a json file of parameters."""
201
+ with open(json_file, "r", encoding='utf-8') as reader:
202
+ text = reader.read()
203
+ return cls.from_dict(json.loads(text))
204
+
205
+ def __repr__(self):
206
+ return str(self.to_json_string())
207
+
208
+ def to_dict(self):
209
+ """Serializes this instance to a Python dictionary."""
210
+ output = copy.deepcopy(self.__dict__)
211
+ return output
212
+
213
+ def to_json_string(self):
214
+ """Serializes this instance to a JSON string."""
215
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
216
+
217
+ try:
218
+ from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
219
+ except ImportError:
220
+ print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
221
+ class BertLayerNorm(nn.Module):
222
+ def __init__(self, hidden_size, eps=1e-12):
223
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
224
+ """
225
+ super(BertLayerNorm, self).__init__()
226
+ self.weight = nn.Parameter(torch.ones(hidden_size))
227
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
228
+ self.variance_epsilon = eps
229
+
230
+ def forward(self, x):
231
+ u = x.mean(-1, keepdim=True)
232
+ s = (x - u).pow(2).mean(-1, keepdim=True)
233
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
234
+ return self.weight * x + self.bias
235
+
236
+ class BertEmbeddings(nn.Module):
237
+ """Construct the embeddings from word, position and token_type embeddings.
238
+ """
239
+ def __init__(self, config):
240
+ super(BertEmbeddings, self).__init__()
241
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
242
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
243
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
244
+
245
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
246
+ # any TensorFlow checkpoint file
247
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
248
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
249
+
250
+ def forward(self, input_ids, token_type_ids=None):
251
+ seq_length = input_ids.size(1)
252
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
253
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
254
+ if token_type_ids is None:
255
+ token_type_ids = torch.zeros_like(input_ids)
256
+
257
+ words_embeddings = self.word_embeddings(input_ids)
258
+ position_embeddings = self.position_embeddings(position_ids)
259
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
260
+
261
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
262
+ embeddings = self.LayerNorm(embeddings)
263
+ embeddings = self.dropout(embeddings)
264
+ return embeddings
265
+
266
+
267
+ class BertSelfAttention(nn.Module):
268
+ def __init__(self, config):
269
+ super(BertSelfAttention, self).__init__()
270
+ if config.hidden_size % config.num_attention_heads != 0:
271
+ raise ValueError(
272
+ "The hidden size (%d) is not a multiple of the number of attention "
273
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
274
+ self.num_attention_heads = config.num_attention_heads
275
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
276
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
277
+
278
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
279
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
280
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
281
+
282
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
283
+
284
+ def transpose_for_scores(self, x):
285
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
286
+ x = x.view(*new_x_shape)
287
+ return x.permute(0, 2, 1, 3)
288
+
289
+ def forward(self, hidden_states, attention_mask):
290
+ mixed_query_layer = self.query(hidden_states)
291
+ mixed_key_layer = self.key(hidden_states)
292
+ mixed_value_layer = self.value(hidden_states)
293
+
294
+ query_layer = self.transpose_for_scores(mixed_query_layer)
295
+ key_layer = self.transpose_for_scores(mixed_key_layer)
296
+ value_layer = self.transpose_for_scores(mixed_value_layer)
297
+
298
+ # Take the dot product between "query" and "key" to get the raw attention scores.
299
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
300
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
301
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
302
+ attention_scores = attention_scores + attention_mask
303
+
304
+ # Normalize the attention scores to probabilities.
305
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
306
+
307
+ # This is actually dropping out entire tokens to attend to, which might
308
+ # seem a bit unusual, but is taken from the original Transformer paper.
309
+ attention_probs = self.dropout(attention_probs)
310
+
311
+ context_layer = torch.matmul(attention_probs, value_layer)
312
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
313
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
314
+ context_layer = context_layer.view(*new_context_layer_shape)
315
+ return context_layer
316
+
317
+
318
+ class BertSelfOutput(nn.Module):
319
+ def __init__(self, config):
320
+ super(BertSelfOutput, self).__init__()
321
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
322
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
323
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
324
+
325
+ def forward(self, hidden_states, input_tensor):
326
+ hidden_states = self.dense(hidden_states)
327
+ hidden_states = self.dropout(hidden_states)
328
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
329
+ return hidden_states
330
+
331
+
332
+ class BertAttention(nn.Module):
333
+ def __init__(self, config):
334
+ super(BertAttention, self).__init__()
335
+ self.self = BertSelfAttention(config)
336
+ self.output = BertSelfOutput(config)
337
+
338
+ def forward(self, input_tensor, attention_mask):
339
+ self_output = self.self(input_tensor, attention_mask)
340
+ attention_output = self.output(self_output, input_tensor)
341
+ return attention_output
342
+
343
+
344
+ class BertIntermediate(nn.Module):
345
+ def __init__(self, config):
346
+ super(BertIntermediate, self).__init__()
347
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
348
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
349
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
350
+ else:
351
+ self.intermediate_act_fn = config.hidden_act
352
+
353
+ def forward(self, hidden_states):
354
+ hidden_states = self.dense(hidden_states)
355
+ hidden_states = self.intermediate_act_fn(hidden_states)
356
+ return hidden_states
357
+
358
+
359
+ class BertOutput(nn.Module):
360
+ def __init__(self, config):
361
+ super(BertOutput, self).__init__()
362
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
363
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
364
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
365
+
366
+ def forward(self, hidden_states, input_tensor):
367
+ hidden_states = self.dense(hidden_states)
368
+ hidden_states = self.dropout(hidden_states)
369
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
370
+ return hidden_states
371
+
372
+
373
+ class BertLayer(nn.Module):
374
+ def __init__(self, config):
375
+ super(BertLayer, self).__init__()
376
+ self.attention = BertAttention(config)
377
+ self.intermediate = BertIntermediate(config)
378
+ self.output = BertOutput(config)
379
+
380
+ def forward(self, hidden_states, attention_mask):
381
+ attention_output = self.attention(hidden_states, attention_mask)
382
+ intermediate_output = self.intermediate(attention_output)
383
+ layer_output = self.output(intermediate_output, attention_output)
384
+ return layer_output
385
+
386
+
387
+ class BertEncoder(nn.Module):
388
+ def __init__(self, config):
389
+ super(BertEncoder, self).__init__()
390
+ layer = BertLayer(config)
391
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
392
+
393
+ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
394
+ all_encoder_layers = []
395
+ for layer_module in self.layer:
396
+ hidden_states = layer_module(hidden_states, attention_mask)
397
+ if output_all_encoded_layers:
398
+ all_encoder_layers.append(hidden_states)
399
+ if not output_all_encoded_layers:
400
+ all_encoder_layers.append(hidden_states)
401
+ return all_encoder_layers
402
+
403
+
404
+ class BertPooler(nn.Module):
405
+ def __init__(self, config):
406
+ super(BertPooler, self).__init__()
407
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
408
+ self.activation = nn.Tanh()
409
+
410
+ def forward(self, hidden_states):
411
+ # We "pool" the model by simply taking the hidden state corresponding
412
+ # to the first token.
413
+ first_token_tensor = hidden_states[:, 0]
414
+ pooled_output = self.dense(first_token_tensor)
415
+ pooled_output = self.activation(pooled_output)
416
+ return pooled_output
417
+
418
+
419
+ class BertPredictionHeadTransform(nn.Module):
420
+ def __init__(self, config):
421
+ super(BertPredictionHeadTransform, self).__init__()
422
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
423
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
424
+ self.transform_act_fn = ACT2FN[config.hidden_act]
425
+ else:
426
+ self.transform_act_fn = config.hidden_act
427
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
428
+
429
+ def forward(self, hidden_states):
430
+ hidden_states = self.dense(hidden_states)
431
+ hidden_states = self.transform_act_fn(hidden_states)
432
+ hidden_states = self.LayerNorm(hidden_states)
433
+ return hidden_states
434
+
435
+
436
+ class BertLMPredictionHead(nn.Module):
437
+ def __init__(self, config, bert_model_embedding_weights):
438
+ super(BertLMPredictionHead, self).__init__()
439
+ self.transform = BertPredictionHeadTransform(config)
440
+
441
+ # The output weights are the same as the input embeddings, but there is
442
+ # an output-only bias for each token.
443
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
444
+ bert_model_embedding_weights.size(0),
445
+ bias=False)
446
+ self.decoder.weight = bert_model_embedding_weights
447
+ self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
448
+
449
+ def forward(self, hidden_states):
450
+ hidden_states = self.transform(hidden_states)
451
+ hidden_states = self.decoder(hidden_states) + self.bias
452
+ return hidden_states
453
+
454
+
455
+ class BertOnlyMLMHead(nn.Module):
456
+ def __init__(self, config, bert_model_embedding_weights):
457
+ super(BertOnlyMLMHead, self).__init__()
458
+ self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
459
+
460
+ def forward(self, sequence_output):
461
+ prediction_scores = self.predictions(sequence_output)
462
+ return prediction_scores
463
+
464
+
465
+ class BertOnlyNSPHead(nn.Module):
466
+ def __init__(self, config):
467
+ super(BertOnlyNSPHead, self).__init__()
468
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
469
+
470
+ def forward(self, pooled_output):
471
+ seq_relationship_score = self.seq_relationship(pooled_output)
472
+ return seq_relationship_score
473
+
474
+
475
+ class BertPreTrainingHeads(nn.Module):
476
+ def __init__(self, config, bert_model_embedding_weights):
477
+ super(BertPreTrainingHeads, self).__init__()
478
+ self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
479
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
480
+
481
+ def forward(self, sequence_output, pooled_output):
482
+ prediction_scores = self.predictions(sequence_output)
483
+ seq_relationship_score = self.seq_relationship(pooled_output)
484
+ return prediction_scores, seq_relationship_score
485
+
486
+
487
+ class BertPreTrainedModel(nn.Module):
488
+ """ An abstract class to handle weights initialization and
489
+ a simple interface for dowloading and loading pretrained models.
490
+ """
491
+ def __init__(self, config, *inputs, **kwargs):
492
+ super(BertPreTrainedModel, self).__init__()
493
+ if not isinstance(config, BertConfig):
494
+ raise ValueError(
495
+ "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
496
+ "To create a model from a Google pretrained model use "
497
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
498
+ self.__class__.__name__, self.__class__.__name__
499
+ ))
500
+ self.config = config
501
+
502
+ def init_bert_weights(self, module):
503
+ """ Initialize the weights.
504
+ """
505
+ if isinstance(module, (nn.Linear, nn.Embedding)):
506
+ # Slightly different from the TF version which uses truncated_normal for initialization
507
+ # cf https://github.com/pytorch/pytorch/pull/5617
508
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
509
+ elif isinstance(module, BertLayerNorm):
510
+ module.bias.data.zero_()
511
+ module.weight.data.fill_(1.0)
512
+ if isinstance(module, nn.Linear) and module.bias is not None:
513
+ module.bias.data.zero_()
514
+
515
+ @classmethod
516
+ def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
517
+ from_tf=False, *inputs, **kwargs):
518
+ """
519
+ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
520
+ Download and cache the pre-trained model file if needed.
521
+
522
+ Params:
523
+ pretrained_model_name_or_path: either:
524
+ - a str with the name of a pre-trained model to load selected in the list of:
525
+ . `bert-base-uncased`
526
+ . `bert-large-uncased`
527
+ . `bert-base-cased`
528
+ . `bert-large-cased`
529
+ . `bert-base-multilingual-uncased`
530
+ . `bert-base-multilingual-cased`
531
+ . `bert-base-chinese`
532
+ - a path or url to a pretrained model archive containing:
533
+ . `bert_config.json` a configuration file for the model
534
+ . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
535
+ - a path or url to a pretrained model archive containing:
536
+ . `bert_config.json` a configuration file for the model
537
+ . `model.chkpt` a TensorFlow checkpoint
538
+ from_tf: should we load the weights from a locally saved TensorFlow checkpoint
539
+ cache_dir: an optional path to a folder in which the pre-trained models will be cached.
540
+ state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
541
+ *inputs, **kwargs: additional input for the specific Bert class
542
+ (ex: num_labels for BertForSequenceClassification)
543
+ """
544
+ if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
545
+ archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
546
+ else:
547
+ archive_file = pretrained_model_name_or_path
548
+ # redirect to the cache, if necessary
549
+ try:
550
+ resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
551
+ except EnvironmentError:
552
+ logger.error(
553
+ "Model name '{}' was not found in model name list ({}). "
554
+ "We assumed '{}' was a path or url but couldn't find any file "
555
+ "associated to this path or url.".format(
556
+ pretrained_model_name_or_path,
557
+ ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
558
+ archive_file))
559
+ return None
560
+ if resolved_archive_file == archive_file:
561
+ logger.info("loading archive file {}".format(archive_file))
562
+ else:
563
+ logger.info("loading archive file {} from cache at {}".format(
564
+ archive_file, resolved_archive_file))
565
+ tempdir = None
566
+ if os.path.isdir(resolved_archive_file) or from_tf:
567
+ serialization_dir = resolved_archive_file
568
+ else:
569
+ # Extract archive to temp dir
570
+ tempdir = tempfile.mkdtemp()
571
+ logger.info("extracting archive file {} to temp dir {}".format(
572
+ resolved_archive_file, tempdir))
573
+ with tarfile.open(resolved_archive_file, 'r:gz') as archive:
574
+ archive.extractall(tempdir)
575
+ serialization_dir = tempdir
576
+ # Load config
577
+ config_file = os.path.join(serialization_dir, CONFIG_NAME)
578
+ config = BertConfig.from_json_file(config_file)
579
+ logger.info("Model config {}".format(config))
580
+ # Instantiate model.
581
+ model = cls(config, *inputs, **kwargs)
582
+ if state_dict is None and not from_tf:
583
+ weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
584
+ state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
585
+ if tempdir:
586
+ # Clean up temp dir
587
+ shutil.rmtree(tempdir)
588
+ if from_tf:
589
+ # Directly load from a TensorFlow checkpoint
590
+ weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
591
+ return load_tf_weights_in_bert(model, weights_path)
592
+ # Load from a PyTorch state_dict
593
+ old_keys = []
594
+ new_keys = []
595
+ for key in state_dict.keys():
596
+ new_key = None
597
+ if 'gamma' in key:
598
+ new_key = key.replace('gamma', 'weight')
599
+ if 'beta' in key:
600
+ new_key = key.replace('beta', 'bias')
601
+ if new_key:
602
+ old_keys.append(key)
603
+ new_keys.append(new_key)
604
+ for old_key, new_key in zip(old_keys, new_keys):
605
+ state_dict[new_key] = state_dict.pop(old_key)
606
+
607
+ missing_keys = []
608
+ unexpected_keys = []
609
+ error_msgs = []
610
+ # copy state_dict so _load_from_state_dict can modify it
611
+ metadata = getattr(state_dict, '_metadata', None)
612
+ state_dict = state_dict.copy()
613
+ if metadata is not None:
614
+ state_dict._metadata = metadata
615
+
616
+ def load(module, prefix=''):
617
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
618
+ module._load_from_state_dict(
619
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
620
+ for name, child in module._modules.items():
621
+ if child is not None:
622
+ load(child, prefix + name + '.')
623
+ start_prefix = ''
624
+ if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
625
+ start_prefix = 'bert.'
626
+ load(model, prefix=start_prefix)
627
+ if len(missing_keys) > 0:
628
+ logger.info("Weights of {} not initialized from pretrained model: {}".format(
629
+ model.__class__.__name__, missing_keys))
630
+ if len(unexpected_keys) > 0:
631
+ logger.info("Weights from pretrained model not used in {}: {}".format(
632
+ model.__class__.__name__, unexpected_keys))
633
+ if len(error_msgs) > 0:
634
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
635
+ model.__class__.__name__, "\n\t".join(error_msgs)))
636
+ return model
637
+
638
+
639
+ class BertModel(BertPreTrainedModel):
640
+ """BERT model ("Bidirectional Embedding Representations from a Transformer").
641
+
642
+ Params:
643
+ config: a BertConfig class instance with the configuration to build a new model
644
+
645
+ Inputs:
646
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
647
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
648
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
649
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
650
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
651
+ a `sentence B` token (see BERT paper for more details).
652
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
653
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
654
+ input sequence length in the current batch. It's the mask that we typically use for attention when
655
+ a batch has varying length sentences.
656
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
657
+
658
+ Outputs: Tuple of (encoded_layers, pooled_output)
659
+ `encoded_layers`: controled by `output_all_encoded_layers` argument:
660
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
661
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
662
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
663
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
664
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
665
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
666
+ classifier pretrained on top of the hidden state associated to the first character of the
667
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
668
+
669
+ Example usage:
670
+ ```python
671
+ # Already been converted into WordPiece token ids
672
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
673
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
674
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
675
+
676
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
677
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
678
+
679
+ model = modeling.BertModel(config=config)
680
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
681
+ ```
682
+ """
683
+ def __init__(self, config):
684
+ super(BertModel, self).__init__(config)
685
+ self.embeddings = BertEmbeddings(config)
686
+ self.encoder = BertEncoder(config)
687
+ self.pooler = BertPooler(config)
688
+ self.apply(self.init_bert_weights)
689
+
690
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
691
+ if attention_mask is None:
692
+ attention_mask = torch.ones_like(input_ids)
693
+ if token_type_ids is None:
694
+ token_type_ids = torch.zeros_like(input_ids)
695
+
696
+ # We create a 3D attention mask from a 2D tensor mask.
697
+ # Sizes are [batch_size, 1, 1, to_seq_length]
698
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
699
+ # this attention mask is more simple than the triangular masking of causal attention
700
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
701
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
702
+
703
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
704
+ # masked positions, this operation will create a tensor which is 0.0 for
705
+ # positions we want to attend and -10000.0 for masked positions.
706
+ # Since we are adding it to the raw scores before the softmax, this is
707
+ # effectively the same as removing these entirely.
708
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
709
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
710
+
711
+ embedding_output = self.embeddings(input_ids, token_type_ids)
712
+ encoded_layers = self.encoder(embedding_output,
713
+ extended_attention_mask,
714
+ output_all_encoded_layers=output_all_encoded_layers)
715
+ sequence_output = encoded_layers[-1]
716
+ pooled_output = self.pooler(sequence_output)
717
+ if not output_all_encoded_layers:
718
+ encoded_layers = encoded_layers[-1]
719
+ return encoded_layers, pooled_output
720
+
721
+
722
+
723
+
724
+
725
+ class BertForSequenceEncoder(BertPreTrainedModel):
726
+ """BERT model for classification.
727
+ This module is composed of the BERT model with a linear layer on top of
728
+ the pooled output.
729
+ Params:
730
+ `config`: a BertConfig class instance with the configuration to build a new model.
731
+ `num_labels`: the number of classes for the classifier. Default = 2.
732
+ Inputs:
733
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
734
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
735
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
736
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
737
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
738
+ a `sentence B` token (see BERT paper for more details).
739
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
740
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
741
+ input sequence length in the current batch. It's the mask that we typically use for attention when
742
+ a batch has varying length sentences.
743
+ `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
744
+ with indices selected in [0, ..., num_labels].
745
+ Outputs:
746
+ if `labels` is not `None`:
747
+ Outputs the CrossEntropy classification loss of the output with the labels.
748
+ if `labels` is `None`:
749
+ Outputs the classification logits of shape [batch_size, num_labels].
750
+ Example usage:
751
+ ```python
752
+ # Already been converted into WordPiece token ids
753
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
754
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
755
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
756
+ config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
757
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
758
+ num_labels = 2
759
+ model = BertForSequenceClassification(config, num_labels)
760
+ logits = model(input_ids, token_type_ids, input_mask)
761
+ ```
762
+ """
763
+ def __init__(self, config):
764
+ super(BertForSequenceEncoder, self).__init__(config)
765
+ self.bert = BertModel(config)
766
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
767
+ self.apply(self.init_bert_weights)
768
+
769
+ def forward(self, input_ids, attention_mask, token_type_ids):
770
+ output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
771
+ output = self.dropout(output)
772
+ pooled_output = self.dropout(pooled_output)
773
+ return output, pooled_output
774
+
775
+
utils/callbacks.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
9
+ from pytorch_lightning.utilities import rank_zero_only
10
+
11
+ from utils.utils_verbalisation_module import save_json
12
+ from pytorch_lightning.utilities import rank_zero_info
13
+
14
+ def count_trainable_parameters(model):
15
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
16
+ params = sum([np.prod(p.size()) for p in model_parameters])
17
+ return params
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+
24
+ class Seq2SeqLoggingCallback(pl.Callback):
25
+ def on_batch_end(self, trainer, pl_module):
26
+ lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
27
+ pl_module.logger.log_metrics(lrs)
28
+
29
+ @rank_zero_only
30
+ def _write_logs(
31
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
32
+ ) -> None:
33
+ logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
34
+ metrics = trainer.callback_metrics
35
+ #print(metrics.keys())
36
+ new_metrics = {}
37
+ ms = ["log", "progress_bar", "preds"]
38
+ for k, v in metrics.items():
39
+ ver = True
40
+ for m in ms:
41
+ if m in k:
42
+ ver = False
43
+ break
44
+ if ver:
45
+ new_metrics[k] = v
46
+
47
+ print(new_metrics)
48
+ trainer.logger.log_metrics(new_metrics)
49
+ # Log results
50
+ od = Path(pl_module.hparams.output_dir)
51
+ if type_path == "test":
52
+ results_file = od / "test_results.txt"
53
+ generations_file = od / "test_generations.txt"
54
+ else:
55
+ # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
56
+ # If people want this it will be easy enough to add back.
57
+ results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
58
+ generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
59
+ results_file.parent.mkdir(exist_ok=True)
60
+ generations_file.parent.mkdir(exist_ok=True)
61
+ with open(results_file, "a+") as writer:
62
+ for key in sorted(metrics):
63
+ if key in ["log", "progress_bar", "preds"]:
64
+ continue
65
+ try:
66
+ val = metrics[key]
67
+ if isinstance(val, torch.Tensor):
68
+ val = val.item()
69
+ msg = f"{key}: {val:.6f}\n"
70
+ writer.write(msg)
71
+ except:
72
+ pass
73
+
74
+ if not save_generations:
75
+ return
76
+
77
+ if "preds" in metrics:
78
+ content = "\n".join(metrics["preds"])
79
+ generations_file.open("w+").write(content)
80
+
81
+ @rank_zero_only
82
+ def on_train_start(self, trainer, pl_module):
83
+ try:
84
+ npars = pl_module.model.model.num_parameters()
85
+ except AttributeError:
86
+ npars = pl_module.model.num_parameters()
87
+
88
+ n_trainable_pars = count_trainable_parameters(pl_module)
89
+ # mp stands for million parameters
90
+ trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
91
+
92
+ @rank_zero_only
93
+ def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
94
+ save_json(pl_module.metrics, pl_module.metrics_save_path)
95
+ return self._write_logs(trainer, pl_module, "test")
96
+
97
+ @rank_zero_only
98
+ def on_validation_end(self, trainer: pl.Trainer, pl_module):
99
+ save_json(pl_module.metrics, pl_module.metrics_save_path)
100
+
101
+ rank_zero_info("***** Validation results *****")
102
+ metrics = trainer.callback_metrics
103
+ # Log results
104
+ for key in sorted(metrics):
105
+ if key not in ["log", "progress_bar", "preds"]:
106
+ rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
107
+ # Uncommenting this will save val generations
108
+ # return self._write_logs(trainer, pl_module, "valid")
109
+
110
+
111
+ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
112
+ """Saves the best model by validation ROUGE2 score."""
113
+ if metric == "rouge2":
114
+ exp = "{val_avg_rouge2:.4f}-{step_count}"
115
+ elif metric == "bleu":
116
+ exp = "{val_avg_bleu:.4f}-{step_count}"
117
+ elif metric == "loss":
118
+ exp = "{val_avg_loss:.4f}-{step_count}"
119
+ else:
120
+ raise NotImplementedError(
121
+ f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
122
+ )
123
+
124
+ checkpoint_callback = ModelCheckpoint(
125
+ filepath=os.path.join(output_dir, exp),
126
+ monitor=f"val_{metric}",
127
+ mode="min" if "loss" in metric else "max",
128
+ save_top_k=save_top_k,
129
+ period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
130
+ )
131
+ return checkpoint_callback
132
+
133
+
134
+ def get_early_stopping_callback(metric, patience):
135
+ return EarlyStopping(
136
+ monitor=f"val_{metric}", # does this need avg?
137
+ mode="min" if "loss" in metric else "max",
138
+ patience=patience,
139
+ verbose=True,
140
+ )
utils/file_utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import tempfile
13
+ from functools import wraps
14
+ from hashlib import sha256
15
+ import sys
16
+ from io import open
17
+
18
+ import boto3
19
+ import requests
20
+ from botocore.exceptions import ClientError
21
+ from tqdm import tqdm
22
+
23
+ try:
24
+ from urllib.parse import urlparse
25
+ except ImportError:
26
+ from urlparse import urlparse
27
+
28
+ try:
29
+ from pathlib import Path
30
+ PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
31
+ Path.home() / '.pytorch_pretrained_bert'))
32
+ except AttributeError:
33
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
34
+ os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
35
+
36
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ def url_to_filename(url, etag=None):
40
+ """
41
+ Convert `url` into a hashed filename in a repeatable way.
42
+ If `etag` is specified, append its hash to the url's, delimited
43
+ by a period.
44
+ """
45
+ url_bytes = url.encode('utf-8')
46
+ url_hash = sha256(url_bytes)
47
+ filename = url_hash.hexdigest()
48
+
49
+ if etag:
50
+ etag_bytes = etag.encode('utf-8')
51
+ etag_hash = sha256(etag_bytes)
52
+ filename += '.' + etag_hash.hexdigest()
53
+
54
+ return filename
55
+
56
+
57
+ def filename_to_url(filename, cache_dir=None):
58
+ """
59
+ Return the url and etag (which may be ``None``) stored for `filename`.
60
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
61
+ """
62
+ if cache_dir is None:
63
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
64
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
65
+ cache_dir = str(cache_dir)
66
+
67
+ cache_path = os.path.join(cache_dir, filename)
68
+ if not os.path.exists(cache_path):
69
+ raise EnvironmentError("file {} not found".format(cache_path))
70
+
71
+ meta_path = cache_path + '.json'
72
+ if not os.path.exists(meta_path):
73
+ raise EnvironmentError("file {} not found".format(meta_path))
74
+
75
+ with open(meta_path, encoding="utf-8") as meta_file:
76
+ metadata = json.load(meta_file)
77
+ url = metadata['url']
78
+ etag = metadata['etag']
79
+
80
+ return url, etag
81
+
82
+
83
+ def cached_path(url_or_filename, cache_dir=None):
84
+ """
85
+ Given something that might be a URL (or might be a local path),
86
+ determine which. If it's a URL, download the file and cache it, and
87
+ return the path to the cached file. If it's already a local path,
88
+ make sure the file exists and then return the path.
89
+ """
90
+ if cache_dir is None:
91
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
92
+ if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
93
+ url_or_filename = str(url_or_filename)
94
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
95
+ cache_dir = str(cache_dir)
96
+
97
+ parsed = urlparse(url_or_filename)
98
+
99
+ if parsed.scheme in ('http', 'https', 's3'):
100
+ # URL, so get it from the cache (downloading if necessary)
101
+ return get_from_cache(url_or_filename, cache_dir)
102
+ elif os.path.exists(url_or_filename):
103
+ # File, and it exists.
104
+ return url_or_filename
105
+ elif parsed.scheme == '':
106
+ # File, but it doesn't exist.
107
+ raise EnvironmentError("file {} not found".format(url_or_filename))
108
+ else:
109
+ # Something unknown
110
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
111
+
112
+
113
+ def split_s3_path(url):
114
+ """Split a full s3 path into the bucket name and path."""
115
+ parsed = urlparse(url)
116
+ if not parsed.netloc or not parsed.path:
117
+ raise ValueError("bad s3 path {}".format(url))
118
+ bucket_name = parsed.netloc
119
+ s3_path = parsed.path
120
+ # Remove '/' at beginning of path.
121
+ if s3_path.startswith("/"):
122
+ s3_path = s3_path[1:]
123
+ return bucket_name, s3_path
124
+
125
+
126
+ def s3_request(func):
127
+ """
128
+ Wrapper function for s3 requests in order to create more helpful error
129
+ messages.
130
+ """
131
+
132
+ @wraps(func)
133
+ def wrapper(url, *args, **kwargs):
134
+ try:
135
+ return func(url, *args, **kwargs)
136
+ except ClientError as exc:
137
+ if int(exc.response["Error"]["Code"]) == 404:
138
+ raise EnvironmentError("file {} not found".format(url))
139
+ else:
140
+ raise
141
+
142
+ return wrapper
143
+
144
+
145
+ @s3_request
146
+ def s3_etag(url):
147
+ """Check ETag on S3 object."""
148
+ s3_resource = boto3.resource("s3")
149
+ bucket_name, s3_path = split_s3_path(url)
150
+ s3_object = s3_resource.Object(bucket_name, s3_path)
151
+ return s3_object.e_tag
152
+
153
+
154
+ @s3_request
155
+ def s3_get(url, temp_file):
156
+ """Pull a file directly from S3."""
157
+ s3_resource = boto3.resource("s3")
158
+ bucket_name, s3_path = split_s3_path(url)
159
+ s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
160
+
161
+
162
+ def http_get(url, temp_file):
163
+ req = requests.get(url, stream=True)
164
+ content_length = req.headers.get('Content-Length')
165
+ total = int(content_length) if content_length is not None else None
166
+ progress = tqdm(unit="B", total=total)
167
+ for chunk in req.iter_content(chunk_size=1024):
168
+ if chunk: # filter out keep-alive new chunks
169
+ progress.update(len(chunk))
170
+ temp_file.write(chunk)
171
+ progress.close()
172
+
173
+
174
+ def get_from_cache(url, cache_dir=None):
175
+ """
176
+ Given a URL, look for the corresponding dataset in the local cache.
177
+ If it's not there, download it. Then return the path to the cached file.
178
+ """
179
+ if cache_dir is None:
180
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
181
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
182
+ cache_dir = str(cache_dir)
183
+
184
+ if not os.path.exists(cache_dir):
185
+ os.makedirs(cache_dir)
186
+
187
+ # Get eTag to add to filename, if it exists.
188
+ if url.startswith("s3://"):
189
+ etag = s3_etag(url)
190
+ else:
191
+ response = requests.head(url, allow_redirects=True)
192
+ if response.status_code != 200:
193
+ raise IOError("HEAD request failed for url {} with status code {}"
194
+ .format(url, response.status_code))
195
+ etag = response.headers.get("ETag")
196
+
197
+ filename = url_to_filename(url, etag)
198
+
199
+ # get cache path to put the file
200
+ cache_path = os.path.join(cache_dir, filename)
201
+
202
+ if not os.path.exists(cache_path):
203
+ # Download to temporary file, then copy to cache dir once finished.
204
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
205
+ with tempfile.NamedTemporaryFile() as temp_file:
206
+ logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
207
+
208
+ # GET file object
209
+ if url.startswith("s3://"):
210
+ s3_get(url, temp_file)
211
+ else:
212
+ http_get(url, temp_file)
213
+
214
+ # we are copying the file before closing it, so flush to avoid truncation
215
+ temp_file.flush()
216
+ # shutil.copyfileobj() starts at the current position, so go to the start
217
+ temp_file.seek(0)
218
+
219
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
220
+ with open(cache_path, 'wb') as cache_file:
221
+ shutil.copyfileobj(temp_file, cache_file)
222
+
223
+ logger.info("creating metadata file for %s", cache_path)
224
+ meta = {'url': url, 'etag': etag}
225
+ meta_path = cache_path + '.json'
226
+ with open(meta_path, 'w', encoding="utf-8") as meta_file:
227
+ json.dump(meta, meta_file)
228
+
229
+ logger.info("removing temp file %s", temp_file.name)
230
+
231
+ return cache_path
232
+
233
+
234
+ def read_set_from_file(filename):
235
+ '''
236
+ Extract a de-duped collection (set) of text from a file.
237
+ Expected file format is one item per line.
238
+ '''
239
+ collection = set()
240
+ with open(filename, 'r', encoding='utf-8') as file_:
241
+ for line in file_:
242
+ collection.add(line.rstrip())
243
+ return collection
244
+
245
+
246
+ def get_file_extension(path, dot=True, lower=True):
247
+ ext = os.path.splitext(path)[1]
248
+ ext = ext if dot else ext[1:]
249
+ return ext.lower() if lower else ext
utils/finetune.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ import glob
5
+ import logging
6
+ import os
7
+ import sys
8
+ import time
9
+ from collections import defaultdict
10
+ from pathlib import Path
11
+ from typing import Dict, List, Tuple
12
+ import pdb
13
+
14
+ import numpy as np
15
+ import pytorch_lightning as pl
16
+ import torch
17
+ from torch.utils.data import DataLoader
18
+
19
+ from pytorch_lightning.utilities import rank_zero_info
20
+
21
+ from utils.callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
22
+ from transformers import MBartTokenizer, T5ForConditionalGeneration
23
+
24
+ from transformers.models.bart.modeling_bart import shift_tokens_right
25
+ from utils.utils_verbalisation_module import (
26
+ ROUGE_KEYS,
27
+ LegacySeq2SeqDataset,
28
+ Seq2SeqDataset,
29
+ assert_all_frozen,
30
+ calculate_bleu,
31
+ calculate_rouge,
32
+ flatten_list,
33
+ freeze_embeds,
34
+ freeze_params,
35
+ label_smoothed_nll_loss,
36
+ lmap,
37
+ pickle_save,
38
+ save_json,
39
+ use_task_specific_params,
40
+ )
41
+
42
+ from utils.utils_graph2text import convert_text, eval_meteor, eval_bleu, eval_chrf, eval_meteor_test_webnlg, eval_chrf_test_webnlg
43
+
44
+ # need the parent dir module
45
+ sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
46
+ from utils.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
47
+
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ class SummarizationModule(BaseTransformer):
53
+ mode = "summarization"
54
+ loss_names = ["loss"]
55
+ metric_names = ROUGE_KEYS
56
+ default_val_metric = "rouge2"
57
+
58
+ def __init__(self, hparams, **kwargs):
59
+ if hparams.sortish_sampler and hparams.gpus > 1:
60
+ hparams.replace_sampler_ddp = False
61
+ elif hparams.max_tokens_per_batch is not None:
62
+ if hparams.gpus > 1:
63
+ raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
64
+ if hparams.sortish_sampler:
65
+ raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")
66
+
67
+ super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
68
+ #use_task_specific_params(self.model, "summarization")
69
+
70
+ self.metrics_save_path = Path('base') / "metrics.json"
71
+ self.hparams_save_path = Path('base') / "hparams.pkl"
72
+ pickle_save(self.hparams, self.hparams_save_path)
73
+ self.step_count = -2
74
+ self.metrics = defaultdict(list)
75
+ self.model_type = self.config.model_type
76
+ self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
77
+
78
+ if 't5' in hparams.model_name_or_path:
79
+ self.model.config.prefix = 'translate Graph to English: '
80
+ self.dataset_kwargs: dict = dict(
81
+ data_dir=self.hparams.data_dir,
82
+ max_source_length=self.hparams.max_source_length,
83
+ prefix=self.model.config.prefix or "",
84
+ )
85
+ n_observations_per_split = {
86
+ "train": self.hparams.n_train,
87
+ "val": self.hparams.n_val,
88
+ "test_seen": self.hparams.n_test,
89
+ "test_unseen": self.hparams.n_test,
90
+ "test_both": self.hparams.n_test,
91
+ }
92
+ self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
93
+
94
+ self.target_lens = {
95
+ "train": self.hparams.max_target_length,
96
+ "val": self.hparams.val_max_target_length,
97
+ "test_seen": self.hparams.test_max_target_length,
98
+ "test_unseen": self.hparams.test_max_target_length,
99
+ "test_both": self.hparams.test_max_target_length,
100
+ }
101
+ assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
102
+ assert self.target_lens["train"] <= self.target_lens["test_both"], f"target_lens: {self.target_lens}"
103
+ if self.hparams.freeze_embeds:
104
+ freeze_embeds(self.model)
105
+ if self.hparams.freeze_encoder:
106
+ freeze_params(self.model.get_encoder())
107
+ assert_all_frozen(self.model.get_encoder())
108
+
109
+ self.num_workers = hparams.num_workers
110
+ self.decoder_start_token_id = None # default to config
111
+ if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
112
+ self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
113
+ self.model.config.decoder_start_token_id = self.decoder_start_token_id
114
+ self.dataset_class = (
115
+ Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
116
+ )
117
+ self.already_saved_batch = False
118
+ self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
119
+ if self.hparams.eval_max_gen_length is not None:
120
+ self.eval_max_length = self.hparams.eval_max_gen_length
121
+ else:
122
+ self.eval_max_length = self.model.config.max_length
123
+ self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
124
+
125
+ def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
126
+ """A debugging utility"""
127
+
128
+ readable_batch = {
129
+ k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items()
130
+ }
131
+ save_json(readable_batch, Path(self.output_dir) / "text_batch.json")
132
+
133
+ tb = {}
134
+ for k, v in batch.items():
135
+ tb[k] = v.tolist()
136
+
137
+ save_json(tb, Path(self.output_dir) / "tok_batch.json")
138
+
139
+ self.already_saved_batch = True
140
+ return readable_batch
141
+
142
+ def forward(self, input_ids, **kwargs):
143
+ return self.model(input_ids, **kwargs)
144
+
145
+ def ids_to_clean_text(self, generated_ids: List[int]):
146
+ gen_text = self.tokenizer.batch_decode(
147
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
148
+ )
149
+ return lmap(str.strip, gen_text)
150
+
151
+ def _step(self, batch: dict) -> Tuple:
152
+ pad_token_id = self.tokenizer.pad_token_id
153
+ src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
154
+ if isinstance(self.model, T5ForConditionalGeneration):
155
+ tgt_ids = batch["labels"]
156
+ decoder_input_ids = self.model._shift_right(tgt_ids)
157
+ else:
158
+ #decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
159
+ y = batch["labels"]
160
+ decoder_input_ids = y[:, :-1].contiguous()
161
+ tgt_ids = y[:, 1:].clone()
162
+ if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero
163
+ batch["decoder_input_ids"] = decoder_input_ids
164
+ self.save_readable_batch(batch)
165
+
166
+ outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
167
+ lm_logits = outputs[0]
168
+ if self.hparams.label_smoothing == 0:
169
+ # Same behavior as modeling_bart.py, besides ignoring pad_token_id
170
+ ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
171
+
172
+ assert lm_logits.shape[-1] == self.vocab_size
173
+ loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
174
+ else:
175
+ lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
176
+ loss, nll_loss = label_smoothed_nll_loss(
177
+ lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
178
+ )
179
+ return (loss,)
180
+
181
+ @property
182
+ def pad(self) -> int:
183
+ return self.tokenizer.pad_token_id
184
+
185
+ def training_step(self, batch, batch_idx) -> Dict:
186
+ loss_tensors = self._step(batch)
187
+
188
+ logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
189
+ # tokens per batch
190
+ logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
191
+ logs["bs"] = batch["input_ids"].shape[0]
192
+ logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
193
+ logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
194
+ # TODO(SS): make a wandb summary metric for this
195
+ return {"loss": loss_tensors[0], "log": logs}
196
+
197
+ def validation_step(self, batch, batch_idx) -> Dict:
198
+ return self._generative_step(batch)
199
+
200
+ def validation_epoch_end(self, outputs, prefix="val") -> Dict:
201
+
202
+ self.step_count += 1
203
+
204
+ val_outputs_folder = "val_outputs"
205
+ os.system("mkdir -p " + os.path.join(self.hparams.output_dir, val_outputs_folder))
206
+
207
+ if prefix == "val":
208
+ output_test_predictions_file = os.path.join(self.hparams.output_dir, val_outputs_folder, "validation_predictions_" +
209
+ str(self.step_count) + ".txt")
210
+ output_test_targets_file = os.path.join(self.hparams.output_dir, val_outputs_folder, "validation_targets_" +
211
+ str(self.step_count) + ".txt")
212
+ # write predictions and targets for later rouge evaluation.
213
+ with open(output_test_predictions_file, "w") as p_writer, open(output_test_targets_file, "w") as t_writer:
214
+ for output_batch in outputs:
215
+ p_writer.writelines(convert_text(s) + "\n" for s in output_batch["preds"])
216
+ t_writer.writelines(convert_text(s) + "\n" for s in output_batch["target"])
217
+ p_writer.close()
218
+ t_writer.close()
219
+
220
+ bleu_info = eval_bleu(self.hparams.data_dir, output_test_predictions_file, 'val')
221
+
222
+ rank_zero_info("%s bleu_info: %s", self.step_count, bleu_info)
223
+
224
+ if bleu_info == -1:
225
+ bleu_info = float(bleu_info)
226
+ else:
227
+ bleu_info = float(bleu_info.split(",")[0].split("BLEU = ")[1])
228
+
229
+ losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
230
+ loss = losses["loss"]
231
+ generative_metrics = {
232
+ k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
233
+ }
234
+
235
+ generative_metrics['bleu'] = bleu_info
236
+
237
+ metric_val = (
238
+ generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[
239
+ self.val_metric]
240
+ )
241
+ metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
242
+ generative_metrics.update({k: v.item() for k, v in losses.items()})
243
+ losses.update(generative_metrics)
244
+ all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
245
+ all_metrics["step_count"] = self.step_count
246
+ self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
247
+ preds = flatten_list([x["preds"] for x in outputs])
248
+
249
+ return {
250
+ "bleu": bleu_info,
251
+ "log": all_metrics,
252
+ "preds": preds,
253
+ f"{prefix}_loss": loss,
254
+ f"{prefix}_{self.val_metric}": metric_tensor,
255
+ }
256
+ else:
257
+
258
+ data_logs = {}
259
+ for output in outputs:
260
+
261
+ dataset_idx = output[0]['dataloader_idx']
262
+
263
+ if dataset_idx == 0:
264
+ dataset_name = 'test_both'
265
+ elif dataset_idx == 1:
266
+ dataset_name = 'test_seen'
267
+ else:
268
+ dataset_name = 'test_unseen'
269
+
270
+ if output[0]['bleu'] == -1:
271
+ bleu_info = float(output[0]['bleu'])
272
+ else:
273
+ bleu_info = float(output[0]['bleu'].split(",")[0].split("BLEU = ")[1])
274
+
275
+
276
+ losses = {k: torch.stack([x[k] for x in output]).mean() for k in self.loss_names}
277
+ loss = losses["loss"]
278
+ generative_metrics = {
279
+ k: np.array([x[k] for x in output]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
280
+ }
281
+
282
+ generative_metrics['bleu'] = bleu_info
283
+
284
+ metric_val = (
285
+ generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[
286
+ self.val_metric]
287
+ )
288
+ metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
289
+ generative_metrics.update({k: v.item() for k, v in losses.items()})
290
+ losses.update(generative_metrics)
291
+ all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
292
+ all_metrics["step_count"] = self.step_count
293
+ self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
294
+ preds = flatten_list([x["preds"] for x in output])
295
+
296
+ data_logs.update({
297
+ "log" + "_" + dataset_name: all_metrics,
298
+ "preds" + "_" + dataset_name: preds,
299
+ f"{prefix}_loss" + "_" + dataset_name: loss,
300
+ f"{prefix}_{self.val_metric}" + "_" + dataset_name: metric_tensor,
301
+ })
302
+ return data_logs
303
+
304
+
305
+ #######
306
+
307
+
308
+
309
+
310
+ def calc_generative_metrics(self, preds, target) -> Dict:
311
+ return calculate_rouge(preds, target)
312
+
313
+ def _generative_step(self, batch: dict, batch_idx=None, dataloader_idx=None) -> dict:
314
+ t0 = time.time()
315
+
316
+ # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
317
+ generated_ids = self.model.generate(
318
+ batch["input_ids"],
319
+ attention_mask=batch["attention_mask"],
320
+ use_cache=True,
321
+ decoder_start_token_id=self.decoder_start_token_id,
322
+ num_beams=self.eval_beams,
323
+ max_length=self.eval_max_length,
324
+ length_penalty=1.0
325
+ )
326
+ gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
327
+ preds: List[str] = self.ids_to_clean_text(generated_ids)
328
+ target: List[str] = self.ids_to_clean_text(batch["labels"])
329
+ loss_tensors = self._step(batch)
330
+ base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
331
+ rouge: Dict = self.calc_generative_metrics(preds, target)
332
+ summ_len = np.mean(lmap(len, generated_ids))
333
+ base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
334
+
335
+ if dataloader_idx is not None:
336
+ base_metrics.update(batch_idx=batch_idx, dataloader_idx=dataloader_idx)
337
+ return base_metrics
338
+
339
+ def test_step(self, batch, batch_idx, dataloader_idx):
340
+ return self._generative_step(batch, batch_idx, dataloader_idx)
341
+
342
+ def test_epoch_end(self, outputs_all_testsets):
343
+
344
+ val_outputs_folder = "val_outputs"
345
+ os.system("mkdir -p " + os.path.join(self.hparams.output_dir, val_outputs_folder))
346
+
347
+ for outputs in outputs_all_testsets:
348
+ dataset_idx = outputs[0]['dataloader_idx']
349
+
350
+ if dataset_idx == 0:
351
+ file_name = "test_both_predictions.txt"
352
+ file_name_tgt = "test_both_targets.txt"
353
+ dataset_name = 'test_both'
354
+ elif dataset_idx == 1:
355
+ file_name = "test_seen_predictions.txt"
356
+ file_name_tgt = "test_seen_targets.txt"
357
+ dataset_name = 'test_seen'
358
+ else:
359
+ file_name = "test_unseen_predictions.txt"
360
+ file_name_tgt = "test_unseen_targets.txt"
361
+ dataset_name = 'test_unseen'
362
+
363
+ file_name += '.debug'
364
+ file_name_tgt += '.debug'
365
+
366
+ output_test_predictions_file = os.path.join(self.hparams.output_dir, val_outputs_folder, file_name)
367
+ output_test_targets_file = os.path.join(self.hparams.output_dir, val_outputs_folder, file_name_tgt)
368
+ # write predictions and targets for later rouge evaluation.
369
+ with open(output_test_predictions_file, "w") as p_writer, open(output_test_targets_file, "w") as t_writer:
370
+ for output_batch in outputs:
371
+
372
+ p_writer.writelines(convert_text(s) + "\n" for s in output_batch["preds"])
373
+ t_writer.writelines(convert_text(s) + "\n" for s in output_batch["target"])
374
+ p_writer.close()
375
+ t_writer.close()
376
+
377
+ bleu_info = eval_bleu(self.hparams.data_dir, output_test_predictions_file, dataset_name)
378
+ meteor_info = eval_meteor_test_webnlg(self.hparams.data_dir, output_test_predictions_file, dataset_name)
379
+ chrf_info = eval_chrf_test_webnlg(self.hparams.data_dir, output_test_predictions_file, dataset_name)
380
+
381
+ rank_zero_info(" %s - bleu_info: %s", dataset_name, bleu_info)
382
+ rank_zero_info(" %s - meteor_info: %s", dataset_name, meteor_info)
383
+ rank_zero_info(" %s - chrf_info: %s", dataset_name, chrf_info)
384
+
385
+ outputs[0]['bleu'] = bleu_info
386
+
387
+ return self.validation_epoch_end(outputs_all_testsets, prefix="test")
388
+
389
+ def get_dataset(self, type_path) -> Seq2SeqDataset:
390
+ n_obs = self.n_obs[type_path]
391
+ max_target_length = self.target_lens[type_path]
392
+ dataset = self.dataset_class(
393
+ self.tokenizer,
394
+ type_path=type_path,
395
+ n_obs=n_obs,
396
+ max_target_length=max_target_length,
397
+ **self.dataset_kwargs,
398
+ )
399
+ return dataset
400
+
401
+ def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
402
+ dataset = self.get_dataset(type_path)
403
+
404
+ if self.hparams.sortish_sampler and type_path != "test":
405
+ sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
406
+ return DataLoader(
407
+ dataset,
408
+ batch_size=batch_size,
409
+ collate_fn=dataset.collate_fn,
410
+ shuffle=False,
411
+ num_workers=self.num_workers,
412
+ sampler=sampler,
413
+ )
414
+
415
+ elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
416
+ batch_sampler = dataset.make_dynamic_sampler(
417
+ self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
418
+ )
419
+ return DataLoader(
420
+ dataset,
421
+ batch_sampler=batch_sampler,
422
+ collate_fn=dataset.collate_fn,
423
+ # shuffle=False,
424
+ num_workers=self.num_workers,
425
+ # batch_size=None,
426
+ )
427
+ else:
428
+ return DataLoader(
429
+ dataset,
430
+ batch_size=batch_size,
431
+ collate_fn=dataset.collate_fn,
432
+ shuffle=shuffle,
433
+ num_workers=self.num_workers,
434
+ sampler=None,
435
+ )
436
+
437
+ def train_dataloader(self) -> DataLoader:
438
+ dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
439
+ return dataloader
440
+
441
+ def val_dataloader(self) -> DataLoader:
442
+ return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
443
+
444
+ def test_dataloader(self) -> List[DataLoader]:
445
+ test_dataloader = self.get_dataloader("test_both", batch_size=self.hparams.eval_batch_size)
446
+ test_seen_dataloader = self.get_dataloader("test_seen", batch_size=self.hparams.eval_batch_size)
447
+ test_unseen_dataloader = self.get_dataloader("test_unseen", batch_size=self.hparams.eval_batch_size)
448
+
449
+ return [test_dataloader, test_seen_dataloader, test_unseen_dataloader]
450
+
451
+ @staticmethod
452
+ def add_model_specific_args(parser, root_dir):
453
+ BaseTransformer.add_model_specific_args(parser, root_dir)
454
+ add_generic_args(parser, root_dir)
455
+ parser.add_argument(
456
+ "--max_source_length",
457
+ default=1024,
458
+ type=int,
459
+ help="The maximum total input sequence length after tokenization. Sequences longer "
460
+ "than this will be truncated, sequences shorter will be padded.",
461
+ )
462
+ parser.add_argument(
463
+ "--max_target_length",
464
+ default=56,
465
+ type=int,
466
+ help="The maximum total input sequence length after tokenization. Sequences longer "
467
+ "than this will be truncated, sequences shorter will be padded.",
468
+ )
469
+ parser.add_argument(
470
+ "--val_max_target_length",
471
+ default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
472
+ type=int,
473
+ help="The maximum total input sequence length after tokenization. Sequences longer "
474
+ "than this will be truncated, sequences shorter will be padded.",
475
+ )
476
+ parser.add_argument(
477
+ "--test_max_target_length",
478
+ default=142,
479
+ type=int,
480
+ help="The maximum total input sequence length after tokenization. Sequences longer "
481
+ "than this will be truncated, sequences shorter will be padded.",
482
+ )
483
+ parser.add_argument("--freeze_encoder", action="store_true")
484
+ parser.add_argument("--freeze_embeds", action="store_true")
485
+ parser.add_argument("--sortish_sampler", action="store_true", default=False)
486
+ parser.add_argument("--max_tokens_per_batch", type=int, default=None)
487
+ parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
488
+ parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
489
+ parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
490
+ parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
491
+ parser.add_argument(
492
+ "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
493
+ )
494
+ parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
495
+ parser.add_argument("--src_lang", type=str, default="", required=False)
496
+ parser.add_argument("--tgt_lang", type=str, default="", required=False)
497
+ parser.add_argument("--eval_beams", type=int, default=None, required=False)
498
+ parser.add_argument("--checkpoint", type=str, default=None, required=False)
499
+ parser.add_argument(
500
+ "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
501
+ )
502
+ parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
503
+ parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
504
+ parser.add_argument(
505
+ "--early_stopping_patience",
506
+ type=int,
507
+ default=-1,
508
+ required=False,
509
+ help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
510
+ )
511
+
512
+ return parser
513
+
514
+
515
+ class TranslationModule(SummarizationModule):
516
+ mode = "translation"
517
+ loss_names = ["loss"]
518
+ metric_names = ["bleu"]
519
+ default_val_metric = "bleu"
520
+
521
+ def __init__(self, hparams, **kwargs):
522
+ super().__init__(hparams, **kwargs)
523
+ self.dataset_kwargs["src_lang"] = hparams.src_lang
524
+ self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
525
+
526
+ def calc_generative_metrics(self, preds, target) -> dict:
527
+ return calculate_bleu(preds, target)
528
+
529
+
530
+ class Graph2TextModule(SummarizationModule):
531
+ mode = "graph2text"
532
+ loss_names = ["loss"]
533
+ metric_names = ["sacrebleu"]
534
+ default_val_metric = "bleu"
535
+
536
+ def __init__(self, hparams, **kwargs):
537
+ if type(hparams) == dict:
538
+ hparams = argparse.Namespace(**hparams)
539
+ print(f'Graph2Text hparams are: {hparams}')
540
+ super().__init__(hparams, **kwargs)
541
+
542
+ self.hparams.update(vars(hparams))
543
+
544
+ rank_zero_info("parameters %s", hparams)
545
+
546
+ def calc_generative_metrics(self, preds, target) -> dict:
547
+ return calculate_bleu(preds, target)
548
+
549
+
550
+ def main(args, model=None) -> SummarizationModule:
551
+ Path(args.output_dir).mkdir(exist_ok=True)
552
+ if len(os.listdir(args.output_dir)) > 3 and args.do_train:
553
+ raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
554
+ if model is None:
555
+ if "summarization" in args.task:
556
+ model: SummarizationModule = SummarizationModule(args)
557
+ elif "translation" in args.task:
558
+ model: SummarizationModule = TranslationModule(args)
559
+ else:
560
+ model: SummarizationModule = Graph2TextModule(args)
561
+ dataset = Path(args.data_dir).name
562
+ if (
563
+ args.logger_name == "default"
564
+ or args.fast_dev_run
565
+ or str(args.output_dir).startswith("/tmp")
566
+ or str(args.output_dir).startswith("/var")
567
+ ):
568
+ logger = True # don't pollute wandb logs unnecessarily
569
+ elif args.logger_name == "wandb":
570
+ from pytorch_lightning.loggers import WandbLogger
571
+
572
+ project = os.environ.get("WANDB_PROJECT", dataset)
573
+ logger = WandbLogger(name=model.output_dir.name, project=project)
574
+
575
+ elif args.logger_name == "wandb_shared":
576
+ from pytorch_lightning.loggers import WandbLogger
577
+
578
+ logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
579
+
580
+ if args.early_stopping_patience >= 0:
581
+ es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
582
+ else:
583
+ es_callback = False
584
+
585
+ lower_is_better = args.val_metric == "loss"
586
+ trainer: pl.Trainer = generic_train(
587
+ model,
588
+ args,
589
+ logging_callback=Seq2SeqLoggingCallback(),
590
+ checkpoint_callback=get_checkpoint_callback(
591
+ args.output_dir, model.val_metric, args.save_top_k, lower_is_better
592
+ ),
593
+ early_stopping_callback=es_callback,
594
+ logger=logger,
595
+ )
596
+ pickle_save(model.hparams, model.output_dir / "hparams.pkl")
597
+ if not args.do_predict:
598
+ return model
599
+
600
+ model.hparams.test_checkpoint = ""
601
+ if not args.checkpoint:
602
+ checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
603
+ else:
604
+ checkpoints = [args.checkpoint]
605
+
606
+ if checkpoints:
607
+ model.hparams.test_checkpoint = checkpoints[-1]
608
+ trainer.resume_from_checkpoint = checkpoints[-1]
609
+
610
+ if args.do_predict and not args.do_train:
611
+
612
+ checkpoint = checkpoints[-1]
613
+ print(checkpoint)
614
+ #trainer.test(ckpt_path=checkpoints[-1])
615
+ trainer.test(model, ckpt_path=checkpoint)
616
+ return model
617
+
618
+
619
+ trainer.logger.log_hyperparams(model.hparams)
620
+
621
+ # test() without a model tests using the best checkpoint automatically
622
+ trainer.test()
623
+ return model
624
+
625
+
626
+ if __name__ == "__main__":
627
+ parser = argparse.ArgumentParser()
628
+ parser = pl.Trainer.add_argparse_args(parser)
629
+ parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
630
+
631
+ args = parser.parse_args()
632
+
633
+ main(args)
utils/lightning_base.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any, Dict
6
+ import sys
7
+ import pytorch_lightning as pl
8
+ from pytorch_lightning.utilities import rank_zero_info
9
+ from pytorch_lightning.callbacks import LearningRateMonitor
10
+
11
+ from transformers import (
12
+ AdamW,
13
+ AutoConfig,
14
+ AutoModel,
15
+ AutoModelForPreTraining,
16
+ AutoModelForQuestionAnswering,
17
+ AutoModelForSeq2SeqLM,
18
+ AutoModelForSequenceClassification,
19
+ AutoModelForTokenClassification,
20
+ AutoModelWithLMHead,
21
+ AutoTokenizer,
22
+ PretrainedConfig,
23
+ PreTrainedTokenizer,
24
+ )
25
+ from transformers.optimization import (
26
+ Adafactor,
27
+ get_cosine_schedule_with_warmup,
28
+ get_cosine_with_hard_restarts_schedule_with_warmup,
29
+ get_linear_schedule_with_warmup,
30
+ get_polynomial_decay_schedule_with_warmup,
31
+ )
32
+
33
+ from tokenizers import AddedToken
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ MODEL_MODES = {
38
+ "base": AutoModel,
39
+ "sequence-classification": AutoModelForSequenceClassification,
40
+ "question-answering": AutoModelForQuestionAnswering,
41
+ "pretraining": AutoModelForPreTraining,
42
+ "token-classification": AutoModelForTokenClassification,
43
+ "language-modeling": AutoModelWithLMHead,
44
+ "summarization": AutoModelForSeq2SeqLM,
45
+ "translation": AutoModelForSeq2SeqLM,
46
+ "graph2text": AutoModelForSeq2SeqLM,
47
+ }
48
+
49
+
50
+ # update this and the import above to support new schedulers from transformers.optimization
51
+ arg_to_scheduler = {
52
+ "linear": get_linear_schedule_with_warmup,
53
+ "cosine": get_cosine_schedule_with_warmup,
54
+ "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
55
+ "polynomial": get_polynomial_decay_schedule_with_warmup,
56
+ # '': get_constant_schedule, # not supported for now
57
+ # '': get_constant_schedule_with_warmup, # not supported for now
58
+ }
59
+ arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
60
+ arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
61
+
62
+
63
+ class BaseTransformer(pl.LightningModule):
64
+ def __init__(
65
+ self,
66
+ hparams: argparse.Namespace,
67
+ num_labels=None,
68
+ mode="base",
69
+ config=None,
70
+ tokenizer=None,
71
+ model=None,
72
+ **config_kwargs
73
+ ):
74
+ """Initialize a model, tokenizer and config."""
75
+ super().__init__()
76
+ # TODO: move to self.save_hyperparameters()
77
+ # self.save_hyperparameters()
78
+ # can also expand arguments into trainer signature for easier reading
79
+ self.save_hyperparameters(hparams)
80
+ self.step_count = -2
81
+ self.output_dir = Path(self.hparams.output_dir)
82
+ cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
83
+ if config is None:
84
+ self.config = AutoConfig.from_pretrained(
85
+ self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
86
+ **({"num_labels": num_labels} if num_labels is not None else {}),
87
+ cache_dir=cache_dir,
88
+ **config_kwargs,
89
+ )
90
+ else:
91
+ self.config: PretrainedConfig = config
92
+
93
+ extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
94
+ for p in extra_model_params:
95
+ if getattr(self.hparams, p, None):
96
+ assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
97
+ setattr(self.config, p, getattr(self.hparams, p))
98
+
99
+ if tokenizer is None:
100
+ self.tokenizer = AutoTokenizer.from_pretrained(
101
+ self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
102
+ cache_dir=cache_dir,
103
+ )
104
+ new_tokens = [
105
+ '<H>','<R>','<T>'
106
+ ]
107
+ new_tokens_vocab = {}
108
+ new_tokens_vocab['additional_special_tokens'] = []
109
+ for idx, t in enumerate(new_tokens):
110
+ new_tokens_vocab['additional_special_tokens'].append(t)
111
+ num_added_toks = self.tokenizer.add_special_tokens(new_tokens_vocab)
112
+ rank_zero_info('We have added %s tokens', num_added_toks)
113
+ else:
114
+ self.tokenizer: PreTrainedTokenizer = tokenizer
115
+ self.model_type = MODEL_MODES[mode]
116
+ if model is None:
117
+ self.model = self.model_type.from_pretrained(
118
+ self.hparams.model_name_or_path,
119
+ from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
120
+ config=self.config,
121
+ cache_dir=cache_dir,
122
+ )
123
+ self.model.resize_token_embeddings(len(self.tokenizer))
124
+ else:
125
+ self.model = model
126
+
127
+ def load_hf_checkpoint(self, *args, **kwargs):
128
+ self.model = self.model_type.from_pretrained(*args, **kwargs)
129
+
130
+ def get_lr_scheduler(self):
131
+ get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
132
+ scheduler = get_schedule_func(
133
+ self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
134
+ )
135
+ scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
136
+ return scheduler
137
+
138
+ def configure_optimizers(self):
139
+ """Prepare optimizer and schedule (linear warmup and decay)"""
140
+ model = self.model
141
+ no_decay = ["bias", "LayerNorm.weight"]
142
+ optimizer_grouped_parameters = [
143
+ {
144
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
145
+ "weight_decay": self.hparams.weight_decay,
146
+ },
147
+ {
148
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
149
+ "weight_decay": 0.0,
150
+ },
151
+ ]
152
+ if self.hparams.adafactor:
153
+ optimizer = Adafactor(
154
+ optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
155
+ )
156
+
157
+ else:
158
+ optimizer = AdamW(
159
+ optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
160
+ )
161
+ self.opt = optimizer
162
+
163
+ scheduler = self.get_lr_scheduler()
164
+
165
+ return [optimizer], [scheduler]
166
+
167
+
168
+ def test_step(self, batch, batch_nb):
169
+ return self.validation_step(batch, batch_nb)
170
+
171
+ def test_epoch_end(self, outputs):
172
+ return self.validation_end(outputs)
173
+
174
+ @property
175
+ def total_steps(self) -> int:
176
+ # print('self.hparams.gpus', self.hparams.gpus)
177
+ # print('self.hparams.accumulate_grad_batches', self.hparams.accumulate_grad_batches)
178
+ # print('self.train_loader.dataset', self.train_loader.dataset)
179
+ # print('self.hparams.max_epochs', self.hparams.max_epochs)
180
+ # print('self.hparams.train_batch_size', self.hparams.train_batch_size)
181
+ # exit()
182
+ """The number of total training steps that will be run. Used for lr scheduler purposes."""
183
+ num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
184
+ effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
185
+ dataset_size = len(self.train_loader.dataset)
186
+ return (dataset_size / effective_batch_size) * self.hparams.max_epochs
187
+
188
+ def setup(self, mode):
189
+ #if mode == "fit":
190
+ self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
191
+
192
+ def get_dataloader(self, type_path, batch_size, shuffle=False):
193
+ raise NotImplementedError("You must implement this for your task")
194
+
195
+ def train_dataloader(self):
196
+ return self.train_loader
197
+
198
+ def val_dataloader(self):
199
+ return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False)
200
+
201
+ def test_dataloader(self):
202
+ return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False)
203
+
204
+ def _feature_file(self, mode):
205
+ return os.path.join(
206
+ self.hparams.data_dir,
207
+ "cached_{}_{}_{}".format(
208
+ mode,
209
+ list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
210
+ str(self.hparams.max_seq_length),
211
+ ),
212
+ )
213
+
214
+ def get_progress_bar_dict(self):
215
+ #metrics = self.trainer.callback_metrics
216
+ #print(self.trainer.lr_logger.lrs)
217
+ lrs = self.trainer.lr_logger.lrs['lr-AdamW/pg1'][-1]
218
+ running_train_loss = self.trainer.running_loss.mean()
219
+ avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN')
220
+ tqdm_dict = {"loss": "{:.3f}".format(avg_training_loss), "lr": lrs}
221
+ return tqdm_dict
222
+
223
+ @pl.utilities.rank_zero_only
224
+ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
225
+ save_path = self.output_dir.joinpath("best_tfmr")
226
+ self.model.config.save_step = self.step_count
227
+ self.model.save_pretrained(save_path)
228
+ self.tokenizer.save_pretrained(save_path)
229
+
230
+ @staticmethod
231
+ def add_model_specific_args(parser, root_dir):
232
+ parser.add_argument(
233
+ "--model_name_or_path",
234
+ default=None,
235
+ type=str,
236
+ required=True,
237
+ help="Path to pretrained model or model identifier from huggingface.co/models",
238
+ )
239
+ parser.add_argument(
240
+ "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
241
+ )
242
+ parser.add_argument(
243
+ "--tokenizer_name",
244
+ default=None,
245
+ type=str,
246
+ help="Pretrained tokenizer name or path if not the same as model_name",
247
+ )
248
+ parser.add_argument(
249
+ "--cache_dir",
250
+ default="",
251
+ type=str,
252
+ help="Where do you want to store the pre-trained models downloaded from s3",
253
+ )
254
+ parser.add_argument(
255
+ "--encoder_layerdrop",
256
+ type=float,
257
+ help="Encoder layer dropout probability (Optional). Goes into model.config",
258
+ )
259
+ parser.add_argument(
260
+ "--decoder_layerdrop",
261
+ type=float,
262
+ help="Decoder layer dropout probability (Optional). Goes into model.config",
263
+ )
264
+ parser.add_argument(
265
+ "--dropout",
266
+ type=float,
267
+ help="Dropout probability (Optional). Goes into model.config",
268
+ )
269
+ parser.add_argument(
270
+ "--attention_dropout",
271
+ type=float,
272
+ help="Attention dropout probability (Optional). Goes into model.config",
273
+ )
274
+ parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
275
+ parser.add_argument(
276
+ "--lr_scheduler",
277
+ default="linear",
278
+ choices=arg_to_scheduler_choices,
279
+ metavar=arg_to_scheduler_metavar,
280
+ type=str,
281
+ help="Learning rate scheduler",
282
+ )
283
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
284
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
285
+ parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
286
+ parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
287
+ parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
288
+ parser.add_argument("--train_batch_size", default=32, type=int)
289
+ parser.add_argument("--eval_batch_size", default=32, type=int)
290
+ parser.add_argument("--adafactor", action="store_true")
291
+
292
+
293
+ class LoggingCallback(pl.Callback):
294
+ def on_batch_end(self, trainer, pl_module):
295
+ lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
296
+ lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())}
297
+ pl_module.logger.log_metrics(lrs)
298
+
299
+ def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
300
+ rank_zero_info("***** Validation results *****")
301
+ metrics = trainer.callback_metrics
302
+ rank_zero_info(trainer.logger)
303
+ # Log results
304
+ for key in sorted(metrics):
305
+ if key not in ["log", "progress_bar"]:
306
+ rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
307
+
308
+ def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
309
+ rank_zero_info("***** Test results *****")
310
+ metrics = trainer.callback_metrics
311
+ # Log and save results to file
312
+ output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
313
+ with open(output_test_results_file, "w") as writer:
314
+ for key in sorted(metrics):
315
+ if key not in ["log", "progress_bar"]:
316
+ rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
317
+ writer.write("{} = {}\n".format(key, str(metrics[key])))
318
+
319
+
320
+ def add_generic_args(parser, root_dir) -> None:
321
+ # TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
322
+ parser.add_argument(
323
+ "--output_dir",
324
+ default=None,
325
+ type=str,
326
+ required=True,
327
+ help="The output directory where the model predictions and checkpoints will be written.",
328
+ )
329
+ parser.add_argument(
330
+ "--fp16",
331
+ action="store_true",
332
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
333
+ )
334
+
335
+ parser.add_argument(
336
+ "--fp16_opt_level",
337
+ type=str,
338
+ default="O2",
339
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
340
+ "See details at https://nvidia.github.io/apex/amp.html",
341
+ )
342
+ parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
343
+ parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
344
+ parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
345
+ parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
346
+ parser.add_argument(
347
+ "--gradient_accumulation_steps",
348
+ dest="accumulate_grad_batches",
349
+ type=int,
350
+ default=1,
351
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
352
+ )
353
+ parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
354
+ parser.add_argument(
355
+ "--data_dir",
356
+ default=None,
357
+ type=str,
358
+ required=True,
359
+ help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
360
+ )
361
+
362
+
363
+ def generic_train(
364
+ model: BaseTransformer,
365
+ args: argparse.Namespace,
366
+ early_stopping_callback=False,
367
+ logger=True, # can pass WandbLogger() here
368
+ extra_callbacks=[],
369
+ checkpoint_callback=None,
370
+ logging_callback=None,
371
+ **extra_train_kwargs
372
+ ):
373
+ pl.seed_everything(args.seed)
374
+
375
+ # init model
376
+ odir = Path(model.hparams.output_dir)
377
+ odir.mkdir(exist_ok=True)
378
+
379
+ # add custom checkpoints
380
+ if checkpoint_callback is None:
381
+ checkpoint_callback = pl.callbacks.ModelCheckpoint(
382
+ filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
383
+ )
384
+ if logging_callback is None:
385
+ logging_callback = LoggingCallback()
386
+
387
+ train_params = {}
388
+
389
+ # TODO: remove with PyTorch 1.6 since pl uses native amp
390
+ if args.fp16:
391
+ train_params["precision"] = 16
392
+ train_params["amp_level"] = args.fp16_opt_level
393
+
394
+ if args.gpus > 1:
395
+ train_params["distributed_backend"] = "ddp"
396
+
397
+ train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
398
+
399
+ lr_logger = LearningRateMonitor(logging_interval='step')
400
+
401
+ # deterministic=True,
402
+ trainer = pl.Trainer.from_argparse_args(
403
+ args,
404
+ weights_summary='full',
405
+ callbacks=[logging_callback, lr_logger],
406
+ logger=logger,
407
+ checkpoint_callback=checkpoint_callback,
408
+ early_stop_callback=early_stopping_callback,
409
+ num_sanity_val_steps=4,
410
+ **train_params,
411
+ )
412
+
413
+ trainer.lr_logger = lr_logger
414
+
415
+ if args.do_train:
416
+ trainer.fit(model)
417
+
418
+ return trainer
utils/sentence_retrieval_model.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from utils.bert_model import BertForSequenceEncoder
5
+
6
+ class sentence_retrieval_model(nn.Module):
7
+ def __init__(self, args):
8
+ super(sentence_retrieval_model, self).__init__()
9
+ self.pred_model = BertForSequenceEncoder.from_pretrained(args['bert_pretrain'])
10
+ self.bert_hidden_dim = args['bert_hidden_dim']
11
+ self.dropout = nn.Dropout(args['dropout'])
12
+ self.proj_match = nn.Linear(self.bert_hidden_dim, 1)
13
+
14
+
15
+ def forward(self, inp_tensor, msk_tensor, seg_tensor):
16
+ _, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor)
17
+ inputs = self.dropout(inputs)
18
+ score = self.proj_match(inputs).squeeze(-1)
19
+ score = torch.tanh(score)
20
+ return score
utils/sentence_retrieval_module.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Tuple
3
+ import pathlib
4
+
5
+ import torch
6
+ from transformers import BertTokenizer
7
+
8
+ from utils.sentence_retrieval_model import sentence_retrieval_model
9
+
10
+
11
+ THIS_DIR = pathlib.Path(__file__).parent.absolute()
12
+ ARGS = {
13
+ 'batch_size': 32,
14
+ 'bert_pretrain': 'base/bert_base',
15
+ 'checkpoint': 'base/model.best.32.pt',
16
+ 'dropout': 0.6,
17
+ 'bert_hidden_dim': 768,
18
+ 'max_len': 384,
19
+ 'cuda': torch.cuda.is_available()
20
+ }
21
+
22
+ if not ARGS['cuda']:
23
+ print('CUDA NOT AVAILABLE')
24
+
25
+
26
+ def process_sent(sentence):
27
+ sentence = re.sub("LSB.*?RSB", "", sentence)
28
+ sentence = re.sub("LRB\s*?RRB", "", sentence)
29
+ sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
30
+ sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
31
+ sentence = re.sub("--", "-", sentence)
32
+ sentence = re.sub("``", '"', sentence)
33
+ sentence = re.sub("''", '"', sentence)
34
+ return sentence
35
+
36
+ class SentenceRetrievalModule():
37
+
38
+ def __init__(self, max_len=None):
39
+
40
+ if max_len:
41
+ ARGS['max_len'] = max_len
42
+
43
+ self.tokenizer = BertTokenizer.from_pretrained(ARGS['bert_pretrain'], do_lower_case=False)
44
+ self.model = sentence_retrieval_model(ARGS)
45
+ self.model.load_state_dict(torch.load(ARGS['checkpoint'], map_location=torch.device('cpu'))['model'])
46
+ if ARGS['cuda']:
47
+ self.model = self.model.cuda()
48
+
49
+ def score_sentence_pairs(self, inputs: List[Tuple[str]]):
50
+ inputs_processed = [(process_sent(input[0]), process_sent(input[1])) for input in inputs]
51
+
52
+ encodings = self.tokenizer(
53
+ inputs_processed,
54
+ padding='max_length',
55
+ truncation='longest_first',
56
+ max_length=ARGS['max_len'],
57
+ return_token_type_ids=True,
58
+ return_attention_mask=True,
59
+ return_tensors='pt',
60
+ )
61
+
62
+ inp = encodings['input_ids']
63
+ msk = encodings['attention_mask']
64
+ seg = encodings['token_type_ids']
65
+
66
+ if ARGS['cuda']:
67
+ inp = inp.cuda()
68
+ msk = msk.cuda()
69
+ seg = seg.cuda()
70
+
71
+ self.model.eval()
72
+ with torch.no_grad():
73
+ outputs = self.model(inp, msk, seg).tolist()
74
+
75
+ assert len(outputs) == len(inputs)
76
+
77
+ return outputs
utils/textual_entailment_module.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import pandas as pd
4
+ from pathlib import Path
5
+ import torch
6
+ import re
7
+
8
+ from transformers import BertTokenizer, BertForSequenceClassification
9
+
10
+ # Constants and paths
11
+ HOME = Path('/users/k2031554')
12
+ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
13
+ MAX_LEN = 512
14
+ CLASSES = ['SUPPORTS','REFUTES','NOT ENOUGH INFO']
15
+ METHODS = ['WEIGHTED_SUM', 'MALON']
16
+
17
+ def process_sent(sentence):
18
+ sentence = re.sub("LSB.*?RSB", "", sentence)
19
+ sentence = re.sub("LRB\s*?RRB", "", sentence)
20
+ sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
21
+ sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
22
+ sentence = re.sub("--", "-", sentence)
23
+ sentence = re.sub("``", '"', sentence)
24
+ sentence = re.sub("''", '"', sentence)
25
+ return sentence
26
+
27
+ class TextualEntailmentModule():
28
+
29
+ def __init__(
30
+ self,
31
+ model_path = 'base/models/BERT_FEVER_v4_model_PBT',
32
+ tokenizer_path = 'base/models/BERT_FEVER_v4_tok_PBT'
33
+ ):
34
+ self.tokenizer = BertTokenizer.from_pretrained(
35
+ tokenizer_path
36
+ )
37
+ self.model = BertForSequenceClassification.from_pretrained(
38
+ model_path
39
+ )
40
+ self.model.to(DEVICE)
41
+
42
+ #def get_pair_scores(self, claim, evidence):
43
+ #
44
+ # encodings = self.tokenizer(
45
+ # [claim, evidence],
46
+ # max_length= MAX_LEN,
47
+ # return_token_type_ids=False,
48
+ # padding='max_length',
49
+ # truncation=True,
50
+ # return_tensors='pt',
51
+ # ).to(DEVICE)
52
+ #
53
+ # self.model.eval()
54
+ # with torch.no_grad():
55
+ # probs = self.model(
56
+ # input_ids=encodings['input_ids'],
57
+ # attention_mask=encodings['attention_mask']
58
+ # )
59
+ #
60
+ # return torch.softmax(probs.logits,dim=1).cpu().numpy()
61
+
62
+ def get_batch_scores(self, claims, evidence):
63
+
64
+ inputs = list(zip(claims, evidence))
65
+
66
+ encodings = self.tokenizer(
67
+ inputs,
68
+ max_length= MAX_LEN,
69
+ return_token_type_ids=False,
70
+ padding='max_length',
71
+ truncation=True,
72
+ return_tensors='pt',
73
+ ).to(DEVICE)
74
+
75
+ self.model.eval()
76
+ with torch.no_grad():
77
+ probs = self.model(
78
+ input_ids=encodings['input_ids'],
79
+ attention_mask=encodings['attention_mask']
80
+ )
81
+
82
+ return torch.softmax(probs.logits,dim=1).cpu().numpy()
83
+
84
+ def get_label_from_scores(self, scores):
85
+ return CLASSES[np.argmax(scores)]
86
+
87
+ def get_label_malon(self, score_set):
88
+ score_labels = [np.argmax(s) for s in score_set]
89
+ if 1 not in score_labels and 0 not in score_labels:
90
+ return CLASSES[2] #NOT ENOUGH INFO
91
+ elif 0 in score_labels:
92
+ return CLASSES[0] #SUPPORTS
93
+ elif 1 in score_labels:
94
+ return CLASSES[1] #REFUTES
utils/utils_graph2text.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+
4
+ def convert_text(text):
5
+ #return text
6
+ text = text.lower()
7
+ text = ' '.join(re.split('(\W)', text))
8
+ text = ' '.join(text.split())
9
+ return text
10
+
11
+ def eval_meteor_test_webnlg(folder_data, pred_file, dataset):
12
+
13
+ dir_path = os.path.dirname(os.path.realpath(__file__))
14
+ folder_data_before = dir_path + "/../utils"
15
+
16
+ cmd_string = "java -jar " + folder_data_before + "/meteor-1.5.jar " + pred_file + " " \
17
+ + folder_data + "/" + dataset + ".target_eval_meteor -l en -norm -r 3 > " + pred_file.replace("txt", "meteor")
18
+
19
+ os.system(cmd_string)
20
+
21
+ meteor_info = open(pred_file.replace("txt", "meteor"), 'r').readlines()[-1].strip()
22
+
23
+ return meteor_info
24
+
25
+
26
+ def eval_chrf_test_webnlg(folder_data, pred_file, dataset):
27
+
28
+ dir_path = os.path.dirname(os.path.realpath(__file__))
29
+ folder_data_before = dir_path + "/../utils"
30
+
31
+ cmd_string = "python " + folder_data_before + "/chrf++.py -H " + pred_file + " -R " \
32
+ + folder_data + "/" + dataset + ".target_eval_crf > " + pred_file.replace("txt", "chrf")
33
+
34
+ os.system(cmd_string)
35
+
36
+ chrf_info_1 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[1].strip()
37
+ chrf_info_2 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[2].strip()
38
+
39
+ return chrf_info_1 + " " + chrf_info_2
40
+
41
+ def eval_bleu(folder_data, pred_file, dataset):
42
+
43
+ dir_path = os.path.dirname(os.path.realpath(__file__))
44
+ folder_data_before = dir_path + "/data/"
45
+
46
+ cmd_string = "perl " + folder_data_before + "/multi-bleu.perl -lc " + folder_data + "/" + dataset + ".target_eval " \
47
+ + folder_data + "/" + dataset + ".target2_eval " + folder_data + "/" + dataset + ".target3_eval < " \
48
+ + pred_file + " > " + pred_file.replace("txt", "bleu")
49
+
50
+ os.system(cmd_string)
51
+
52
+ try:
53
+ bleu_info = open(pred_file.replace("txt", "bleu"), 'r').readlines()[0].strip()
54
+ except:
55
+ bleu_info = -1
56
+
57
+ return bleu_info
58
+
59
+
60
+ def eval_bleu_sents_tok(pred_file, folder_data, dataset):
61
+
62
+ dir_path = os.path.dirname(os.path.realpath(__file__))
63
+ folder_data_before = dir_path + "/../utils"
64
+
65
+ cmd_string = "perl " + folder_data_before + "/tokenizer.perl -threads 4 -no-escape < " + pred_file + " > " +\
66
+ pred_file + "_tok"
67
+ os.system(cmd_string)
68
+
69
+ cmd_string = "perl " + folder_data_before + "/multi-bleu.perl -lc " + folder_data + "/" + dataset + ".target.tok"\
70
+ + " < " + pred_file + "_tok" + " > " + pred_file.replace("txt", "bleu_data")
71
+ os.system(cmd_string)
72
+
73
+ try:
74
+ bleu_info_data = open(pred_file.replace("txt", "bleu_data"), 'r').readlines()[0].strip()
75
+ except:
76
+ bleu_info_data = 'no data'
77
+
78
+ return bleu_info_data
79
+
80
+
81
+ def eval_meteor(ref_file, pred_file):
82
+
83
+ dir_path = os.path.dirname(os.path.realpath(__file__))
84
+ folder_data_before = dir_path + "/../utils"
85
+
86
+ cmd_string = "java -jar " + folder_data_before + "/meteor-1.5.jar " + pred_file + " " \
87
+ + ref_file + " > " + pred_file.replace("txt", "meteor")
88
+
89
+ os.system(cmd_string)
90
+
91
+ meteor_info = open(pred_file.replace("txt", "meteor"), 'r').readlines()[-1].strip()
92
+
93
+ return meteor_info
94
+
95
+
96
+ def eval_chrf(ref_file, pred_file):
97
+
98
+ dir_path = os.path.dirname(os.path.realpath(__file__))
99
+ folder_data_before = dir_path + "/../utils"
100
+
101
+ cmd_string = "python " + folder_data_before + "/chrf++.py -H " + pred_file + " -R " \
102
+ + ref_file + " > " + pred_file.replace("txt", "chrf")
103
+
104
+ os.system(cmd_string)
105
+
106
+ try:
107
+ chrf_info_1 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[1].strip()
108
+ chrf_info_2 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[2].strip()
109
+ chrf_data = chrf_info_1 + " " + chrf_info_2
110
+ except:
111
+ chrf_data = "no data"
112
+
113
+
114
+ return chrf_data
utils/utils_verbalisation_module.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+ import linecache
4
+ import math
5
+ import os
6
+ import pickle
7
+ import socket
8
+ from logging import getLogger
9
+ from pathlib import Path
10
+ from typing import Callable, Dict, Iterable, List, Tuple, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.distributed as dist
15
+ from rouge_score import rouge_scorer, scoring
16
+ from sacrebleu import corpus_bleu
17
+ from torch import nn
18
+ from torch.utils.data import Dataset, Sampler
19
+
20
+ from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
21
+ from transformers.file_utils import cached_property
22
+ from transformers.models.bart.modeling_bart import shift_tokens_right
23
+ from utils.utils_graph2text import convert_text, eval_bleu
24
+ from pytorch_lightning.utilities import rank_zero_info
25
+ import pdb
26
+
27
+
28
+ try:
29
+ from fairseq.data.data_utils import batch_by_size
30
+
31
+ FAIRSEQ_AVAILABLE = True
32
+ except (ImportError, ModuleNotFoundError):
33
+ FAIRSEQ_AVAILABLE = False
34
+
35
+
36
+ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
37
+ """From fairseq"""
38
+ if target.dim() == lprobs.dim() - 1:
39
+ target = target.unsqueeze(-1)
40
+ nll_loss = -lprobs.gather(dim=-1, index=target)
41
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
42
+ if ignore_index is not None:
43
+ pad_mask = target.eq(ignore_index)
44
+ nll_loss.masked_fill_(pad_mask, 0.0)
45
+ smooth_loss.masked_fill_(pad_mask, 0.0)
46
+ else:
47
+ nll_loss = nll_loss.squeeze(-1)
48
+ smooth_loss = smooth_loss.squeeze(-1)
49
+
50
+ nll_loss = nll_loss.sum() # mean()? Scared to break other math.
51
+ smooth_loss = smooth_loss.sum()
52
+ eps_i = epsilon / lprobs.size(-1)
53
+ loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
54
+ return loss, nll_loss
55
+
56
+
57
+ def lmap(f: Callable, x: Iterable) -> List:
58
+ """list(map(f, x))"""
59
+ return list(map(f, x))
60
+
61
+
62
+ def calculate_bleu(output_lns, refs_lns) -> dict:
63
+ """Uses sacrebleu's corpus_bleu implementation."""
64
+ return {"sacrebleu": round(corpus_bleu(output_lns, [refs_lns]).score, 4)}
65
+
66
+
67
+ def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
68
+ def non_pad_len(tokens: np.ndarray) -> int:
69
+ return np.count_nonzero(tokens != tokenizer.pad_token_id)
70
+
71
+ def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
72
+ pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
73
+ label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
74
+ pred_str = lmap(str.strip, pred_str)
75
+ label_str = lmap(str.strip, label_str)
76
+ return pred_str, label_str
77
+
78
+ def summarization_metrics(pred: EvalPrediction) -> Dict:
79
+ pred_str, label_str = decode_pred(pred)
80
+ rouge: Dict = calculate_rouge(pred_str, label_str)
81
+ summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
82
+ rouge.update({"gen_len": summ_len})
83
+ return rouge
84
+
85
+ def translation_metrics(pred: EvalPrediction) -> Dict:
86
+ pred_str, label_str = decode_pred(pred)
87
+ bleu: Dict = calculate_bleu(pred_str, label_str)
88
+ gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
89
+ bleu.update({"gen_len": gen_len})
90
+ return bleu
91
+
92
+ compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
93
+ return compute_metrics_fn
94
+
95
+
96
+ def trim_batch(
97
+ input_ids,
98
+ pad_token_id,
99
+ attention_mask=None,
100
+ ):
101
+ """Remove columns that are populated exclusively by pad_token_id"""
102
+ keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
103
+ if attention_mask is None:
104
+ return input_ids[:, keep_column_mask]
105
+ else:
106
+ return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
107
+
108
+
109
+ class AbstractSeq2SeqDataset(Dataset):
110
+ def __init__(
111
+ self,
112
+ tokenizer,
113
+ data_dir,
114
+ max_source_length,
115
+ max_target_length,
116
+ type_path="train",
117
+ n_obs=None,
118
+ prefix="",
119
+ **dataset_kwargs
120
+ ):
121
+ super().__init__()
122
+ self.src_file = Path(data_dir).joinpath(type_path + ".source")
123
+ self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
124
+ self.len_file = Path(data_dir).joinpath(type_path + ".len")
125
+ if os.path.exists(self.len_file):
126
+ self.src_lens = pickle_load(self.len_file)
127
+ self.used_char_len = False
128
+ else:
129
+ self.src_lens = self.get_char_lens(self.src_file)
130
+ self.used_char_len = True
131
+ self.max_source_length = max_source_length
132
+ self.max_target_length = max_target_length
133
+ assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
134
+ self.tokenizer = tokenizer
135
+ self.prefix = prefix if prefix is not None else ""
136
+
137
+ if n_obs is not None:
138
+ self.src_lens = self.src_lens[:n_obs]
139
+ self.pad_token_id = self.tokenizer.pad_token_id
140
+ self.dataset_kwargs = dataset_kwargs
141
+ dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {})
142
+
143
+ def __len__(self):
144
+ return len(self.src_lens)
145
+
146
+ @staticmethod
147
+ def get_char_lens(data_file):
148
+ return [len(x) for x in Path(data_file).open().readlines()]
149
+
150
+ @cached_property
151
+ def tgt_lens(self):
152
+ """Length in characters of target documents"""
153
+ return self.get_char_lens(self.tgt_file)
154
+
155
+ def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
156
+ if distributed:
157
+ return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
158
+ else:
159
+ return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
160
+
161
+ def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
162
+ assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
163
+ assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler"
164
+ sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
165
+
166
+ def num_tokens_in_example(i):
167
+ return min(self.src_lens[i], self.max_target_length)
168
+
169
+ # call fairseq cython function
170
+ batch_sampler: List[List[int]] = batch_by_size(
171
+ sorted_indices,
172
+ num_tokens_fn=num_tokens_in_example,
173
+ max_tokens=max_tokens_per_batch,
174
+ required_batch_size_multiple=64,
175
+ )
176
+ shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
177
+ # move the largest batch to the front to OOM quickly (uses an approximation for padding)
178
+ approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
179
+ largest_batch_idx = np.argmax(approximate_toks_per_batch)
180
+ shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
181
+ shuffled_batches[largest_batch_idx],
182
+ shuffled_batches[0],
183
+ )
184
+ return shuffled_batches
185
+
186
+ def __getitem__(self, item):
187
+ raise NotImplementedError("You must implement this")
188
+
189
+ def collate_fn(self, batch):
190
+ raise NotImplementedError("You must implement this")
191
+
192
+
193
+ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
194
+ def __getitem__(self, index) -> Dict[str, torch.Tensor]:
195
+ """Call tokenizer on src and tgt_lines"""
196
+
197
+
198
+ index = index + 1 # linecache starts at 1
199
+ source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
200
+ tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
201
+ assert source_line, f"empty source line for index {index}"
202
+ assert tgt_line, f"empty tgt line for index {index}"
203
+ source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length)
204
+ target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length)
205
+
206
+ source_ids = source_inputs["input_ids"].squeeze()
207
+ target_ids = target_inputs["input_ids"].squeeze()
208
+ src_mask = source_inputs["attention_mask"].squeeze()
209
+ return {
210
+ "input_ids": source_ids,
211
+ "attention_mask": src_mask,
212
+ "labels": target_ids,
213
+ }
214
+
215
+ def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
216
+ """Only used by LegacyDataset"""
217
+ return tokenizer(
218
+ [line],
219
+ max_length=max_length,
220
+ padding="max_length" if pad_to_max_length else None,
221
+ truncation=True,
222
+ return_tensors=return_tensors,
223
+ **self.dataset_kwargs,
224
+ )
225
+
226
+ def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
227
+ input_ids = torch.stack([x["input_ids"] for x in batch])
228
+ masks = torch.stack([x["attention_mask"] for x in batch])
229
+ target_ids = torch.stack([x["labels"] for x in batch])
230
+ pad_token_id = self.pad_token_id
231
+ y = trim_batch(target_ids, pad_token_id)
232
+ source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
233
+ batch = {
234
+ "input_ids": source_ids,
235
+ "attention_mask": source_mask,
236
+ "labels": y,
237
+ }
238
+ return batch
239
+
240
+
241
+ class Seq2SeqDataset(AbstractSeq2SeqDataset):
242
+ """A dataset that calls prepare_seq2seq_batch."""
243
+
244
+ def __getitem__(self, index) -> Dict[str, str]:
245
+
246
+ #print(self.dataset_kwargs['model_t'])
247
+ # if 't5' in self.dataset_kwargs['model_t']:
248
+ # self.prefix = 'translate Graph to English: '
249
+ # print('aac')
250
+ # exit()
251
+
252
+ index = index + 1 # linecache starts at 1
253
+ source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
254
+ tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
255
+ assert source_line, f"empty source line for index {index}"
256
+ assert tgt_line, f"empty tgt line for index {index}"
257
+ return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
258
+
259
+ def collate_fn(self, batch):
260
+ """Call prepare_seq2seq_batch."""
261
+ batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
262
+ [x["src_texts"] for x in batch],
263
+ tgt_texts=[x["tgt_texts"] for x in batch],
264
+ max_length=self.max_source_length,
265
+ max_target_length=self.max_target_length,
266
+ return_tensors="pt",
267
+ **self.dataset_kwargs,
268
+ ).data
269
+ #lens = (batch_encoding['attention_mask'] == 1.).sum(dim=1).tolist()
270
+
271
+ batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
272
+
273
+ return batch_encoding
274
+
275
+
276
+
277
+ class Seq2SeqDataCollator:
278
+ def __init__(self, tokenizer, data_args, tpu_num_cores=None):
279
+ self.tokenizer = tokenizer
280
+ self.pad_token_id = tokenizer.pad_token_id
281
+ assert (
282
+ self.pad_token_id is not None
283
+ ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
284
+ self.data_args = data_args
285
+ self.tpu_num_cores = tpu_num_cores
286
+ self.dataset_kwargs = {"add_prefix_space": isinstance(tokenizer, BartTokenizer)}
287
+ if data_args.src_lang is not None:
288
+ self.dataset_kwargs["src_lang"] = data_args.src_lang
289
+ if data_args.tgt_lang is not None:
290
+ self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
291
+
292
+ def __call__(self, batch) -> Dict[str, torch.Tensor]:
293
+ if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
294
+ batch = self._encode(batch)
295
+ input_ids, attention_mask, labels = (
296
+ batch["input_ids"],
297
+ batch["attention_mask"],
298
+ batch["labels"],
299
+ )
300
+ else:
301
+ input_ids = torch.stack([x["input_ids"] for x in batch])
302
+ attention_mask = torch.stack([x["attention_mask"] for x in batch])
303
+ labels = torch.stack([x["labels"] for x in batch])
304
+
305
+ labels = trim_batch(labels, self.pad_token_id)
306
+ input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
307
+
308
+ if isinstance(self.tokenizer, T5Tokenizer):
309
+ decoder_input_ids = self._shift_right_t5(labels)
310
+ else:
311
+ decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
312
+
313
+ batch = {
314
+ "input_ids": input_ids,
315
+ "attention_mask": attention_mask,
316
+ "decoder_input_ids": decoder_input_ids,
317
+ "labels": labels,
318
+ }
319
+ return batch
320
+
321
+ def _shift_right_t5(self, input_ids):
322
+ # shift inputs to the right
323
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
324
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
325
+ shifted_input_ids[..., 0] = self.pad_token_id
326
+ return shifted_input_ids
327
+
328
+ def _encode(self, batch) -> Dict[str, torch.Tensor]:
329
+ batch_encoding = self.tokenizer.prepare_seq2seq_batch(
330
+ [x["src_texts"] for x in batch],
331
+ tgt_texts=[x["tgt_texts"] for x in batch],
332
+ max_length=self.data_args.max_source_length,
333
+ max_target_length=self.data_args.max_target_length,
334
+ padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
335
+ return_tensors="pt",
336
+ **self.dataset_kwargs,
337
+ )
338
+ return batch_encoding.data
339
+
340
+
341
+ class SortishSampler(Sampler):
342
+ "Go through the text data by order of src length with a bit of randomness. From fastai repo."
343
+
344
+ def __init__(self, data, batch_size, shuffle=True):
345
+ self.data, self.bs, self.shuffle = data, batch_size, shuffle
346
+
347
+ def __len__(self) -> int:
348
+ return len(self.data)
349
+
350
+ def __iter__(self):
351
+ return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
352
+
353
+
354
+ def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
355
+ "Go through the text data by order of src length with a bit of randomness. From fastai repo."
356
+ if not shuffle:
357
+ return np.argsort(np.array(data) * -1)
358
+
359
+ def key_fn(i):
360
+ return data[i]
361
+
362
+ idxs = np.random.permutation(len(data))
363
+ sz = bs * 50
364
+ ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
365
+ sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
366
+ sz = bs
367
+ ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
368
+ max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
369
+ ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
370
+ sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
371
+ sort_idx = np.concatenate((ck_idx[0], sort_idx))
372
+ return sort_idx
373
+
374
+
375
+ class DistributedSortishSampler(Sampler):
376
+ """Copied from torch DistributedSampler"""
377
+
378
+ def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
379
+ if num_replicas is None:
380
+ if not dist.is_available():
381
+ raise RuntimeError("Requires distributed package to be available")
382
+ num_replicas = dist.get_world_size()
383
+ if rank is None:
384
+ if not dist.is_available():
385
+ raise RuntimeError("Requires distributed package to be available")
386
+ rank = dist.get_rank()
387
+ self.dataset = dataset
388
+ self.num_replicas = num_replicas
389
+ self.rank = rank
390
+ self.epoch = 0
391
+ if add_extra_examples:
392
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
393
+ self.total_size = self.num_samples * self.num_replicas
394
+ else:
395
+ self.total_size = len(dataset)
396
+ self.num_samples = len(self.available_indices)
397
+ self.batch_size = batch_size
398
+ self.add_extra_examples = add_extra_examples
399
+ self.shuffle = shuffle
400
+
401
+ def __iter__(self) -> Iterable:
402
+ g = torch.Generator()
403
+ g.manual_seed(self.epoch)
404
+
405
+ sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
406
+ sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
407
+ indices = [self.available_indices[i] for i in sortish_indices]
408
+ assert len(indices) == self.num_samples
409
+ return iter(indices)
410
+
411
+ @cached_property
412
+ def available_indices(self) -> np.array:
413
+ indices = list(range(len(self.dataset)))
414
+ # add extra samples to make it evenly divisible
415
+ indices += indices[: (self.total_size - len(indices))]
416
+ assert len(indices) == self.total_size
417
+ # subsample
418
+ available_indices = indices[self.rank : self.total_size : self.num_replicas]
419
+ return available_indices
420
+
421
+ def __len__(self):
422
+ return self.num_samples
423
+
424
+ def set_epoch(self, epoch):
425
+ self.epoch = epoch
426
+
427
+
428
+ logger = getLogger(__name__)
429
+
430
+
431
+ def use_task_specific_params(model, task):
432
+ """Update config with summarization specific params."""
433
+ task_specific_params = model.config.task_specific_params
434
+
435
+ if task_specific_params is not None:
436
+ pars = task_specific_params.get(task, {})
437
+ logger.info(f"using task specific params for {task}: {pars}")
438
+ model.config.update(pars)
439
+
440
+
441
+ def pickle_load(path):
442
+ """pickle.load(path)"""
443
+ with open(path, "rb") as f:
444
+ return pickle.load(f)
445
+
446
+
447
+ def pickle_save(obj, path):
448
+ """pickle.dump(obj, path)"""
449
+ with open(path, "wb") as f:
450
+ return pickle.dump(obj, f)
451
+
452
+
453
+ def flatten_list(summary_ids: List[List]):
454
+ return [x for x in itertools.chain.from_iterable(summary_ids)]
455
+
456
+
457
+ def save_json(content, path, indent=4, **json_dump_kwargs):
458
+ with open(path, "w") as f:
459
+ json.dump(content, f, indent=indent, **json_dump_kwargs)
460
+
461
+
462
+ def load_json(path):
463
+ with open(path) as f:
464
+ return json.load(f)
465
+
466
+
467
+ ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
468
+
469
+
470
+ def extract_rouge_mid_statistics(dct):
471
+ new_dict = {}
472
+ for k1, v1 in dct.items():
473
+ mid = v1.mid
474
+ new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
475
+ return new_dict
476
+
477
+
478
+ def calculate_rouge(
479
+ pred_lns: List[str],
480
+ tgt_lns: List[str],
481
+ use_stemmer=True,
482
+ rouge_keys=ROUGE_KEYS,
483
+ return_precision_and_recall=False,
484
+ bootstrap_aggregation=True,
485
+ newline_sep=True,
486
+ ) -> Dict:
487
+ """Calculate rouge using rouge_scorer package.
488
+
489
+ Args:
490
+ pred_lns: list of summaries generated by model
491
+ tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
492
+ use_stemmer: Bool indicating whether Porter stemmer should be used to
493
+ strip word suffixes to improve matching.
494
+ rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
495
+ return_precision_and_recall: (False) whether to also return precision and recall.
496
+ bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
497
+ this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
498
+ newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
499
+ on multi sentence summaries (CNN/DM dataset).
500
+
501
+ Returns:
502
+ Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
503
+
504
+ """
505
+ scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
506
+ aggregator = scoring.BootstrapAggregator()
507
+ for pred, tgt in zip(tgt_lns, pred_lns):
508
+ # rougeLsum expects "\n" separated sentences within a summary
509
+ if newline_sep:
510
+ pred = add_newline_to_end_of_each_sentence(pred)
511
+ tgt = add_newline_to_end_of_each_sentence(tgt)
512
+ scores = scorer.score(pred, tgt)
513
+ aggregator.add_scores(scores)
514
+
515
+ if bootstrap_aggregation:
516
+ result = aggregator.aggregate()
517
+ if return_precision_and_recall:
518
+ return extract_rouge_mid_statistics(result) # here we return dict
519
+ else:
520
+ return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
521
+
522
+ else:
523
+ return aggregator._scores # here we return defaultdict(list)
524
+
525
+
526
+ # Utilities for freezing parameters and checking whether they are frozen
527
+
528
+
529
+ def freeze_params(model: nn.Module):
530
+ """Set requires_grad=False for each of model.parameters()"""
531
+ for par in model.parameters():
532
+ par.requires_grad = False
533
+
534
+
535
+ def freeze_embeds(model):
536
+ """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
537
+ model_type = model.config.model_type
538
+
539
+ if model_type == "t5":
540
+ freeze_params(model.shared)
541
+ for d in [model.encoder, model.decoder]:
542
+ freeze_params(d.embed_tokens)
543
+ elif model_type == "fsmt":
544
+ for d in [model.model.encoder, model.model.decoder]:
545
+ freeze_params(d.embed_positions)
546
+ freeze_params(d.embed_tokens)
547
+ else:
548
+ freeze_params(model.model.shared)
549
+ for d in [model.model.encoder, model.model.decoder]:
550
+ freeze_params(d.embed_positions)
551
+ freeze_params(d.embed_tokens)
552
+
553
+
554
+ def grad_status(model: nn.Module) -> Iterable:
555
+ return (par.requires_grad for par in model.parameters())
556
+
557
+
558
+ def any_requires_grad(model: nn.Module) -> bool:
559
+ return any(grad_status(model))
560
+
561
+
562
+ def assert_all_frozen(model):
563
+ model_grads: List[bool] = list(grad_status(model))
564
+ n_require_grad = sum(lmap(int, model_grads))
565
+ npars = len(model_grads)
566
+ assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
567
+
568
+
569
+ def assert_not_all_frozen(model):
570
+ model_grads: List[bool] = list(grad_status(model))
571
+ npars = len(model_grads)
572
+ assert any(model_grads), f"none of {npars} weights require grad"
573
+
574
+
575
+ def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
576
+ """
577
+ Parse an argv list of unspecified command line args to a dict.
578
+ Assumes all values are either numeric or boolean in the form of true/false.
579
+ """
580
+ result = {}
581
+ assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
582
+ num_pairs = len(unparsed_args) // 2
583
+ for pair_num in range(num_pairs):
584
+ i = 2 * pair_num
585
+ assert unparsed_args[i].startswith("--")
586
+ if unparsed_args[i + 1].lower() == "true":
587
+ value = True
588
+ elif unparsed_args[i + 1].lower() == "false":
589
+ value = False
590
+ else:
591
+ try:
592
+ value = int(unparsed_args[i + 1])
593
+ except ValueError:
594
+ value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
595
+
596
+ result[unparsed_args[i][2:]] = value
597
+ return result
598
+
599
+
600
+ def write_txt_file(ordered_tgt, path):
601
+ f = Path(path).open("w")
602
+ for ln in ordered_tgt:
603
+ f.write(ln + "\n")
604
+ f.flush()
605
+
606
+
607
+ def chunks(lst, n):
608
+ """Yield successive n-sized chunks from lst."""
609
+ for i in range(0, len(lst), n):
610
+ yield lst[i : i + n]
utils/verbalisation_module.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.finetune import Graph2TextModule
2
+ from typing import Dict, List, Tuple, Union, Optional
3
+ import torch
4
+ import re
5
+
6
+ if torch.cuda.is_available():
7
+ DEVICE = 'cuda'
8
+ else:
9
+ DEVICE = 'cpu'
10
+ print('CUDA NOT AVAILABLE')
11
+
12
+ CHECKPOINT = 'base/t5-base_13881_val_avg_bleu=68.1000-step_count=5.ckpt'
13
+ MAX_LENGTH = 384
14
+ SEED = 42
15
+
16
+
17
+ class VerbModule():
18
+
19
+ def __init__(self, override_args: Dict[str, str] = None):
20
+ # Model
21
+ if not override_args:
22
+ override_args = {}
23
+ self.g2t_module = Graph2TextModule.load_from_checkpoint(CHECKPOINT, strict=False, **override_args)
24
+ self.tokenizer = self.g2t_module.tokenizer
25
+ # Unk replacer
26
+ self.vocab = self.tokenizer.get_vocab()
27
+ self.convert_some_japanese_characters = True
28
+ self.unk_char_replace_sliding_window_size = 2
29
+ self.unknowns = []
30
+
31
+ def __generate_verbalisations_from_inputs(self, inputs: Union[str, List[str]]):
32
+ try:
33
+ inputs_encoding = self.tokenizer.prepare_seq2seq_batch(
34
+ inputs, truncation=True, max_length=MAX_LENGTH, return_tensors='pt'
35
+ )
36
+ inputs_encoding = {k: v.to(DEVICE) for k, v in inputs_encoding.items()}
37
+
38
+ self.g2t_module.model.eval()
39
+ with torch.no_grad():
40
+ gen_output = self.g2t_module.model.generate(
41
+ inputs_encoding['input_ids'],
42
+ attention_mask=inputs_encoding['attention_mask'],
43
+ use_cache=True,
44
+ decoder_start_token_id = self.g2t_module.decoder_start_token_id,
45
+ num_beams= self.g2t_module.eval_beams,
46
+ max_length= self.g2t_module.eval_max_length,
47
+ length_penalty=1.0
48
+ )
49
+ except Exception:
50
+ print(inputs)
51
+ raise
52
+
53
+ return gen_output
54
+
55
+ '''
56
+ We create this function as an alteration from [this one](https://github.com/huggingface/transformers/blob/198c335d219a5eb4d3f124fdd1ce1a9cd9f78a9b/src/transformers/tokenization_utils_fast.py#L537), mainly because the official 'tokenizer.decode' treats all special tokens the same, while we want to drop all special tokens from the decoded sentence EXCEPT for the <unk> token, which we will replace later on.
57
+ '''
58
+ def __decode_ids_to_string_custom(
59
+ self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
60
+ ) -> str:
61
+ filtered_tokens = self.tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=False)
62
+ # Do not remove special tokens yet
63
+
64
+ # To avoid mixing byte-level and unicode for byte-level BPT
65
+ # we need to build string separatly for added tokens and byte-level tokens
66
+ # cf. https://github.com/huggingface/transformers/issues/1133
67
+ sub_texts = []
68
+ current_sub_text = []
69
+ for token in filtered_tokens:
70
+ if skip_special_tokens and\
71
+ token != self.tokenizer.unk_token and\
72
+ token in self.tokenizer.all_special_tokens:
73
+
74
+ continue
75
+ else:
76
+ current_sub_text.append(token)
77
+ if current_sub_text:
78
+ sub_texts.append(self.tokenizer.convert_tokens_to_string(current_sub_text))
79
+ text = " ".join(sub_texts)
80
+
81
+ if clean_up_tokenization_spaces:
82
+ clean_text = self.tokenizer.clean_up_tokenization(text)
83
+ return clean_text
84
+ else:
85
+ return text
86
+
87
+ def __decode_sentences(self, encoded_sentences: Union[str, List[str]]):
88
+ if type(encoded_sentences) == str:
89
+ encoded_sentences = [encoded_sentences]
90
+ decoded_sentences = [self.__decode_ids_to_string_custom(i, skip_special_tokens=True) for i in encoded_sentences]
91
+ return decoded_sentences
92
+
93
+ def verbalise_sentence(self, inputs: Union[str, List[str]]):
94
+ if type(inputs) == str:
95
+ inputs = [inputs]
96
+
97
+ gen_output = self.__generate_verbalisations_from_inputs(inputs)
98
+
99
+ decoded_sentences = self.__decode_sentences(gen_output)
100
+
101
+ if len(decoded_sentences) == 1:
102
+ return decoded_sentences[0]
103
+ else:
104
+ return decoded_sentences
105
+
106
+ def verbalise_triples(self, input_triples: Union[Dict[str, str], List[Dict[str, str]], List[List[Dict[str, str]]]]):
107
+ if type(input_triples) == dict:
108
+ input_triples = [input_triples]
109
+
110
+ verbalisation_inputs = []
111
+ for triple in input_triples:
112
+ if type(triple) == dict:
113
+ assert 'subject' in triple
114
+ assert 'predicate' in triple
115
+ assert 'object' in triple
116
+ verbalisation_inputs.append(
117
+ f'translate Graph to English: <H> {triple["subject"]} <R> {triple["predicate"]} <T> {triple["object"]}'
118
+ )
119
+ elif type(triple) == list:
120
+ input_sentence = ['translate Graph to English:']
121
+ for subtriple in triple:
122
+ assert 'subject' in subtriple
123
+ assert 'predicate' in subtriple
124
+ assert 'object' in subtriple
125
+ input_sentence.append(f'<H> {subtriple["subject"]}')
126
+ input_sentence.append(f'<R> {subtriple["predicate"]}')
127
+ input_sentence.append(f'<T> {subtriple["object"]}')
128
+ verbalisation_inputs.append(
129
+ ' '.join(input_sentence)
130
+ )
131
+
132
+ return self.verbalise_sentence(verbalisation_inputs)
133
+
134
+ def verbalise(self, input: Union[str, List, Dict]):
135
+ try:
136
+ if (type(input) == str) or (type(input) == list and type(input[0]) == str):
137
+ return self.verbalise_sentence(input)
138
+ elif (type(input) == dict) or (type(input) == list and type(input[0]) == dict):
139
+ return self.verbalise_triples(input)
140
+ else:
141
+ return self.verbalise_triples(input)
142
+ except Exception:
143
+ print(f'ERROR VERBALISING {input}')
144
+ raise
145
+
146
+ def add_label_to_unk_replacer(self, label: str):
147
+ N = self.unk_char_replace_sliding_window_size
148
+ self.unknowns.append({})
149
+
150
+ # Some pre-processing of labels to normalise some characters
151
+ if self.convert_some_japanese_characters:
152
+ label = label.replace('(','(')
153
+ label = label.replace(')',')')
154
+ label = label.replace('〈','<')
155
+ label = label.replace('/','/')
156
+ label = label.replace('〉','>')
157
+
158
+ label_encoded = self.tokenizer.encode(label)
159
+ label_tokens = self.tokenizer.convert_ids_to_tokens(label_encoded)
160
+
161
+ # Here, we also remove </s> (eos) and <pad> tokens in the replacing key, because:
162
+ # 1) When the whole label is all unk:
163
+ # label_token_to_string would be '<unk></s>', meaning the replacing key (which is the same) only replaces
164
+ # the <unk> if it appears at the end of the sentence, which is not the desired effect.
165
+ # But since this means ANY <unk> will be replaced by this, it would be good to only replace keys that are <unk>
166
+ # on the last replacing pass.
167
+ # 2) On other cases, then the unk is in the label but not in its entirety, like in the start/end, it might
168
+ # involve the starting <pad> token or the ending <eos> token on the replacing key, again forcing the replacement
169
+ # to only happen if the label appears in the end of the sentence.
170
+ label_tokens = [t for t in label_tokens if t not in [
171
+ self.tokenizer.eos_token, self.tokenizer.pad_token
172
+ ]]
173
+
174
+ label_token_to_string = self.tokenizer.convert_tokens_to_string(label_tokens)
175
+ unk_token_to_string = self.tokenizer.convert_tokens_to_string([self.tokenizer.unk_token])
176
+
177
+ #print(label_encoded,label_tokens,label_token_to_string)
178
+
179
+ match_unks_in_label = re.findall('(?:(?: )*<unk>(?: )*)+', label_token_to_string)
180
+ if len(match_unks_in_label) > 0:
181
+ # If the whole label is made of UNK
182
+ if (match_unks_in_label[0]) == label_token_to_string:
183
+ #print('Label is all unks')
184
+ self.unknowns[-1][label_token_to_string.strip()] = label
185
+ # Else, there should be non-UNK characters in the label
186
+ else:
187
+ #print('Label is NOT all unks')
188
+ # Analyse the label with a sliding window of size N (N before, N ahead)
189
+ for idx, token in enumerate(label_tokens):
190
+ idx_before = max(0,idx-N)
191
+ idx_ahead = min(len(label_tokens), idx+N+1)
192
+
193
+
194
+ # Found a UNK
195
+ if token == self.tokenizer.unk_token:
196
+
197
+ # In case multiple UNK, exclude UNKs seen after this one, expand window to other side if possible
198
+ if len(match_unks_in_label) > 1:
199
+ #print(idx)
200
+ #print(label_tokens)
201
+ #print(label_tokens[idx_before:idx_ahead])
202
+ #print('HERE!')
203
+ # Reduce on the right, expanding on the left
204
+ while self.tokenizer.unk_token in label_tokens[idx+1:idx_ahead]:
205
+ idx_before = max(0,idx_before-1)
206
+ idx_ahead = min(idx+2, idx_ahead-1)
207
+ #print(label_tokens[idx_before:idx_ahead])
208
+ # Now just reduce on the left
209
+ while self.tokenizer.unk_token in label_tokens[idx_before:idx]:
210
+ idx_before = min(idx-1,idx_before+2)
211
+ #print(label_tokens[idx_before:idx_ahead])
212
+
213
+ span = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx_ahead])
214
+ # First token of the label is UNK
215
+ if idx == 1 and label_tokens[0] == '▁':
216
+ #print('Label begins with unks')
217
+ to_replace = '^' + re.escape(span).replace(
218
+ re.escape(unk_token_to_string),
219
+ '.+?'
220
+ )
221
+
222
+ replaced_span = re.search(
223
+ to_replace,
224
+ label
225
+ )[0]
226
+ self.unknowns[-1][span.strip()] = replaced_span
227
+ # Last token of the label is UNK
228
+ elif idx == len(label_tokens)-2 and label_tokens[-1] == self.tokenizer.eos_token:
229
+ #print('Label ends with unks')
230
+ pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])
231
+ pre_idx_unk_counts = pre_idx.count(unk_token_to_string)
232
+ to_replace = re.escape(span).replace(
233
+ re.escape(unk_token_to_string),
234
+ f'[^{re.escape(pre_idx)}]+?'
235
+ ) + '$'
236
+
237
+ if pre_idx.strip() == '':
238
+ to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
239
+
240
+ replaced_span = re.search(
241
+ to_replace,
242
+ label
243
+ )[0]
244
+ self.unknowns[-1][span.strip()] = replaced_span
245
+
246
+ # A token in-between the label is UNK
247
+ else:
248
+ #print('Label has unks in the middle')
249
+ pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])
250
+
251
+ to_replace = re.escape(span).replace(
252
+ re.escape(unk_token_to_string),
253
+ f'[^{re.escape(pre_idx)}]+?'
254
+ )
255
+ #If there is nothing behind the ??, because it is in the middle but the previous token is also
256
+ #a ??, then we would end up with to_replace beginning with [^], which we can't have
257
+ if pre_idx.strip() == '':
258
+ to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
259
+
260
+ replaced_span = re.search(
261
+ to_replace,
262
+ label
263
+ )
264
+
265
+ if replaced_span:
266
+ span = re.sub(r'\s([?.!",](?:\s|$))', r'\1', span.strip())
267
+ self.unknowns[-1][span] = replaced_span[0]
268
+
269
+ def replace_unks_on_sentence(self, sentence: str, loop_n : int = 3, empty_after : bool = False):
270
+ # Loop through in case the labels are repeated, maximum of three times
271
+ while '<unk>' in sentence and loop_n > 0:
272
+ loop_n -= 1
273
+ for unknowns in self.unknowns:
274
+ for k,v in unknowns.items():
275
+ # Leave to replace all-unk labels at the last pass
276
+ if k == '<unk>' and loop_n > 0:
277
+ continue
278
+ # In case it is because the first letter of the sentence has been uppercased
279
+ if not k in sentence and k[0] == k[0].lower() and k[0].upper() == sentence[0]:
280
+ k = k[0].upper() + k[1:]
281
+ v = v[0].upper() + v[1:]
282
+ # In case it is because a double space is found where it should not be
283
+ elif not k in sentence and len(re.findall(r'\s{2,}',k))>0:
284
+ k = re.sub(r'\s+', ' ', k)
285
+ #print(k,'/',v,'/',sentence)
286
+ sentence = sentence.replace(k.strip(),v.strip(),1)
287
+ #sentence = re.sub(k, v, sentence)
288
+ # Removing final doublespaces
289
+ sentence = re.sub(r'\s+', ' ', sentence).strip()
290
+ # Removing spaces before punctuation
291
+ sentence = re.sub(r'\s([?.!",](?:\s|$))', r'\1', sentence)
292
+ if empty_after:
293
+ self.unknowns = []
294
+ return sentence
295
+
296
+ if __name__ == '__main__':
297
+
298
+ verb_module = VerbModule()
299
+ verbs = verb_module.verbalise('translate Graph to English: <H> World Trade Center <R> height <T> 200 meter <H> World Trade Center <R> is a <T> tower')
300
+ print(verbs)
utils/wikidata_utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import uuid
4
+ import numpy as np
5
+ import time
6
+ import requests
7
+ import traceback
8
+ import pdb
9
+ import math
10
+ import ast
11
+ import pandas as pd
12
+ import pickle
13
+ from qwikidata.linked_data_interface import get_entity_dict_from_api
14
+ from qwikidata.sparql import return_sparql_query_results
15
+
16
+ from urllib3.exceptions import MaxRetryError, ConnectionError
17
+ from qwikidata.linked_data_interface import LdiResponseNotOk
18
+
19
+ import hashlib
20
+
21
+ class CachedWikidataAPI():
22
+
23
+ def __init__(self, cache_path = 'entity_cache.p', save_every_x_queries=1):
24
+ self.save_every_x_queries = save_every_x_queries
25
+ self.x_queries_passed = 0
26
+ self.languages = ['en','fr','es','pt','pt-br','it','de']
27
+ self.cache_path = cache_path
28
+ try:
29
+ with open(self.cache_path,'rb') as f:
30
+ self.entity_cache = pickle.load(f)
31
+ except FileNotFoundError:
32
+ self.entity_cache = {}
33
+
34
+ def get_unique_id_from_str(self, my_str):
35
+ return hashlib.md5(str.encode(my_str)).hexdigest()
36
+
37
+ def save_entity_cache(self, force=False):
38
+ if force:
39
+ self.x_queries_passed = self.save_every_x_queries
40
+ self.x_queries_passed = self.x_queries_passed+1
41
+ if self.x_queries_passed >= self.save_every_x_queries:
42
+ with open(self.cache_path,'wb') as f:
43
+ pickle.dump(self.entity_cache,f)
44
+ self.x_queries_passed = 0
45
+
46
+ def get_entity(self, item_id):
47
+ if item_id in self.entity_cache:
48
+ return self.entity_cache[item_id]
49
+ while True:
50
+ try:
51
+ entity = get_entity_dict_from_api(item_id)
52
+ self.entity_cache[item_id] = entity
53
+ self.save_entity_cache()
54
+ return entity
55
+ except (ConnectionError, MaxRetryError) as e:
56
+ #traceback.print_exc()
57
+ time.sleep(1)
58
+ continue
59
+ except LdiResponseNotOk:
60
+ #traceback.print_exc()
61
+ self.entity_cache[item_id] = 'deleted'
62
+ self.save_entity_cache()
63
+ return 'deleted'
64
+
65
+ def get_label(self, item, non_language_set=False):
66
+ if type(item) == str:
67
+ entity = self.get_entity(item)
68
+ if entity == 'deleted':
69
+ return (entity, 'none')
70
+ labels = entity['labels' if 'labels' in entity else 'lemmas']
71
+ elif type(item) == dict:
72
+ if 'labels' in item:
73
+ labels = item['labels']
74
+ elif 'lemmas' in item:
75
+ labels = item['lemmas']
76
+ for l in self.languages:
77
+ if l in labels:
78
+ return (labels[l]['value'], l)
79
+ if non_language_set:
80
+ all_labels = list(labels.keys())
81
+ if len(all_labels)>0:
82
+ return (labels[all_labels[0]]['value'], all_labels[0])
83
+ return ('no-label', 'none')
84
+
85
+ def get_desc(self, item, non_language_set=False):
86
+ if type(item) == str:
87
+ entity = self.get_entity(item)
88
+ if entity == 'deleted':
89
+ return (entity, 'none')
90
+ descriptions = entity['descriptions']
91
+ elif type(item) == dict:
92
+ if 'descriptions' in item:
93
+ descriptions = item['descriptions']
94
+ for l in self.languages:
95
+ if l in descriptions:
96
+ return (descriptions[l]['value'], l)
97
+ if non_language_set:
98
+ all_descriptions = list(descriptions.keys())
99
+ if len(all_descriptions)>0:
100
+ return (descriptions[all_descriptions[0]]['value'], all_descriptions[0])
101
+ return ('no-desc', 'none')
102
+
103
+ def get_alias(self, item, non_language_set=False):
104
+ if type(item) == str:
105
+ entity = self.get_entity(item)
106
+ if entity == 'deleted':
107
+ return ([entity], 'none')
108
+ aliases = entity['aliases']
109
+ elif type(item) == dict:
110
+ if 'aliases' in item:
111
+ aliases = item['aliases']
112
+ for l in self.languages:
113
+ if l in aliases:
114
+ return ([alias['value'] for alias in aliases[l]], l)
115
+ if non_language_set:
116
+ all_aliases = list(aliases.keys())
117
+ if len(all_aliases)>0:
118
+ return (aliases[all_aliases[0]]['value'], all_aliases[0])
119
+ return ([alias['value'] for alias in aliases[all_aliases[0]]], all_aliases[0])
120
+ return ('no-alias', 'none')
121
+
122
+ def get_datatype(self, item):
123
+ try:
124
+ if type(item) == str:
125
+ entity = self.get_entity(item)
126
+ if entity == 'deleted':
127
+ return entity
128
+ datatype = entity['datatype']
129
+ elif type(item) == dict:
130
+ datatype = item['datatype']
131
+ return datatype
132
+ except KeyError:
133
+ return 'none'
134
+
135
+ def get_claim_values_of(self, item, property_id):
136
+ if type(item) == str:
137
+ entity = self.get_entity(item)
138
+ if entity == 'deleted':
139
+ return entity
140
+ claims = entity['claims']
141
+ elif type(item) == dict:
142
+ claims = item['claims']
143
+ if property_id in claims:
144
+ instance_of_claims = claims[property_id]
145
+ return [i['mainsnak']['datavalue']['value']['id'] for i in instance_of_claims]
146
+ else:
147
+ return []
148
+
149
+ def query_sparql_endpoint(self, sparql_query):
150
+ sparql_query_id = self.get_unique_id_from_str(sparql_query)
151
+ if sparql_query_id in self.entity_cache:
152
+ return self.entity_cache[sparql_query_id]
153
+ else:
154
+ wikidata_sparql_url = 'https://query.wikidata.org/sparql'
155
+ try:
156
+ while True:
157
+ res = requests.get(wikidata_sparql_url, params={"query": sparql_query, "format": "json"})
158
+ if res.status_code in (429,504):
159
+ time.sleep(1)
160
+ continue
161
+ elif res.status_code == 200:
162
+ res = res.json()
163
+ self.entity_cache[sparql_query_id] = res
164
+ self.save_entity_cache()
165
+ return res
166
+ else:
167
+ print(res.status_code)
168
+ raise Exception
169
+ except json.JSONDecodeError as e:
170
+ #pdb.set_trace()
171
+ print(res, res.__dict__)
172
+ raise e
173
+