johnpaulbin commited on
Commit
22d4f29
1 Parent(s): 17528b8
large.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e649a63d0f6cdce341cfde2ac8210cd7a1796420aac21585b1380e283cac01c
3
+ size 1265336
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cbf6f7067d56aa1c2d571bb169f05fba16cea4c263c06fb3f217f42c591a978
3
+ size 89616062
torchmoji/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+
torchmoji/__init__.py ADDED
File without changes
torchmoji/attlayer.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """ Define the Attention Layer of the model.
3
+ """
4
+
5
+ from __future__ import print_function, division
6
+
7
+ import torch
8
+
9
+ from torch.autograd import Variable
10
+ from torch.nn import Module
11
+ from torch.nn.parameter import Parameter
12
+
13
+ class Attention(Module):
14
+ """
15
+ Computes a weighted average of the different channels across timesteps.
16
+ Uses 1 parameter pr. channel to compute the attention value for a single timestep.
17
+ """
18
+
19
+ def __init__(self, attention_size, return_attention=False):
20
+ """ Initialize the attention layer
21
+
22
+ # Arguments:
23
+ attention_size: Size of the attention vector.
24
+ return_attention: If true, output will include the weight for each input token
25
+ used for the prediction
26
+
27
+ """
28
+ super(Attention, self).__init__()
29
+ self.return_attention = return_attention
30
+ self.attention_size = attention_size
31
+ self.attention_vector = Parameter(torch.FloatTensor(attention_size))
32
+ self.attention_vector.data.normal_(std=0.05) # Initialize attention vector
33
+
34
+ def __repr__(self):
35
+ s = '{name}({attention_size}, return attention={return_attention})'
36
+ return s.format(name=self.__class__.__name__, **self.__dict__)
37
+
38
+ def forward(self, inputs, input_lengths):
39
+ """ Forward pass.
40
+
41
+ # Arguments:
42
+ inputs (Torch.Variable): Tensor of input sequences
43
+ input_lengths (torch.LongTensor): Lengths of the sequences
44
+
45
+ # Return:
46
+ Tuple with (representations and attentions if self.return_attention else None).
47
+ """
48
+ logits = inputs.matmul(self.attention_vector)
49
+ unnorm_ai = (logits - logits.max()).exp()
50
+
51
+ # Compute a mask for the attention on the padded sequences
52
+ # See e.g. https://discuss.pytorch.org/t/self-attention-on-words-and-masking/5671/5
53
+ max_len = unnorm_ai.size(1)
54
+ idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0)
55
+ mask = Variable((idxes < input_lengths.unsqueeze(1)).float())
56
+
57
+ # apply mask and renormalize attention scores (weights)
58
+ if self.attention_vector.device.type == "cuda":
59
+ masked_weights = unnorm_ai * mask.cuda()
60
+ else:
61
+ masked_weights = unnorm_ai * mask
62
+ att_sums = masked_weights.sum(dim=1, keepdim=True) # sums per sequence
63
+ attentions = masked_weights.div(att_sums)
64
+
65
+ # apply attention weights
66
+ weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs))
67
+
68
+ # get the final fixed vector representations of the sentences
69
+ representations = weighted.sum(dim=1)
70
+
71
+ return (representations, attentions if self.return_attention else None)
torchmoji/class_avg_finetuning.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """ Class average finetuning functions. Before using any of these finetuning
3
+ functions, ensure that the model is set up with nb_classes=2.
4
+ """
5
+ from __future__ import print_function
6
+
7
+ import uuid
8
+ from time import sleep
9
+ import numpy as np
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.optim as optim
14
+
15
+ from torchmoji.global_variables import (
16
+ FINETUNING_METHODS,
17
+ WEIGHTS_DIR)
18
+ from torchmoji.finetuning import (
19
+ freeze_layers,
20
+ get_data_loader,
21
+ fit_model,
22
+ train_by_chain_thaw,
23
+ find_f1_threshold)
24
+
25
+ def relabel(y, current_label_nr, nb_classes):
26
+ """ Makes a binary classification for a specific class in a
27
+ multi-class dataset.
28
+
29
+ # Arguments:
30
+ y: Outputs to be relabelled.
31
+ current_label_nr: Current label number.
32
+ nb_classes: Total number of classes.
33
+
34
+ # Returns:
35
+ Relabelled outputs of a given multi-class dataset into a binary
36
+ classification dataset.
37
+ """
38
+
39
+ # Handling binary classification
40
+ if nb_classes == 2 and len(y.shape) == 1:
41
+ return y
42
+
43
+ y_new = np.zeros(len(y))
44
+ y_cut = y[:, current_label_nr]
45
+ label_pos = np.where(y_cut == 1)[0]
46
+ y_new[label_pos] = 1
47
+ return y_new
48
+
49
+
50
+ def class_avg_finetune(model, texts, labels, nb_classes, batch_size,
51
+ method, epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
52
+ verbose=True):
53
+ """ Compiles and finetunes the given model.
54
+
55
+ # Arguments:
56
+ model: Model to be finetuned
57
+ texts: List of three lists, containing tokenized inputs for training,
58
+ validation and testing (in that order).
59
+ labels: List of three lists, containing labels for training,
60
+ validation and testing (in that order).
61
+ nb_classes: Number of classes in the dataset.
62
+ batch_size: Batch size.
63
+ method: Finetuning method to be used. For available methods, see
64
+ FINETUNING_METHODS in global_variables.py. Note that the model
65
+ should be defined accordingly (see docstring for torchmoji_transfer())
66
+ epoch_size: Number of samples in an epoch.
67
+ nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
68
+ embed_l2: L2 regularization for the embedding layer.
69
+ verbose: Verbosity flag.
70
+
71
+ # Returns:
72
+ Model after finetuning,
73
+ score after finetuning using the class average F1 metric.
74
+ """
75
+
76
+ if method not in FINETUNING_METHODS:
77
+ raise ValueError('ERROR (class_avg_tune_trainable): '
78
+ 'Invalid method parameter. '
79
+ 'Available options: {}'.format(FINETUNING_METHODS))
80
+
81
+ (X_train, y_train) = (texts[0], labels[0])
82
+ (X_val, y_val) = (texts[1], labels[1])
83
+ (X_test, y_test) = (texts[2], labels[2])
84
+
85
+ checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
86
+ .format(WEIGHTS_DIR, str(uuid.uuid4()))
87
+
88
+ f1_init_path = '{}/torchmoji-f1-init-{}.bin' \
89
+ .format(WEIGHTS_DIR, str(uuid.uuid4()))
90
+
91
+ if method in ['last', 'new']:
92
+ lr = 0.001
93
+ elif method in ['full', 'chain-thaw']:
94
+ lr = 0.0001
95
+
96
+ loss_op = nn.BCEWithLogitsLoss()
97
+
98
+ # Freeze layers if using last
99
+ if method == 'last':
100
+ model = freeze_layers(model, unfrozen_keyword='output_layer')
101
+
102
+ # Define optimizer, for chain-thaw we define it later (after freezing)
103
+ if method == 'last':
104
+ adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
105
+ elif method in ['full', 'new']:
106
+ # Add L2 regulation on embeddings only
107
+ special_params = [id(p) for p in model.embed.parameters()]
108
+ base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
109
+ embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
110
+ adam = optim.Adam([
111
+ {'params': base_params},
112
+ {'params': embed_parameters, 'weight_decay': embed_l2},
113
+ ], lr=lr)
114
+
115
+ # Training
116
+ if verbose:
117
+ print('Method: {}'.format(method))
118
+ print('Classes: {}'.format(nb_classes))
119
+
120
+ if method == 'chain-thaw':
121
+ result = class_avg_chainthaw(model, nb_classes=nb_classes,
122
+ loss_op=loss_op,
123
+ train=(X_train, y_train),
124
+ val=(X_val, y_val),
125
+ test=(X_test, y_test),
126
+ batch_size=batch_size,
127
+ epoch_size=epoch_size,
128
+ nb_epochs=nb_epochs,
129
+ checkpoint_weight_path=checkpoint_path,
130
+ f1_init_weight_path=f1_init_path,
131
+ verbose=verbose)
132
+ else:
133
+ result = class_avg_tune_trainable(model, nb_classes=nb_classes,
134
+ loss_op=loss_op,
135
+ optim_op=adam,
136
+ train=(X_train, y_train),
137
+ val=(X_val, y_val),
138
+ test=(X_test, y_test),
139
+ epoch_size=epoch_size,
140
+ nb_epochs=nb_epochs,
141
+ batch_size=batch_size,
142
+ init_weight_path=f1_init_path,
143
+ checkpoint_weight_path=checkpoint_path,
144
+ verbose=verbose)
145
+ return model, result
146
+
147
+
148
+ def prepare_labels(y_train, y_val, y_test, iter_i, nb_classes):
149
+ # Relabel into binary classification
150
+ y_train_new = relabel(y_train, iter_i, nb_classes)
151
+ y_val_new = relabel(y_val, iter_i, nb_classes)
152
+ y_test_new = relabel(y_test, iter_i, nb_classes)
153
+ return y_train_new, y_val_new, y_test_new
154
+
155
+ def prepare_generators(X_train, y_train_new, X_val, y_val_new, batch_size, epoch_size):
156
+ # Create sample generators
157
+ # Make a fixed validation set to avoid fluctuations in validation
158
+ train_gen = get_data_loader(X_train, y_train_new, batch_size,
159
+ extended_batch_sampler=True)
160
+ val_gen = get_data_loader(X_val, y_val_new, epoch_size,
161
+ extended_batch_sampler=True)
162
+ X_val_resamp, y_val_resamp = next(iter(val_gen))
163
+ return train_gen, X_val_resamp, y_val_resamp
164
+
165
+
166
+ def class_avg_tune_trainable(model, nb_classes, loss_op, optim_op, train, val, test,
167
+ epoch_size, nb_epochs, batch_size,
168
+ init_weight_path, checkpoint_weight_path, patience=5,
169
+ verbose=True):
170
+ """ Finetunes the given model using the F1 measure.
171
+
172
+ # Arguments:
173
+ model: Model to be finetuned.
174
+ nb_classes: Number of classes in the given dataset.
175
+ train: Training data, given as a tuple of (inputs, outputs)
176
+ val: Validation data, given as a tuple of (inputs, outputs)
177
+ test: Testing data, given as a tuple of (inputs, outputs)
178
+ epoch_size: Number of samples in an epoch.
179
+ nb_epochs: Number of epochs.
180
+ batch_size: Batch size.
181
+ init_weight_path: Filepath where weights will be initially saved before
182
+ training each class. This file will be rewritten by the function.
183
+ checkpoint_weight_path: Filepath where weights will be checkpointed to
184
+ during training. This file will be rewritten by the function.
185
+ verbose: Verbosity flag.
186
+
187
+ # Returns:
188
+ F1 score of the trained model
189
+ """
190
+ total_f1 = 0
191
+ nb_iter = nb_classes if nb_classes > 2 else 1
192
+
193
+ # Unpack args
194
+ X_train, y_train = train
195
+ X_val, y_val = val
196
+ X_test, y_test = test
197
+
198
+ # Save and reload initial weights after running for
199
+ # each class to avoid learning across classes
200
+ torch.save(model.state_dict(), init_weight_path)
201
+ for i in range(nb_iter):
202
+ if verbose:
203
+ print('Iteration number {}/{}'.format(i+1, nb_iter))
204
+
205
+ model.load_state_dict(torch.load(init_weight_path))
206
+ y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
207
+ y_test, i, nb_classes)
208
+ train_gen, X_val_resamp, y_val_resamp = \
209
+ prepare_generators(X_train, y_train_new, X_val, y_val_new,
210
+ batch_size, epoch_size)
211
+
212
+ if verbose:
213
+ print("Training..")
214
+ fit_model(model, loss_op, optim_op, train_gen, [(X_val_resamp, y_val_resamp)],
215
+ nb_epochs, checkpoint_weight_path, patience, verbose=0)
216
+
217
+ # Reload the best weights found to avoid overfitting
218
+ # Wait a bit to allow proper closing of weights file
219
+ sleep(1)
220
+ model.load_state_dict(torch.load(checkpoint_weight_path))
221
+
222
+ # Evaluate
223
+ y_pred_val = model(X_val).cpu().numpy()
224
+ y_pred_test = model(X_test).cpu().numpy()
225
+
226
+ f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
227
+ y_test_new, y_pred_test)
228
+ if verbose:
229
+ print('f1_test: {}'.format(f1_test))
230
+ print('best_t: {}'.format(best_t))
231
+ total_f1 += f1_test
232
+
233
+ return total_f1 / nb_iter
234
+
235
+
236
+ def class_avg_chainthaw(model, nb_classes, loss_op, train, val, test, batch_size,
237
+ epoch_size, nb_epochs, checkpoint_weight_path,
238
+ f1_init_weight_path, patience=5,
239
+ initial_lr=0.001, next_lr=0.0001, verbose=True):
240
+ """ Finetunes given model using chain-thaw and evaluates using F1.
241
+ For a dataset with multiple classes, the model is trained once for
242
+ each class, relabeling those classes into a binary classification task.
243
+ The result is an average of all F1 scores for each class.
244
+
245
+ # Arguments:
246
+ model: Model to be finetuned.
247
+ nb_classes: Number of classes in the given dataset.
248
+ train: Training data, given as a tuple of (inputs, outputs)
249
+ val: Validation data, given as a tuple of (inputs, outputs)
250
+ test: Testing data, given as a tuple of (inputs, outputs)
251
+ batch_size: Batch size.
252
+ loss: Loss function to be used during training.
253
+ epoch_size: Number of samples in an epoch.
254
+ nb_epochs: Number of epochs.
255
+ checkpoint_weight_path: Filepath where weights will be checkpointed to
256
+ during training. This file will be rewritten by the function.
257
+ f1_init_weight_path: Filepath where weights will be saved to and
258
+ reloaded from before training each class. This ensures that
259
+ each class is trained independently. This file will be rewritten.
260
+ initial_lr: Initial learning rate. Will only be used for the first
261
+ training step (i.e. the softmax layer)
262
+ next_lr: Learning rate for every subsequent step.
263
+ seed: Random number generator seed.
264
+ verbose: Verbosity flag.
265
+
266
+ # Returns:
267
+ Averaged F1 score.
268
+ """
269
+
270
+ # Unpack args
271
+ X_train, y_train = train
272
+ X_val, y_val = val
273
+ X_test, y_test = test
274
+
275
+ total_f1 = 0
276
+ nb_iter = nb_classes if nb_classes > 2 else 1
277
+
278
+ torch.save(model.state_dict(), f1_init_weight_path)
279
+
280
+ for i in range(nb_iter):
281
+ if verbose:
282
+ print('Iteration number {}/{}'.format(i+1, nb_iter))
283
+
284
+ model.load_state_dict(torch.load(f1_init_weight_path))
285
+ y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
286
+ y_test, i, nb_classes)
287
+ train_gen, X_val_resamp, y_val_resamp = \
288
+ prepare_generators(X_train, y_train_new, X_val, y_val_new,
289
+ batch_size, epoch_size)
290
+
291
+ if verbose:
292
+ print("Training..")
293
+
294
+ # Train using chain-thaw
295
+ train_by_chain_thaw(model=model, train_gen=train_gen,
296
+ val_gen=[(X_val_resamp, y_val_resamp)],
297
+ loss_op=loss_op, patience=patience,
298
+ nb_epochs=nb_epochs,
299
+ checkpoint_path=checkpoint_weight_path,
300
+ initial_lr=initial_lr, next_lr=next_lr,
301
+ verbose=verbose)
302
+
303
+ # Evaluate
304
+ y_pred_val = model(X_val).cpu().numpy()
305
+ y_pred_test = model(X_test).cpu().numpy()
306
+
307
+ f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
308
+ y_test_new, y_pred_test)
309
+
310
+ if verbose:
311
+ print('f1_test: {}'.format(f1_test))
312
+ print('best_t: {}'.format(best_t))
313
+ total_f1 += f1_test
314
+
315
+ return total_f1 / nb_iter
torchmoji/create_vocab.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import print_function, division
3
+
4
+ import glob
5
+ import json
6
+ import uuid
7
+ from copy import deepcopy
8
+ from collections import defaultdict, OrderedDict
9
+ import numpy as np
10
+
11
+ from torchmoji.filter_utils import is_special_token
12
+ from torchmoji.word_generator import WordGenerator
13
+ from torchmoji.global_variables import SPECIAL_TOKENS, VOCAB_PATH
14
+
15
+ class VocabBuilder():
16
+ """ Create vocabulary with words extracted from sentences as fed from a
17
+ word generator.
18
+ """
19
+ def __init__(self, word_gen):
20
+ # initialize any new key with value of 0
21
+ self.word_counts = defaultdict(lambda: 0, {})
22
+ self.word_length_limit=30
23
+
24
+ for token in SPECIAL_TOKENS:
25
+ assert len(token) < self.word_length_limit
26
+ self.word_counts[token] = 0
27
+ self.word_gen = word_gen
28
+
29
+ def count_words_in_sentence(self, words):
30
+ """ Generates word counts for all tokens in the given sentence.
31
+
32
+ # Arguments:
33
+ words: Tokenized sentence whose words should be counted.
34
+ """
35
+ for word in words:
36
+ if 0 < len(word) and len(word) <= self.word_length_limit:
37
+ try:
38
+ self.word_counts[word] += 1
39
+ except KeyError:
40
+ self.word_counts[word] = 1
41
+
42
+ def save_vocab(self, path=None):
43
+ """ Saves the vocabulary into a file.
44
+
45
+ # Arguments:
46
+ path: Where the vocabulary should be saved. If not specified, a
47
+ randomly generated filename is used instead.
48
+ """
49
+ dtype = ([('word','|S{}'.format(self.word_length_limit)),('count','int')])
50
+ np_dict = np.array(self.word_counts.items(), dtype=dtype)
51
+
52
+ # sort from highest to lowest frequency
53
+ np_dict[::-1].sort(order='count')
54
+ data = np_dict
55
+
56
+ if path is None:
57
+ path = str(uuid.uuid4())
58
+
59
+ np.savez_compressed(path, data=data)
60
+ print("Saved dict to {}".format(path))
61
+
62
+ def get_next_word(self):
63
+ """ Returns next tokenized sentence from the word geneerator.
64
+
65
+ # Returns:
66
+ List of strings, representing the next tokenized sentence.
67
+ """
68
+ return self.word_gen.__iter__().next()
69
+
70
+ def count_all_words(self):
71
+ """ Generates word counts for all words in all sentences of the word
72
+ generator.
73
+ """
74
+ for words, _ in self.word_gen:
75
+ self.count_words_in_sentence(words)
76
+
77
+ class MasterVocab():
78
+ """ Combines vocabularies.
79
+ """
80
+ def __init__(self):
81
+
82
+ # initialize custom tokens
83
+ self.master_vocab = {}
84
+
85
+ def populate_master_vocab(self, vocab_path, min_words=1, force_appearance=None):
86
+ """ Populates the master vocabulary using all vocabularies found in the
87
+ given path. Vocabularies should be named *.npz. Expects the
88
+ vocabularies to be numpy arrays with counts. Normalizes the counts
89
+ and combines them.
90
+
91
+ # Arguments:
92
+ vocab_path: Path containing vocabularies to be combined.
93
+ min_words: Minimum amount of occurences a word must have in order
94
+ to be included in the master vocabulary.
95
+ force_appearance: Optional vocabulary filename that will be added
96
+ to the master vocabulary no matter what. This vocabulary must
97
+ be present in vocab_path.
98
+ """
99
+
100
+ paths = glob.glob(vocab_path + '*.npz')
101
+ sizes = {path: 0 for path in paths}
102
+ dicts = {path: {} for path in paths}
103
+
104
+ # set up and get sizes of individual dictionaries
105
+ for path in paths:
106
+ np_data = np.load(path)['data']
107
+
108
+ for entry in np_data:
109
+ word, count = entry
110
+ if count < min_words:
111
+ continue
112
+ if is_special_token(word):
113
+ continue
114
+ dicts[path][word] = count
115
+
116
+ sizes[path] = sum(dicts[path].values())
117
+ print('Overall word count for {} -> {}'.format(path, sizes[path]))
118
+ print('Overall word number for {} -> {}'.format(path, len(dicts[path])))
119
+
120
+ vocab_of_max_size = max(sizes, key=sizes.get)
121
+ max_size = sizes[vocab_of_max_size]
122
+ print('Min: {}, {}, {}'.format(sizes, vocab_of_max_size, max_size))
123
+
124
+ # can force one vocabulary to always be present
125
+ if force_appearance is not None:
126
+ force_appearance_path = [p for p in paths if force_appearance in p][0]
127
+ force_appearance_vocab = deepcopy(dicts[force_appearance_path])
128
+ print(force_appearance_path)
129
+ else:
130
+ force_appearance_path, force_appearance_vocab = None, None
131
+
132
+ # normalize word counts before inserting into master dict
133
+ for path in paths:
134
+ normalization_factor = max_size / sizes[path]
135
+ print('Norm factor for path {} -> {}'.format(path, normalization_factor))
136
+
137
+ for word in dicts[path]:
138
+ if is_special_token(word):
139
+ print("SPECIAL - ", word)
140
+ continue
141
+ normalized_count = dicts[path][word] * normalization_factor
142
+
143
+ # can force one vocabulary to always be present
144
+ if force_appearance_vocab is not None:
145
+ try:
146
+ force_word_count = force_appearance_vocab[word]
147
+ except KeyError:
148
+ continue
149
+ #if force_word_count < 5:
150
+ #continue
151
+
152
+ if word in self.master_vocab:
153
+ self.master_vocab[word] += normalized_count
154
+ else:
155
+ self.master_vocab[word] = normalized_count
156
+
157
+ print('Size of master_dict {}'.format(len(self.master_vocab)))
158
+ print("Hashes for master dict: {}".format(
159
+ len([w for w in self.master_vocab if '#' in w[0]])))
160
+
161
+ def save_vocab(self, path_count, path_vocab, word_limit=100000):
162
+ """ Saves the master vocabulary into a file.
163
+ """
164
+
165
+ # reserve space for 10 special tokens
166
+ words = OrderedDict()
167
+ for token in SPECIAL_TOKENS:
168
+ # store -1 instead of np.inf, which can overflow
169
+ words[token] = -1
170
+
171
+ # sort words by frequency
172
+ desc_order = OrderedDict(sorted(self.master_vocab.items(),
173
+ key=lambda kv: kv[1], reverse=True))
174
+ words.update(desc_order)
175
+
176
+ # use encoding of up to 30 characters (no token conversions)
177
+ # use float to store large numbers (we don't care about precision loss)
178
+ np_vocab = np.array(words.items(),
179
+ dtype=([('word','|S30'),('count','float')]))
180
+
181
+ # output count for debugging
182
+ counts = np_vocab[:word_limit]
183
+ np.savez_compressed(path_count, counts=counts)
184
+
185
+ # output the index of each word for easy lookup
186
+ final_words = OrderedDict()
187
+ for i, w in enumerate(words.keys()[:word_limit]):
188
+ final_words.update({w:i})
189
+ with open(path_vocab, 'w') as f:
190
+ f.write(json.dumps(final_words, indent=4, separators=(',', ': ')))
191
+
192
+
193
+ def all_words_in_sentences(sentences):
194
+ """ Extracts all unique words from a given list of sentences.
195
+
196
+ # Arguments:
197
+ sentences: List or word generator of sentences to be processed.
198
+
199
+ # Returns:
200
+ List of all unique words contained in the given sentences.
201
+ """
202
+ vocab = []
203
+ if isinstance(sentences, WordGenerator):
204
+ sentences = [s for s, _ in sentences]
205
+
206
+ for sentence in sentences:
207
+ for word in sentence:
208
+ if word not in vocab:
209
+ vocab.append(word)
210
+
211
+ return vocab
212
+
213
+
214
+ def extend_vocab_in_file(vocab, max_tokens=10000, vocab_path=VOCAB_PATH):
215
+ """ Extends JSON-formatted vocabulary with words from vocab that are not
216
+ present in the current vocabulary. Adds up to max_tokens words.
217
+ Overwrites file in vocab_path.
218
+
219
+ # Arguments:
220
+ new_vocab: Vocabulary to be added. MUST have word_counts populated, i.e.
221
+ must have run count_all_words() previously.
222
+ max_tokens: Maximum number of words to be added.
223
+ vocab_path: Path to the vocabulary json which is to be extended.
224
+ """
225
+ try:
226
+ with open(vocab_path, 'r') as f:
227
+ current_vocab = json.load(f)
228
+ except IOError:
229
+ print('Vocabulary file not found, expected at ' + vocab_path)
230
+ return
231
+
232
+ extend_vocab(current_vocab, vocab, max_tokens)
233
+
234
+ # Save back to file
235
+ with open(vocab_path, 'w') as f:
236
+ json.dump(current_vocab, f, sort_keys=True, indent=4, separators=(',',': '))
237
+
238
+
239
+ def extend_vocab(current_vocab, new_vocab, max_tokens=10000):
240
+ """ Extends current vocabulary with words from vocab that are not
241
+ present in the current vocabulary. Adds up to max_tokens words.
242
+
243
+ # Arguments:
244
+ current_vocab: Current dictionary of tokens.
245
+ new_vocab: Vocabulary to be added. MUST have word_counts populated, i.e.
246
+ must have run count_all_words() previously.
247
+ max_tokens: Maximum number of words to be added.
248
+
249
+ # Returns:
250
+ How many new tokens have been added.
251
+ """
252
+ if max_tokens < 0:
253
+ max_tokens = 10000
254
+
255
+ words = OrderedDict()
256
+
257
+ # sort words by frequency
258
+ desc_order = OrderedDict(sorted(new_vocab.word_counts.items(),
259
+ key=lambda kv: kv[1], reverse=True))
260
+ words.update(desc_order)
261
+
262
+ base_index = len(current_vocab.keys())
263
+ added = 0
264
+ for word in words:
265
+ if added >= max_tokens:
266
+ break
267
+ if word not in current_vocab.keys():
268
+ current_vocab[word] = base_index + added
269
+ added += 1
270
+
271
+ return added
torchmoji/filter_input.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import print_function, division
3
+ import codecs
4
+ import csv
5
+ import numpy as np
6
+ from emoji import UNICODE_EMOJI
7
+
8
+ def read_english(path="english_words.txt", add_emojis=True):
9
+ # read english words for filtering (includes emojis as part of set)
10
+ english = set()
11
+ with codecs.open(path, "r", "utf-8") as f:
12
+ for line in f:
13
+ line = line.strip().lower().replace('\n', '')
14
+ if len(line):
15
+ english.add(line)
16
+ if add_emojis:
17
+ for e in UNICODE_EMOJI:
18
+ english.add(e)
19
+ return english
20
+
21
+ def read_wanted_emojis(path="wanted_emojis.csv"):
22
+ emojis = []
23
+ with open(path, 'rb') as f:
24
+ reader = csv.reader(f)
25
+ for line in reader:
26
+ line = line[0].strip().replace('\n', '')
27
+ line = line.decode('unicode-escape')
28
+ emojis.append(line)
29
+ return emojis
30
+
31
+ def read_non_english_users(path="unwanted_users.npz"):
32
+ try:
33
+ neu_set = set(np.load(path)['userids'])
34
+ except IOError:
35
+ neu_set = set()
36
+ return neu_set
torchmoji/filter_utils.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ from __future__ import print_function, division, unicode_literals
4
+ import sys
5
+ import re
6
+ import string
7
+ import emoji
8
+ from itertools import groupby
9
+
10
+ import numpy as np
11
+ from torchmoji.tokenizer import RE_MENTION, RE_URL
12
+ from torchmoji.global_variables import SPECIAL_TOKENS
13
+
14
+ try:
15
+ unichr # Python 2
16
+ except NameError:
17
+ unichr = chr # Python 3
18
+
19
+
20
+ AtMentionRegex = re.compile(RE_MENTION)
21
+ urlRegex = re.compile(RE_URL)
22
+
23
+ # from http://bit.ly/2rdjgjE (UTF-8 encodings and Unicode chars)
24
+ VARIATION_SELECTORS = [ '\ufe00',
25
+ '\ufe01',
26
+ '\ufe02',
27
+ '\ufe03',
28
+ '\ufe04',
29
+ '\ufe05',
30
+ '\ufe06',
31
+ '\ufe07',
32
+ '\ufe08',
33
+ '\ufe09',
34
+ '\ufe0a',
35
+ '\ufe0b',
36
+ '\ufe0c',
37
+ '\ufe0d',
38
+ '\ufe0e',
39
+ '\ufe0f']
40
+
41
+ # from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
42
+ ALL_CHARS = (unichr(i) for i in range(sys.maxunicode))
43
+ CONTROL_CHARS = ''.join(map(unichr, list(range(0,32)) + list(range(127,160))))
44
+ CONTROL_CHAR_REGEX = re.compile('[%s]' % re.escape(CONTROL_CHARS))
45
+
46
+ def is_special_token(word):
47
+ equal = False
48
+ for spec in SPECIAL_TOKENS:
49
+ if word == spec:
50
+ equal = True
51
+ break
52
+ return equal
53
+
54
+ def mostly_english(words, english, pct_eng_short=0.5, pct_eng_long=0.6, ignore_special_tokens=True, min_length=2):
55
+ """ Ensure text meets threshold for containing English words """
56
+
57
+ n_words = 0
58
+ n_english = 0
59
+
60
+ if english is None:
61
+ return True, 0, 0
62
+
63
+ for w in words:
64
+ if len(w) < min_length:
65
+ continue
66
+ if punct_word(w):
67
+ continue
68
+ if ignore_special_tokens and is_special_token(w):
69
+ continue
70
+ n_words += 1
71
+ if w in english:
72
+ n_english += 1
73
+
74
+ if n_words < 2:
75
+ return True, n_words, n_english
76
+ if n_words < 5:
77
+ valid_english = n_english >= n_words * pct_eng_short
78
+ else:
79
+ valid_english = n_english >= n_words * pct_eng_long
80
+ return valid_english, n_words, n_english
81
+
82
+ def correct_length(words, min_words, max_words, ignore_special_tokens=True):
83
+ """ Ensure text meets threshold for containing English words
84
+ and that it's within the min and max words limits. """
85
+
86
+ if min_words is None:
87
+ min_words = 0
88
+
89
+ if max_words is None:
90
+ max_words = 99999
91
+
92
+ n_words = 0
93
+ for w in words:
94
+ if punct_word(w):
95
+ continue
96
+ if ignore_special_tokens and is_special_token(w):
97
+ continue
98
+ n_words += 1
99
+ valid = min_words <= n_words and n_words <= max_words
100
+ return valid
101
+
102
+ def punct_word(word, punctuation=string.punctuation):
103
+ return all([True if c in punctuation else False for c in word])
104
+
105
+ def load_non_english_user_set():
106
+ non_english_user_set = set(np.load('uids.npz')['data'])
107
+ return non_english_user_set
108
+
109
+ def non_english_user(userid, non_english_user_set):
110
+ neu_found = int(userid) in non_english_user_set
111
+ return neu_found
112
+
113
+ def separate_emojis_and_text(text):
114
+ emoji_chars = []
115
+ non_emoji_chars = []
116
+ for c in text:
117
+ if c in emoji.UNICODE_EMOJI:
118
+ emoji_chars.append(c)
119
+ else:
120
+ non_emoji_chars.append(c)
121
+ return ''.join(emoji_chars), ''.join(non_emoji_chars)
122
+
123
+ def extract_emojis(text, wanted_emojis):
124
+ text = remove_variation_selectors(text)
125
+ return [c for c in text if c in wanted_emojis]
126
+
127
+ def remove_variation_selectors(text):
128
+ """ Remove styling glyph variants for Unicode characters.
129
+ For instance, remove skin color from emojis.
130
+ """
131
+ for var in VARIATION_SELECTORS:
132
+ text = text.replace(var, '')
133
+ return text
134
+
135
+ def shorten_word(word):
136
+ """ Shorten groupings of 3+ identical consecutive chars to 2, e.g. '!!!!' --> '!!'
137
+ """
138
+
139
+ # only shorten ASCII words
140
+ try:
141
+ word.decode('ascii')
142
+ except (UnicodeDecodeError, UnicodeEncodeError, AttributeError) as e:
143
+ return word
144
+
145
+ # must have at least 3 char to be shortened
146
+ if len(word) < 3:
147
+ return word
148
+
149
+ # find groups of 3+ consecutive letters
150
+ letter_groups = [list(g) for k, g in groupby(word)]
151
+ triple_or_more = [''.join(g) for g in letter_groups if len(g) >= 3]
152
+ if len(triple_or_more) == 0:
153
+ return word
154
+
155
+ # replace letters to find the short word
156
+ short_word = word
157
+ for trip in triple_or_more:
158
+ short_word = short_word.replace(trip, trip[0]*2)
159
+
160
+ return short_word
161
+
162
+ def detect_special_tokens(word):
163
+ try:
164
+ int(word)
165
+ word = SPECIAL_TOKENS[4]
166
+ except ValueError:
167
+ if AtMentionRegex.findall(word):
168
+ word = SPECIAL_TOKENS[2]
169
+ elif urlRegex.findall(word):
170
+ word = SPECIAL_TOKENS[3]
171
+ return word
172
+
173
+ def process_word(word):
174
+ """ Shortening and converting the word to a special token if relevant.
175
+ """
176
+ word = shorten_word(word)
177
+ word = detect_special_tokens(word)
178
+ return word
179
+
180
+ def remove_control_chars(text):
181
+ return CONTROL_CHAR_REGEX.sub('', text)
182
+
183
+ def convert_nonbreaking_space(text):
184
+ # ugly hack handling non-breaking space no matter how badly it's been encoded in the input
185
+ for r in ['\\\\xc2', '\\xc2', '\xc2', '\\\\xa0', '\\xa0', '\xa0']:
186
+ text = text.replace(r, ' ')
187
+ return text
188
+
189
+ def convert_linebreaks(text):
190
+ # ugly hack handling non-breaking space no matter how badly it's been encoded in the input
191
+ # space around to ensure proper tokenization
192
+ for r in ['\\\\n', '\\n', '\n', '\\\\r', '\\r', '\r', '<br>']:
193
+ text = text.replace(r, ' ' + SPECIAL_TOKENS[5] + ' ')
194
+ return text
torchmoji/finetuning.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """ Finetuning functions for doing transfer learning to new datasets.
3
+ """
4
+ from __future__ import print_function
5
+
6
+ import uuid
7
+ from time import sleep
8
+ from io import open
9
+
10
+ import math
11
+ import pickle
12
+ import numpy as np
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.optim as optim
17
+ from sklearn.metrics import accuracy_score
18
+ from torch.autograd import Variable
19
+ from torch.utils.data import Dataset, DataLoader
20
+ from torch.utils.data.sampler import BatchSampler, SequentialSampler
21
+ from torch.nn.utils import clip_grad_norm
22
+
23
+ from sklearn.metrics import f1_score
24
+
25
+ from torchmoji.global_variables import (FINETUNING_METHODS,
26
+ FINETUNING_METRICS,
27
+ WEIGHTS_DIR)
28
+ from torchmoji.tokenizer import tokenize
29
+ from torchmoji.sentence_tokenizer import SentenceTokenizer
30
+
31
+ try:
32
+ unicode
33
+ IS_PYTHON2 = True
34
+ except NameError:
35
+ unicode = str
36
+ IS_PYTHON2 = False
37
+
38
+
39
+ def load_benchmark(path, vocab, extend_with=0):
40
+ """ Loads the given benchmark dataset.
41
+
42
+ Tokenizes the texts using the provided vocabulary, extending it with
43
+ words from the training dataset if extend_with > 0. Splits them into
44
+ three lists: training, validation and testing (in that order).
45
+
46
+ Also calculates the maximum length of the texts and the
47
+ suggested batch_size.
48
+
49
+ # Arguments:
50
+ path: Path to the dataset to be loaded.
51
+ vocab: Vocabulary to be used for tokenizing texts.
52
+ extend_with: If > 0, the vocabulary will be extended with up to
53
+ extend_with tokens from the training set before tokenizing.
54
+
55
+ # Returns:
56
+ A dictionary with the following fields:
57
+ texts: List of three lists, containing tokenized inputs for
58
+ training, validation and testing (in that order).
59
+ labels: List of three lists, containing labels for training,
60
+ validation and testing (in that order).
61
+ added: Number of tokens added to the vocabulary.
62
+ batch_size: Batch size.
63
+ maxlen: Maximum length of an input.
64
+ """
65
+ # Pre-processing dataset
66
+ with open(path, 'rb') as dataset:
67
+ if IS_PYTHON2:
68
+ data = pickle.load(dataset)
69
+ else:
70
+ data = pickle.load(dataset, fix_imports=True)
71
+
72
+ # Decode data
73
+ try:
74
+ texts = [unicode(x) for x in data['texts']]
75
+ except UnicodeDecodeError:
76
+ texts = [x.decode('utf-8') for x in data['texts']]
77
+
78
+ # Extract labels
79
+ labels = [x['label'] for x in data['info']]
80
+
81
+ batch_size, maxlen = calculate_batchsize_maxlen(texts)
82
+
83
+ st = SentenceTokenizer(vocab, maxlen)
84
+
85
+ # Split up dataset. Extend the existing vocabulary with up to extend_with
86
+ # tokens from the training dataset.
87
+ texts, labels, added = st.split_train_val_test(texts,
88
+ labels,
89
+ [data['train_ind'],
90
+ data['val_ind'],
91
+ data['test_ind']],
92
+ extend_with=extend_with)
93
+ return {'texts': texts,
94
+ 'labels': labels,
95
+ 'added': added,
96
+ 'batch_size': batch_size,
97
+ 'maxlen': maxlen}
98
+
99
+
100
+ def calculate_batchsize_maxlen(texts):
101
+ """ Calculates the maximum length in the provided texts and a suitable
102
+ batch size. Rounds up maxlen to the nearest multiple of ten.
103
+
104
+ # Arguments:
105
+ texts: List of inputs.
106
+
107
+ # Returns:
108
+ Batch size,
109
+ max length
110
+ """
111
+ def roundup(x):
112
+ return int(math.ceil(x / 10.0)) * 10
113
+
114
+ # Calculate max length of sequences considered
115
+ # Adjust batch_size accordingly to prevent GPU overflow
116
+ lengths = [len(tokenize(t)) for t in texts]
117
+ maxlen = roundup(np.percentile(lengths, 80.0))
118
+ batch_size = 250 if maxlen <= 100 else 50
119
+ return batch_size, maxlen
120
+
121
+
122
+
123
+ def freeze_layers(model, unfrozen_types=[], unfrozen_keyword=None):
124
+ """ Freezes all layers in the given model, except for ones that are
125
+ explicitly specified to not be frozen.
126
+
127
+ # Arguments:
128
+ model: Model whose layers should be modified.
129
+ unfrozen_types: List of layer types which shouldn't be frozen.
130
+ unfrozen_keyword: Name keywords of layers that shouldn't be frozen.
131
+
132
+ # Returns:
133
+ Model with the selected layers frozen.
134
+ """
135
+ # Get trainable modules
136
+ trainable_modules = [(n, m) for n, m in model.named_children() if len([id(p) for p in m.parameters()]) != 0]
137
+ for name, module in trainable_modules:
138
+ trainable = (any(typ in str(module) for typ in unfrozen_types) or
139
+ (unfrozen_keyword is not None and unfrozen_keyword.lower() in name.lower()))
140
+ change_trainable(module, trainable, verbose=False)
141
+ return model
142
+
143
+
144
+ def change_trainable(module, trainable, verbose=False):
145
+ """ Helper method that freezes or unfreezes a given layer.
146
+
147
+ # Arguments:
148
+ module: Module to be modified.
149
+ trainable: Whether the layer should be frozen or unfrozen.
150
+ verbose: Verbosity flag.
151
+ """
152
+
153
+ if verbose: print('Changing MODULE', module, 'to trainable =', trainable)
154
+ for name, param in module.named_parameters():
155
+ if verbose: print('Setting weight', name, 'to trainable =', trainable)
156
+ param.requires_grad = trainable
157
+
158
+ if verbose:
159
+ action = 'Unfroze' if trainable else 'Froze'
160
+ if verbose: print("{} {}".format(action, module))
161
+
162
+
163
+ def find_f1_threshold(model, val_gen, test_gen, average='binary'):
164
+ """ Choose a threshold for F1 based on the validation dataset
165
+ (see https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/
166
+ for details on why to find another threshold than simply 0.5)
167
+
168
+ # Arguments:
169
+ model: pyTorch model
170
+ val_gen: Validation set dataloader.
171
+ test_gen: Testing set dataloader.
172
+
173
+ # Returns:
174
+ F1 score for the given data and
175
+ the corresponding F1 threshold
176
+ """
177
+ thresholds = np.arange(0.01, 0.5, step=0.01)
178
+ f1_scores = []
179
+
180
+ model.eval()
181
+ val_out = [(y, model(X)) for X, y in val_gen]
182
+ y_val, y_pred_val = (list(t) for t in zip(*val_out))
183
+
184
+ test_out = [(y, model(X)) for X, y in test_gen]
185
+ y_test, y_pred_test = (list(t) for t in zip(*val_out))
186
+
187
+ for t in thresholds:
188
+ y_pred_val_ind = (y_pred_val > t)
189
+ f1_val = f1_score(y_val, y_pred_val_ind, average=average)
190
+ f1_scores.append(f1_val)
191
+
192
+ best_t = thresholds[np.argmax(f1_scores)]
193
+ y_pred_ind = (y_pred_test > best_t)
194
+ f1_test = f1_score(y_test, y_pred_ind, average=average)
195
+ return f1_test, best_t
196
+
197
+
198
+ def finetune(model, texts, labels, nb_classes, batch_size, method,
199
+ metric='acc', epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
200
+ verbose=1):
201
+ """ Compiles and finetunes the given pytorch model.
202
+
203
+ # Arguments:
204
+ model: Model to be finetuned
205
+ texts: List of three lists, containing tokenized inputs for training,
206
+ validation and testing (in that order).
207
+ labels: List of three lists, containing labels for training,
208
+ validation and testing (in that order).
209
+ nb_classes: Number of classes in the dataset.
210
+ batch_size: Batch size.
211
+ method: Finetuning method to be used. For available methods, see
212
+ FINETUNING_METHODS in global_variables.py.
213
+ metric: Evaluation metric to be used. For available metrics, see
214
+ FINETUNING_METRICS in global_variables.py.
215
+ epoch_size: Number of samples in an epoch.
216
+ nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
217
+ embed_l2: L2 regularization for the embedding layer.
218
+ verbose: Verbosity flag.
219
+
220
+ # Returns:
221
+ Model after finetuning,
222
+ score after finetuning using the provided metric.
223
+ """
224
+
225
+ if method not in FINETUNING_METHODS:
226
+ raise ValueError('ERROR (finetune): Invalid method parameter. '
227
+ 'Available options: {}'.format(FINETUNING_METHODS))
228
+ if metric not in FINETUNING_METRICS:
229
+ raise ValueError('ERROR (finetune): Invalid metric parameter. '
230
+ 'Available options: {}'.format(FINETUNING_METRICS))
231
+
232
+ train_gen = get_data_loader(texts[0], labels[0], batch_size,
233
+ extended_batch_sampler=True, epoch_size=epoch_size)
234
+ val_gen = get_data_loader(texts[1], labels[1], batch_size,
235
+ extended_batch_sampler=False)
236
+ test_gen = get_data_loader(texts[2], labels[2], batch_size,
237
+ extended_batch_sampler=False)
238
+
239
+ checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
240
+ .format(WEIGHTS_DIR, str(uuid.uuid4()))
241
+
242
+ if method in ['last', 'new']:
243
+ lr = 0.001
244
+ elif method in ['full', 'chain-thaw']:
245
+ lr = 0.0001
246
+
247
+ loss_op = nn.BCEWithLogitsLoss() if nb_classes <= 2 \
248
+ else nn.CrossEntropyLoss()
249
+
250
+ # Freeze layers if using last
251
+ if method == 'last':
252
+ model = freeze_layers(model, unfrozen_keyword='output_layer')
253
+
254
+ # Define optimizer, for chain-thaw we define it later (after freezing)
255
+ if method == 'last':
256
+ adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
257
+ elif method in ['full', 'new']:
258
+ # Add L2 regulation on embeddings only
259
+ embed_params_id = [id(p) for p in model.embed.parameters()]
260
+ output_layer_params_id = [id(p) for p in model.output_layer.parameters()]
261
+ base_params = [p for p in model.parameters()
262
+ if id(p) not in embed_params_id and id(p) not in output_layer_params_id and p.requires_grad]
263
+ embed_params = [p for p in model.parameters() if id(p) in embed_params_id and p.requires_grad]
264
+ output_layer_params = [p for p in model.parameters() if id(p) in output_layer_params_id and p.requires_grad]
265
+ adam = optim.Adam([
266
+ {'params': base_params},
267
+ {'params': embed_params, 'weight_decay': embed_l2},
268
+ {'params': output_layer_params, 'lr': 0.001},
269
+ ], lr=lr)
270
+
271
+ # Training
272
+ if verbose:
273
+ print('Method: {}'.format(method))
274
+ print('Metric: {}'.format(metric))
275
+ print('Classes: {}'.format(nb_classes))
276
+
277
+ if method == 'chain-thaw':
278
+ result = chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, embed_l2=embed_l2,
279
+ evaluate=metric, verbose=verbose)
280
+ else:
281
+ result = tune_trainable(model, loss_op, adam, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path,
282
+ evaluate=metric, verbose=verbose)
283
+ return model, result
284
+
285
+
286
+ def tune_trainable(model, loss_op, optim_op, train_gen, val_gen, test_gen,
287
+ nb_epochs, checkpoint_path, patience=5, evaluate='acc',
288
+ verbose=2):
289
+ """ Finetunes the given model using the accuracy measure.
290
+
291
+ # Arguments:
292
+ model: Model to be finetuned.
293
+ nb_classes: Number of classes in the given dataset.
294
+ train: Training data, given as a tuple of (inputs, outputs)
295
+ val: Validation data, given as a tuple of (inputs, outputs)
296
+ test: Testing data, given as a tuple of (inputs, outputs)
297
+ epoch_size: Number of samples in an epoch.
298
+ nb_epochs: Number of epochs.
299
+ batch_size: Batch size.
300
+ checkpoint_weight_path: Filepath where weights will be checkpointed to
301
+ during training. This file will be rewritten by the function.
302
+ patience: Patience for callback methods.
303
+ evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.
304
+ verbose: Verbosity flag.
305
+
306
+ # Returns:
307
+ Accuracy of the trained model, ONLY if 'evaluate' is set.
308
+ """
309
+ if verbose:
310
+ print("Trainable weights: {}".format([n for n, p in model.named_parameters() if p.requires_grad]))
311
+ print("Training...")
312
+ if evaluate == 'acc':
313
+ print("Evaluation on test set prior training:", evaluate_using_acc(model, test_gen))
314
+ elif evaluate == 'weighted_f1':
315
+ print("Evaluation on test set prior training:", evaluate_using_weighted_f1(model, test_gen, val_gen))
316
+
317
+ fit_model(model, loss_op, optim_op, train_gen, val_gen, nb_epochs, checkpoint_path, patience)
318
+
319
+ # Reload the best weights found to avoid overfitting
320
+ # Wait a bit to allow proper closing of weights file
321
+ sleep(1)
322
+ model.load_state_dict(torch.load(checkpoint_path))
323
+ if verbose >= 2:
324
+ print("Loaded weights from {}".format(checkpoint_path))
325
+
326
+ if evaluate == 'acc':
327
+ return evaluate_using_acc(model, test_gen)
328
+ elif evaluate == 'weighted_f1':
329
+ return evaluate_using_weighted_f1(model, test_gen, val_gen)
330
+
331
+
332
+ def evaluate_using_weighted_f1(model, test_gen, val_gen):
333
+ """ Evaluation function using macro weighted F1 score.
334
+
335
+ # Arguments:
336
+ model: Model to be evaluated.
337
+ X_test: Inputs of the testing set.
338
+ y_test: Outputs of the testing set.
339
+ X_val: Inputs of the validation set.
340
+ y_val: Outputs of the validation set.
341
+ batch_size: Batch size.
342
+
343
+ # Returns:
344
+ Weighted F1 score of the given model.
345
+ """
346
+ # Evaluate on test and val data
347
+ f1_test, _ = find_f1_threshold(model, test_gen, val_gen, average='weighted_f1')
348
+ return f1_test
349
+
350
+
351
+ def evaluate_using_acc(model, test_gen):
352
+ """ Evaluation function using accuracy.
353
+
354
+ # Arguments:
355
+ model: Model to be evaluated.
356
+ test_gen: Testing data iterator (DataLoader)
357
+
358
+ # Returns:
359
+ Accuracy of the given model.
360
+ """
361
+
362
+ # Validate on test_data
363
+ model.eval()
364
+ accs = []
365
+ for i, data in enumerate(test_gen):
366
+ x, y = data
367
+ outs = model(x)
368
+ if model.nb_classes > 2:
369
+ pred = torch.max(outs, 1)[1]
370
+ acc = accuracy_score(y.squeeze().numpy(), pred.squeeze().numpy())
371
+ else:
372
+ pred = (outs >= 0).long()
373
+ acc = (pred == y).double().sum() / len(pred)
374
+ accs.append(acc)
375
+ return np.mean(accs)
376
+
377
+
378
+ def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
379
+ patience=5, initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, evaluate='acc', verbose=1):
380
+ """ Finetunes given model using chain-thaw and evaluates using accuracy.
381
+
382
+ # Arguments:
383
+ model: Model to be finetuned.
384
+ train: Training data, given as a tuple of (inputs, outputs)
385
+ val: Validation data, given as a tuple of (inputs, outputs)
386
+ test: Testing data, given as a tuple of (inputs, outputs)
387
+ batch_size: Batch size.
388
+ loss: Loss function to be used during training.
389
+ epoch_size: Number of samples in an epoch.
390
+ nb_epochs: Number of epochs.
391
+ checkpoint_weight_path: Filepath where weights will be checkpointed to
392
+ during training. This file will be rewritten by the function.
393
+ initial_lr: Initial learning rate. Will only be used for the first
394
+ training step (i.e. the output_layer layer)
395
+ next_lr: Learning rate for every subsequent step.
396
+ seed: Random number generator seed.
397
+ verbose: Verbosity flag.
398
+ evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.
399
+
400
+ # Returns:
401
+ Accuracy of the finetuned model.
402
+ """
403
+ if verbose:
404
+ print('Training..')
405
+
406
+ # Train using chain-thaw
407
+ train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
408
+ initial_lr, next_lr, embed_l2, verbose)
409
+
410
+ if evaluate == 'acc':
411
+ return evaluate_using_acc(model, test_gen)
412
+ elif evaluate == 'weighted_f1':
413
+ return evaluate_using_weighted_f1(model, test_gen, val_gen)
414
+
415
+
416
+ def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
417
+ initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, verbose=1):
418
+ """ Finetunes model using the chain-thaw method.
419
+
420
+ This is done as follows:
421
+ 1) Freeze every layer except the last (output_layer) layer and train it.
422
+ 2) Freeze every layer except the first layer and train it.
423
+ 3) Freeze every layer except the second etc., until the second last layer.
424
+ 4) Unfreeze all layers and train entire model.
425
+
426
+ # Arguments:
427
+ model: Model to be trained.
428
+ train_gen: Training sample generator.
429
+ val_data: Validation data.
430
+ loss: Loss function to be used.
431
+ finetuning_args: Training early stopping and checkpoint saving parameters
432
+ epoch_size: Number of samples in an epoch.
433
+ nb_epochs: Number of epochs.
434
+ checkpoint_weight_path: Where weight checkpoints should be saved.
435
+ batch_size: Batch size.
436
+ initial_lr: Initial learning rate. Will only be used for the first
437
+ training step (i.e. the output_layer layer)
438
+ next_lr: Learning rate for every subsequent step.
439
+ verbose: Verbosity flag.
440
+ """
441
+ # Get trainable layers
442
+ layers = [m for m in model.children() if len([id(p) for p in m.parameters()]) != 0]
443
+
444
+ # Bring last layer to front
445
+ layers.insert(0, layers.pop(len(layers) - 1))
446
+
447
+ # Add None to the end to signify finetuning all layers
448
+ layers.append(None)
449
+
450
+ lr = None
451
+ # Finetune each layer one by one and finetune all of them at once
452
+ # at the end
453
+ for layer in layers:
454
+ if lr is None:
455
+ lr = initial_lr
456
+ elif lr == initial_lr:
457
+ lr = next_lr
458
+
459
+ # Freeze all except current layer
460
+ for _layer in layers:
461
+ if _layer is not None:
462
+ trainable = _layer == layer or layer is None
463
+ change_trainable(_layer, trainable=trainable, verbose=False)
464
+
465
+ # Verify we froze the right layers
466
+ for _layer in model.children():
467
+ assert all(p.requires_grad == (_layer == layer) for p in _layer.parameters()) or layer is None
468
+
469
+ if verbose:
470
+ if layer is None:
471
+ print('Finetuning all layers')
472
+ else:
473
+ print('Finetuning {}'.format(layer))
474
+
475
+ special_params = [id(p) for p in model.embed.parameters()]
476
+ base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
477
+ embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
478
+ adam = optim.Adam([
479
+ {'params': base_params},
480
+ {'params': embed_parameters, 'weight_decay': embed_l2},
481
+ ], lr=lr)
482
+
483
+ fit_model(model, loss_op, adam, train_gen, val_gen, nb_epochs,
484
+ checkpoint_path, patience)
485
+
486
+ # Reload the best weights found to avoid overfitting
487
+ # Wait a bit to allow proper closing of weights file
488
+ sleep(1)
489
+ model.load_state_dict(torch.load(checkpoint_path))
490
+ if verbose >= 2:
491
+ print("Loaded weights from {}".format(checkpoint_path))
492
+
493
+
494
+ def calc_loss(loss_op, pred, yv):
495
+ if type(loss_op) is nn.CrossEntropyLoss:
496
+ return loss_op(pred.squeeze(), yv.squeeze())
497
+ else:
498
+ return loss_op(pred.squeeze(), yv.squeeze().float())
499
+
500
+
501
+ def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
502
+ checkpoint_path, patience):
503
+ """ Analog to Keras fit_generator function.
504
+
505
+ # Arguments:
506
+ model: Model to be finetuned.
507
+ loss_op: loss operation (BCEWithLogitsLoss or CrossEntropy for e.g.)
508
+ optim_op: optimization operation (Adam e.g.)
509
+ train_gen: Training data iterator (DataLoader)
510
+ val_gen: Validation data iterator (DataLoader)
511
+ epochs: Number of epochs.
512
+ checkpoint_path: Filepath where weights will be checkpointed to
513
+ during training. This file will be rewritten by the function.
514
+ patience: Patience for callback methods.
515
+ verbose: Verbosity flag.
516
+
517
+ # Returns:
518
+ Accuracy of the trained model, ONLY if 'evaluate' is set.
519
+ """
520
+ # Save original checkpoint
521
+ torch.save(model.state_dict(), checkpoint_path)
522
+
523
+ model.eval()
524
+ best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
525
+ print("original val loss", best_loss)
526
+
527
+ epoch_without_impr = 0
528
+ for epoch in range(epochs):
529
+ for i, data in enumerate(train_gen):
530
+ X_train, y_train = data
531
+ X_train = Variable(X_train, requires_grad=False)
532
+ y_train = Variable(y_train, requires_grad=False)
533
+ model.train()
534
+ optim_op.zero_grad()
535
+ output = model(X_train)
536
+ loss = calc_loss(loss_op, output, y_train)
537
+ loss.backward()
538
+ clip_grad_norm(model.parameters(), 1)
539
+ optim_op.step()
540
+
541
+ acc = evaluate_using_acc(model, [(X_train.data, y_train.data)])
542
+ print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy()[0], "train acc", acc)
543
+
544
+ model.eval()
545
+ acc = evaluate_using_acc(model, val_gen)
546
+ print("val acc", acc)
547
+
548
+ val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen])
549
+ print("val loss", val_loss)
550
+ if best_loss is not None and val_loss >= best_loss:
551
+ epoch_without_impr += 1
552
+ print('No improvement over previous best loss: ', best_loss)
553
+
554
+ # Save checkpoint
555
+ if best_loss is None or val_loss < best_loss:
556
+ best_loss = val_loss
557
+ torch.save(model.state_dict(), checkpoint_path)
558
+ print('Saving model at', checkpoint_path)
559
+
560
+ # Early stopping
561
+ if epoch_without_impr >= patience:
562
+ break
563
+
564
+ def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42):
565
+ """ Returns a dataloader that enables larger epochs on small datasets and
566
+ has upsampling functionality.
567
+
568
+ # Arguments:
569
+ X_in: Inputs of the given dataset.
570
+ y_in: Outputs of the given dataset.
571
+ batch_size: Batch size.
572
+ epoch_size: Number of samples in an epoch.
573
+ upsample: Whether upsampling should be done. This flag should only be
574
+ set on binary class problems.
575
+
576
+ # Returns:
577
+ DataLoader.
578
+ """
579
+ dataset = DeepMojiDataset(X_in, y_in)
580
+
581
+ if extended_batch_sampler:
582
+ batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed)
583
+ else:
584
+ batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False)
585
+
586
+ return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0)
587
+
588
+ class DeepMojiDataset(Dataset):
589
+ """ A simple Dataset class.
590
+
591
+ # Arguments:
592
+ X_in: Inputs of the given dataset.
593
+ y_in: Outputs of the given dataset.
594
+
595
+ # __getitem__ output:
596
+ (torch.LongTensor, torch.LongTensor)
597
+ """
598
+ def __init__(self, X_in, y_in):
599
+ # Check if we have Torch.LongTensor inputs (assume Numpy array otherwise)
600
+ if not isinstance(X_in, torch.LongTensor):
601
+ X_in = torch.from_numpy(X_in.astype('int64')).long()
602
+ if not isinstance(y_in, torch.LongTensor):
603
+ y_in = torch.from_numpy(y_in.astype('int64')).long()
604
+
605
+ self.X_in = torch.split(X_in, 1, dim=0)
606
+ self.y_in = torch.split(y_in, 1, dim=0)
607
+
608
+ def __len__(self):
609
+ return len(self.X_in)
610
+
611
+ def __getitem__(self, idx):
612
+ return self.X_in[idx].squeeze(), self.y_in[idx].squeeze()
613
+
614
+ class DeepMojiBatchSampler(object):
615
+ """A Batch sampler that enables larger epochs on small datasets and
616
+ has upsampling functionality.
617
+
618
+ # Arguments:
619
+ y_in: Labels of the dataset.
620
+ batch_size: Batch size.
621
+ epoch_size: Number of samples in an epoch.
622
+ upsample: Whether upsampling should be done. This flag should only be
623
+ set on binary class problems.
624
+ seed: Random number generator seed.
625
+
626
+ # __iter__ output:
627
+ iterator of lists (batches) of indices in the dataset
628
+ """
629
+
630
+ def __init__(self, y_in, batch_size, epoch_size, upsample, seed):
631
+ self.batch_size = batch_size
632
+ self.epoch_size = epoch_size
633
+ self.upsample = upsample
634
+
635
+ np.random.seed(seed)
636
+
637
+ if upsample:
638
+ # Should only be used on binary class problems
639
+ assert len(y_in.shape) == 1
640
+ neg = np.where(y_in.numpy() == 0)[0]
641
+ pos = np.where(y_in.numpy() == 1)[0]
642
+ assert epoch_size % 2 == 0
643
+ samples_pr_class = int(epoch_size / 2)
644
+ else:
645
+ ind = range(len(y_in))
646
+
647
+ if not upsample:
648
+ # Randomly sample observations in a balanced way
649
+ self.sample_ind = np.random.choice(ind, epoch_size, replace=True)
650
+ else:
651
+ # Randomly sample observations in a balanced way
652
+ sample_neg = np.random.choice(neg, samples_pr_class, replace=True)
653
+ sample_pos = np.random.choice(pos, samples_pr_class, replace=True)
654
+ concat_ind = np.concatenate((sample_neg, sample_pos), axis=0)
655
+
656
+ # Shuffle to avoid labels being in specific order
657
+ # (all negative then positive)
658
+ p = np.random.permutation(len(concat_ind))
659
+ self.sample_ind = concat_ind[p]
660
+
661
+ label_dist = np.mean(y_in.numpy()[self.sample_ind])
662
+ assert(label_dist > 0.45)
663
+ assert(label_dist < 0.55)
664
+
665
+ def __iter__(self):
666
+ # Hand-off data using batch_size
667
+ for i in range(int(self.epoch_size/self.batch_size)):
668
+ start = i * self.batch_size
669
+ end = min(start + self.batch_size, self.epoch_size)
670
+ yield self.sample_ind[start:end]
671
+
672
+ def __len__(self):
673
+ # Take care of the last (maybe incomplete) batch
674
+ return (self.epoch_size + self.batch_size - 1) // self.batch_size
torchmoji/global_variables.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """ Global variables.
3
+ """
4
+ import tempfile
5
+ from os.path import abspath, dirname
6
+
7
+ # The ordering of these special tokens matter
8
+ # blank tokens can be used for new purposes
9
+ # Tokenizer should be updated if special token prefix is changed
10
+ SPECIAL_PREFIX = 'CUSTOM_'
11
+ SPECIAL_TOKENS = ['CUSTOM_MASK',
12
+ 'CUSTOM_UNKNOWN',
13
+ 'CUSTOM_AT',
14
+ 'CUSTOM_URL',
15
+ 'CUSTOM_NUMBER',
16
+ 'CUSTOM_BREAK']
17
+ SPECIAL_TOKENS.extend(['{}BLANK_{}'.format(SPECIAL_PREFIX, i) for i in range(6, 10)])
18
+
19
+ ROOT_PATH = dirname(dirname(abspath(__file__)))
20
+ VOCAB_PATH = '{}/model/vocabulary.json'.format(ROOT_PATH)
21
+ PRETRAINED_PATH = '{}/model/pytorch_model.bin'.format(ROOT_PATH)
22
+
23
+ WEIGHTS_DIR = tempfile.mkdtemp()
24
+
25
+ NB_TOKENS = 50000
26
+ NB_EMOJI_CLASSES = 64
27
+ FINETUNING_METHODS = ['last', 'full', 'new', 'chain-thaw']
28
+ FINETUNING_METRICS = ['acc', 'weighted']
torchmoji/lstm.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """ Implement a pyTorch LSTM with hard sigmoid reccurent activation functions.
3
+ Adapted from the non-cuda variant of pyTorch LSTM at
4
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py
5
+ """
6
+
7
+ from __future__ import print_function, division
8
+ import math
9
+ import torch
10
+
11
+ from torch.nn import Module
12
+ from torch.nn.parameter import Parameter
13
+ from torch.nn.utils.rnn import PackedSequence
14
+ import torch.nn.functional as F
15
+
16
+ class LSTMHardSigmoid(Module):
17
+
18
+ def __init__(self, input_size, hidden_size,
19
+ num_layers=1, bias=True, batch_first=False,
20
+ dropout=0, bidirectional=False):
21
+ super(LSTMHardSigmoid, self).__init__()
22
+ self.input_size = input_size
23
+ self.hidden_size = hidden_size
24
+ self.num_layers = num_layers
25
+ self.bias = bias
26
+ self.batch_first = batch_first
27
+ self.dropout = dropout
28
+ self.dropout_state = {}
29
+ self.bidirectional = bidirectional
30
+ num_directions = 2 if bidirectional else 1
31
+
32
+ gate_size = 4 * hidden_size
33
+
34
+ self._all_weights = []
35
+ for layer in range(num_layers):
36
+ for direction in range(num_directions):
37
+ layer_input_size = input_size if layer == 0 else hidden_size * num_directions
38
+
39
+ w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
40
+ w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
41
+ b_ih = Parameter(torch.Tensor(gate_size))
42
+ b_hh = Parameter(torch.Tensor(gate_size))
43
+ layer_params = (w_ih, w_hh, b_ih, b_hh)
44
+
45
+ suffix = '_reverse' if direction == 1 else ''
46
+ param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
47
+ if bias:
48
+ param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
49
+ param_names = [x.format(layer, suffix) for x in param_names]
50
+
51
+ for name, param in zip(param_names, layer_params):
52
+ setattr(self, name, param)
53
+ self._all_weights.append(param_names)
54
+
55
+ self.flatten_parameters()
56
+ self.reset_parameters()
57
+
58
+ def flatten_parameters(self):
59
+ """Resets parameter data pointer so that they can use faster code paths.
60
+
61
+ Right now, this is a no-op wince we don't use CUDA acceleration.
62
+ """
63
+ self._data_ptrs = []
64
+
65
+ def _apply(self, fn):
66
+ ret = super(LSTMHardSigmoid, self)._apply(fn)
67
+ self.flatten_parameters()
68
+ return ret
69
+
70
+ def reset_parameters(self):
71
+ stdv = 1.0 / math.sqrt(self.hidden_size)
72
+ for weight in self.parameters():
73
+ weight.data.uniform_(-stdv, stdv)
74
+
75
+ def forward(self, input, hx=None):
76
+ is_packed = isinstance(input, PackedSequence)
77
+ if is_packed:
78
+ input, batch_sizes ,_ ,_ = input
79
+ max_batch_size = batch_sizes[0]
80
+ else:
81
+ batch_sizes = None
82
+ max_batch_size = input.size(0) if self.batch_first else input.size(1)
83
+
84
+ if hx is None:
85
+ num_directions = 2 if self.bidirectional else 1
86
+ hx = torch.autograd.Variable(input.data.new(self.num_layers *
87
+ num_directions,
88
+ max_batch_size,
89
+ self.hidden_size).zero_(), requires_grad=False)
90
+ hx = (hx, hx)
91
+
92
+ has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs
93
+ if has_flat_weights:
94
+ first_data = next(self.parameters()).data
95
+ assert first_data.storage().size() == self._param_buf_size
96
+ flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size]))
97
+ else:
98
+ flat_weight = None
99
+ func = AutogradRNN(
100
+ self.input_size,
101
+ self.hidden_size,
102
+ num_layers=self.num_layers,
103
+ batch_first=self.batch_first,
104
+ dropout=self.dropout,
105
+ train=self.training,
106
+ bidirectional=self.bidirectional,
107
+ batch_sizes=batch_sizes,
108
+ dropout_state=self.dropout_state,
109
+ flat_weight=flat_weight
110
+ )
111
+ output, hidden = func(input, self.all_weights, hx)
112
+ if is_packed:
113
+ output = PackedSequence(output, batch_sizes)
114
+ return output, hidden
115
+
116
+ def __repr__(self):
117
+ s = '{name}({input_size}, {hidden_size}'
118
+ if self.num_layers != 1:
119
+ s += ', num_layers={num_layers}'
120
+ if self.bias is not True:
121
+ s += ', bias={bias}'
122
+ if self.batch_first is not False:
123
+ s += ', batch_first={batch_first}'
124
+ if self.dropout != 0:
125
+ s += ', dropout={dropout}'
126
+ if self.bidirectional is not False:
127
+ s += ', bidirectional={bidirectional}'
128
+ s += ')'
129
+ return s.format(name=self.__class__.__name__, **self.__dict__)
130
+
131
+ def __setstate__(self, d):
132
+ super(LSTMHardSigmoid, self).__setstate__(d)
133
+ self.__dict__.setdefault('_data_ptrs', [])
134
+ if 'all_weights' in d:
135
+ self._all_weights = d['all_weights']
136
+ if isinstance(self._all_weights[0][0], str):
137
+ return
138
+ num_layers = self.num_layers
139
+ num_directions = 2 if self.bidirectional else 1
140
+ self._all_weights = []
141
+ for layer in range(num_layers):
142
+ for direction in range(num_directions):
143
+ suffix = '_reverse' if direction == 1 else ''
144
+ weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
145
+ weights = [x.format(layer, suffix) for x in weights]
146
+ if self.bias:
147
+ self._all_weights += [weights]
148
+ else:
149
+ self._all_weights += [weights[:2]]
150
+
151
+ @property
152
+ def all_weights(self):
153
+ return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
154
+
155
+ def AutogradRNN(input_size, hidden_size, num_layers=1, batch_first=False,
156
+ dropout=0, train=True, bidirectional=False, batch_sizes=None,
157
+ dropout_state=None, flat_weight=None):
158
+
159
+ cell = LSTMCell
160
+
161
+ if batch_sizes is None:
162
+ rec_factory = Recurrent
163
+ else:
164
+ rec_factory = variable_recurrent_factory(batch_sizes)
165
+
166
+ if bidirectional:
167
+ layer = (rec_factory(cell), rec_factory(cell, reverse=True))
168
+ else:
169
+ layer = (rec_factory(cell),)
170
+
171
+ func = StackedRNN(layer,
172
+ num_layers,
173
+ True,
174
+ dropout=dropout,
175
+ train=train)
176
+
177
+ def forward(input, weight, hidden):
178
+ if batch_first and batch_sizes is None:
179
+ input = input.transpose(0, 1)
180
+
181
+ nexth, output = func(input, hidden, weight)
182
+
183
+ if batch_first and batch_sizes is None:
184
+ output = output.transpose(0, 1)
185
+
186
+ return output, nexth
187
+
188
+ return forward
189
+
190
+ def Recurrent(inner, reverse=False):
191
+ def forward(input, hidden, weight):
192
+ output = []
193
+ steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
194
+ for i in steps:
195
+ hidden = inner(input[i], hidden, *weight)
196
+ # hack to handle LSTM
197
+ output.append(hidden[0] if isinstance(hidden, tuple) else hidden)
198
+
199
+ if reverse:
200
+ output.reverse()
201
+ output = torch.cat(output, 0).view(input.size(0), *output[0].size())
202
+
203
+ return hidden, output
204
+
205
+ return forward
206
+
207
+
208
+ def variable_recurrent_factory(batch_sizes):
209
+ def fac(inner, reverse=False):
210
+ if reverse:
211
+ return VariableRecurrentReverse(batch_sizes, inner)
212
+ else:
213
+ return VariableRecurrent(batch_sizes, inner)
214
+ return fac
215
+
216
+ def VariableRecurrent(batch_sizes, inner):
217
+ def forward(input, hidden, weight):
218
+ output = []
219
+ input_offset = 0
220
+ last_batch_size = batch_sizes[0]
221
+ hiddens = []
222
+ flat_hidden = not isinstance(hidden, tuple)
223
+ if flat_hidden:
224
+ hidden = (hidden,)
225
+ for batch_size in batch_sizes:
226
+ step_input = input[input_offset:input_offset + batch_size]
227
+ input_offset += batch_size
228
+
229
+ dec = last_batch_size - batch_size
230
+ if dec > 0:
231
+ hiddens.append(tuple(h[-dec:] for h in hidden))
232
+ hidden = tuple(h[:-dec] for h in hidden)
233
+ last_batch_size = batch_size
234
+
235
+ if flat_hidden:
236
+ hidden = (inner(step_input, hidden[0], *weight),)
237
+ else:
238
+ hidden = inner(step_input, hidden, *weight)
239
+
240
+ output.append(hidden[0])
241
+ hiddens.append(hidden)
242
+ hiddens.reverse()
243
+
244
+ hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
245
+ assert hidden[0].size(0) == batch_sizes[0]
246
+ if flat_hidden:
247
+ hidden = hidden[0]
248
+ output = torch.cat(output, 0)
249
+
250
+ return hidden, output
251
+
252
+ return forward
253
+
254
+
255
+ def VariableRecurrentReverse(batch_sizes, inner):
256
+ def forward(input, hidden, weight):
257
+ output = []
258
+ input_offset = input.size(0)
259
+ last_batch_size = batch_sizes[-1]
260
+ initial_hidden = hidden
261
+ flat_hidden = not isinstance(hidden, tuple)
262
+ if flat_hidden:
263
+ hidden = (hidden,)
264
+ initial_hidden = (initial_hidden,)
265
+ hidden = tuple(h[:batch_sizes[-1]] for h in hidden)
266
+ for batch_size in reversed(batch_sizes):
267
+ inc = batch_size - last_batch_size
268
+ if inc > 0:
269
+ hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0)
270
+ for h, ih in zip(hidden, initial_hidden))
271
+ last_batch_size = batch_size
272
+ step_input = input[input_offset - batch_size:input_offset]
273
+ input_offset -= batch_size
274
+
275
+ if flat_hidden:
276
+ hidden = (inner(step_input, hidden[0], *weight),)
277
+ else:
278
+ hidden = inner(step_input, hidden, *weight)
279
+ output.append(hidden[0])
280
+
281
+ output.reverse()
282
+ output = torch.cat(output, 0)
283
+ if flat_hidden:
284
+ hidden = hidden[0]
285
+ return hidden, output
286
+
287
+ return forward
288
+
289
+ def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True):
290
+
291
+ num_directions = len(inners)
292
+ total_layers = num_layers * num_directions
293
+
294
+ def forward(input, hidden, weight):
295
+ assert(len(weight) == total_layers)
296
+ next_hidden = []
297
+
298
+ if lstm:
299
+ hidden = list(zip(*hidden))
300
+
301
+ for i in range(num_layers):
302
+ all_output = []
303
+ for j, inner in enumerate(inners):
304
+ l = i * num_directions + j
305
+
306
+ hy, output = inner(input, hidden[l], weight[l])
307
+ next_hidden.append(hy)
308
+ all_output.append(output)
309
+
310
+ input = torch.cat(all_output, input.dim() - 1)
311
+
312
+ if dropout != 0 and i < num_layers - 1:
313
+ input = F.dropout(input, p=dropout, training=train, inplace=False)
314
+
315
+ if lstm:
316
+ next_h, next_c = zip(*next_hidden)
317
+ next_hidden = (
318
+ torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
319
+ torch.cat(next_c, 0).view(total_layers, *next_c[0].size())
320
+ )
321
+ else:
322
+ next_hidden = torch.cat(next_hidden, 0).view(
323
+ total_layers, *next_hidden[0].size())
324
+
325
+ return next_hidden, input
326
+
327
+ return forward
328
+
329
+ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
330
+ """
331
+ A modified LSTM cell with hard sigmoid activation on the input, forget and output gates.
332
+ """
333
+ hx, cx = hidden
334
+ gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
335
+
336
+ ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
337
+
338
+ ingate = hard_sigmoid(ingate)
339
+ forgetgate = hard_sigmoid(forgetgate)
340
+ cellgate = F.tanh(cellgate)
341
+ outgate = hard_sigmoid(outgate)
342
+
343
+ cy = (forgetgate * cx) + (ingate * cellgate)
344
+ hy = outgate * F.tanh(cy)
345
+
346
+ return hy, cy
347
+
348
+ def hard_sigmoid(x):
349
+ """
350
+ Computes element-wise hard sigmoid of x.
351
+ See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279
352
+ """
353
+ x = (0.2 * x) + 0.5
354
+ x = F.threshold(-x, -1, -1)
355
+ x = F.threshold(-x, 0, 0)
356
+ return x
torchmoji/model_def.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """ Model definition functions and weight loading.
3
+ """
4
+
5
+ from __future__ import print_function, division, unicode_literals
6
+
7
+ from os.path import exists
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.autograd import Variable
12
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
13
+
14
+ from torchmoji.lstm import LSTMHardSigmoid
15
+ from torchmoji.attlayer import Attention
16
+ from torchmoji.global_variables import NB_TOKENS, NB_EMOJI_CLASSES
17
+
18
+
19
+ def torchmoji_feature_encoding(weight_path, return_attention=False):
20
+ """ Loads the pretrained torchMoji model for extracting features
21
+ from the penultimate feature layer. In this way, it transforms
22
+ the text into its emotional encoding.
23
+
24
+ # Arguments:
25
+ weight_path: Path to model weights to be loaded.
26
+ return_attention: If true, output will include weight of each input token
27
+ used for the prediction
28
+
29
+ # Returns:
30
+ Pretrained model for encoding text into feature vectors.
31
+ """
32
+
33
+ model = TorchMoji(nb_classes=None,
34
+ nb_tokens=NB_TOKENS,
35
+ feature_output=True,
36
+ return_attention=return_attention)
37
+ load_specific_weights(model, weight_path, exclude_names=['output_layer'])
38
+ return model
39
+
40
+
41
+ def torchmoji_emojis(weight_path, return_attention=False):
42
+ """ Loads the pretrained torchMoji model for extracting features
43
+ from the penultimate feature layer. In this way, it transforms
44
+ the text into its emotional encoding.
45
+
46
+ # Arguments:
47
+ weight_path: Path to model weights to be loaded.
48
+ return_attention: If true, output will include weight of each input token
49
+ used for the prediction
50
+
51
+ # Returns:
52
+ Pretrained model for encoding text into feature vectors.
53
+ """
54
+
55
+ model = TorchMoji(nb_classes=NB_EMOJI_CLASSES,
56
+ nb_tokens=NB_TOKENS,
57
+ return_attention=return_attention)
58
+ model.load_state_dict(torch.load(weight_path))
59
+ return model
60
+
61
+
62
+ def torchmoji_transfer(nb_classes, weight_path=None, extend_embedding=0,
63
+ embed_dropout_rate=0.1, final_dropout_rate=0.5):
64
+ """ Loads the pretrained torchMoji model for finetuning/transfer learning.
65
+ Does not load weights for the softmax layer.
66
+
67
+ Note that if you are planning to use class average F1 for evaluation,
68
+ nb_classes should be set to 2 instead of the actual number of classes
69
+ in the dataset, since binary classification will be performed on each
70
+ class individually.
71
+
72
+ Note that for the 'new' method, weight_path should be left as None.
73
+
74
+ # Arguments:
75
+ nb_classes: Number of classes in the dataset.
76
+ weight_path: Path to model weights to be loaded.
77
+ extend_embedding: Number of tokens that have been added to the
78
+ vocabulary on top of NB_TOKENS. If this number is larger than 0,
79
+ the embedding layer's dimensions are adjusted accordingly, with the
80
+ additional weights being set to random values.
81
+ embed_dropout_rate: Dropout rate for the embedding layer.
82
+ final_dropout_rate: Dropout rate for the final Softmax layer.
83
+
84
+ # Returns:
85
+ Model with the given parameters.
86
+ """
87
+
88
+ model = TorchMoji(nb_classes=nb_classes,
89
+ nb_tokens=NB_TOKENS + extend_embedding,
90
+ embed_dropout_rate=embed_dropout_rate,
91
+ final_dropout_rate=final_dropout_rate,
92
+ output_logits=True)
93
+ if weight_path is not None:
94
+ load_specific_weights(model, weight_path,
95
+ exclude_names=['output_layer'],
96
+ extend_embedding=extend_embedding)
97
+ return model
98
+
99
+
100
+ class TorchMoji(nn.Module):
101
+ def __init__(self, nb_classes, nb_tokens, feature_output=False, output_logits=False,
102
+ embed_dropout_rate=0, final_dropout_rate=0, return_attention=False):
103
+ """
104
+ torchMoji model.
105
+ IMPORTANT: The model is loaded in evaluation mode by default (self.eval())
106
+
107
+ # Arguments:
108
+ nb_classes: Number of classes in the dataset.
109
+ nb_tokens: Number of tokens in the dataset (i.e. vocabulary size).
110
+ feature_output: If True the model returns the penultimate
111
+ feature vector rather than Softmax probabilities
112
+ (defaults to False).
113
+ output_logits: If True the model returns logits rather than probabilities
114
+ (defaults to False).
115
+ embed_dropout_rate: Dropout rate for the embedding layer.
116
+ final_dropout_rate: Dropout rate for the final Softmax layer.
117
+ return_attention: If True the model also returns attention weights over the sentence
118
+ (defaults to False).
119
+ """
120
+ super(TorchMoji, self).__init__()
121
+
122
+ embedding_dim = 256
123
+ hidden_size = 512
124
+ attention_size = 4 * hidden_size + embedding_dim
125
+
126
+ self.feature_output = feature_output
127
+ self.embed_dropout_rate = embed_dropout_rate
128
+ self.final_dropout_rate = final_dropout_rate
129
+ self.return_attention = return_attention
130
+ self.hidden_size = hidden_size
131
+ self.output_logits = output_logits
132
+ self.nb_classes = nb_classes
133
+
134
+ self.add_module('embed', nn.Embedding(nb_tokens, embedding_dim))
135
+ # dropout2D: embedding channels are dropped out instead of words
136
+ # many exampels in the datasets contain few words that losing one or more words can alter the emotions completely
137
+ self.add_module('embed_dropout', nn.Dropout2d(embed_dropout_rate))
138
+ self.add_module('lstm_0', LSTMHardSigmoid(embedding_dim, hidden_size, batch_first=True, bidirectional=True))
139
+ self.add_module('lstm_1', LSTMHardSigmoid(hidden_size*2, hidden_size, batch_first=True, bidirectional=True))
140
+ self.add_module('attention_layer', Attention(attention_size=attention_size, return_attention=return_attention))
141
+ if not feature_output:
142
+ self.add_module('final_dropout', nn.Dropout(final_dropout_rate))
143
+ if output_logits:
144
+ self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1)))
145
+ else:
146
+ self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1),
147
+ nn.Softmax() if self.nb_classes > 2 else nn.Sigmoid()))
148
+ self.init_weights()
149
+ # Put model in evaluation mode by default
150
+ self.eval()
151
+
152
+ def init_weights(self):
153
+ """
154
+ Here we reproduce Keras default initialization weights for consistency with Keras version
155
+ """
156
+ ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
157
+ hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
158
+ b = (param.data for name, param in self.named_parameters() if 'bias' in name)
159
+ nn.init.uniform(self.embed.weight.data, a=-0.5, b=0.5)
160
+ for t in ih:
161
+ nn.init.xavier_uniform(t)
162
+ for t in hh:
163
+ nn.init.orthogonal(t)
164
+ for t in b:
165
+ nn.init.constant(t, 0)
166
+ if not self.feature_output:
167
+ nn.init.xavier_uniform(self.output_layer[0].weight.data)
168
+
169
+ def forward(self, input_seqs):
170
+ """ Forward pass.
171
+
172
+ # Arguments:
173
+ input_seqs: Can be one of Numpy array, Torch.LongTensor, Torch.Variable, Torch.PackedSequence.
174
+
175
+ # Return:
176
+ Same format as input format (except for PackedSequence returned as Variable).
177
+ """
178
+ # Check if we have Torch.LongTensor inputs or not Torch.Variable (assume Numpy array in this case), take note to return same format
179
+ return_numpy = False
180
+ return_tensor = False
181
+ if isinstance(input_seqs, (torch.LongTensor, torch.cuda.LongTensor)):
182
+ input_seqs = Variable(input_seqs)
183
+ return_tensor = True
184
+ elif not isinstance(input_seqs, Variable):
185
+ input_seqs = Variable(torch.from_numpy(input_seqs.astype('int64')).long())
186
+ return_numpy = True
187
+
188
+ # If we don't have a packed inputs, let's pack it
189
+ reorder_output = False
190
+ if not isinstance(input_seqs, PackedSequence):
191
+ ho = self.lstm_0.weight_hh_l0.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
192
+ co = self.lstm_0.weight_hh_l0.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
193
+
194
+ # Reorder batch by sequence length
195
+ input_lengths = torch.LongTensor([torch.max(input_seqs[i, :].data.nonzero()) + 1 for i in range(input_seqs.size()[0])])
196
+ input_lengths, perm_idx = input_lengths.sort(0, descending=True)
197
+ input_seqs = input_seqs[perm_idx][:, :input_lengths.max()]
198
+
199
+ # Pack sequence and work on data tensor to reduce embeddings/dropout computations
200
+ packed_input = pack_padded_sequence(input_seqs, input_lengths.cpu().numpy(), batch_first=True)
201
+ reorder_output = True
202
+ else:
203
+ ho = self.lstm_0.weight_hh_l0.data.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
204
+ co = self.lstm_0.weight_hh_l0.data.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
205
+ input_lengths = input_seqs.batch_sizes
206
+ packed_input = input_seqs
207
+
208
+ hidden = (Variable(ho, requires_grad=False), Variable(co, requires_grad=False))
209
+
210
+ # Embed with an activation function to bound the values of the embeddings
211
+ x = self.embed(packed_input.data)
212
+ x = nn.Tanh()(x)
213
+
214
+ # pyTorch 2D dropout2d operate on axis 1 which is fine for us
215
+ x = self.embed_dropout(x)
216
+
217
+ # Update packed sequence data for RNN
218
+ packed_input = PackedSequence(x, packed_input.batch_sizes)
219
+
220
+ # skip-connection from embedding to output eases gradient-flow and allows access to lower-level features
221
+ # ordering of the way the merge is done is important for consistency with the pretrained model
222
+ lstm_0_output, _ = self.lstm_0(packed_input, hidden)
223
+ lstm_1_output, _ = self.lstm_1(lstm_0_output, hidden)
224
+
225
+ # Update packed sequence data for attention layer
226
+ packed_input = PackedSequence(torch.cat((lstm_1_output.data,
227
+ lstm_0_output.data,
228
+ packed_input.data), dim=1),
229
+ packed_input.batch_sizes)
230
+
231
+ input_seqs, _ = pad_packed_sequence(packed_input, batch_first=True)
232
+
233
+ x, att_weights = self.attention_layer(input_seqs, input_lengths)
234
+
235
+ # output class probabilities or penultimate feature vector
236
+ if not self.feature_output:
237
+ x = self.final_dropout(x)
238
+ outputs = self.output_layer(x)
239
+ else:
240
+ outputs = x
241
+
242
+ # Reorder output if needed
243
+ if reorder_output:
244
+ reorered = Variable(outputs.data.new(outputs.size()))
245
+ reorered[perm_idx] = outputs
246
+ outputs = reorered
247
+
248
+ # Adapt return format if needed
249
+ if return_tensor:
250
+ outputs = outputs.data
251
+ if return_numpy:
252
+ outputs = outputs.data.numpy()
253
+
254
+ if self.return_attention:
255
+ return outputs, att_weights
256
+ else:
257
+ return outputs
258
+
259
+
260
+ def load_specific_weights(model, weight_path, exclude_names=[], extend_embedding=0, verbose=True):
261
+ """ Loads model weights from the given file path, excluding any
262
+ given layers.
263
+
264
+ # Arguments:
265
+ model: Model whose weights should be loaded.
266
+ weight_path: Path to file containing model weights.
267
+ exclude_names: List of layer names whose weights should not be loaded.
268
+ extend_embedding: Number of new words being added to vocabulary.
269
+ verbose: Verbosity flag.
270
+
271
+ # Raises:
272
+ ValueError if the file at weight_path does not exist.
273
+ """
274
+ if not exists(weight_path):
275
+ raise ValueError('ERROR (load_weights): The weights file at {} does '
276
+ 'not exist. Refer to the README for instructions.'
277
+ .format(weight_path))
278
+
279
+ if extend_embedding and 'embed' in exclude_names:
280
+ raise ValueError('ERROR (load_weights): Cannot extend a vocabulary '
281
+ 'without loading the embedding weights.')
282
+
283
+ # Copy only weights from the temporary model that are wanted
284
+ # for the specific task (e.g. the Softmax is often ignored)
285
+ weights = torch.load(weight_path)
286
+ for key, weight in weights.items():
287
+ if any(excluded in key for excluded in exclude_names):
288
+ if verbose:
289
+ print('Ignoring weights for {}'.format(key))
290
+ continue
291
+
292
+ try:
293
+ model_w = model.state_dict()[key]
294
+ except KeyError:
295
+ raise KeyError("Weights had parameters {},".format(key)
296
+ + " but could not find this parameters in model.")
297
+
298
+ if verbose:
299
+ print('Loading weights for {}'.format(key))
300
+
301
+ # extend embedding layer to allow new randomly initialized words
302
+ # if requested. Otherwise, just load the weights for the layer.
303
+ if 'embed' in key and extend_embedding > 0:
304
+ weight = torch.cat((weight, model_w[NB_TOKENS:, :]), dim=0)
305
+ if verbose:
306
+ print('Extended vocabulary for embedding layer ' +
307
+ 'from {} to {} tokens.'.format(
308
+ NB_TOKENS, NB_TOKENS + extend_embedding))
309
+ try:
310
+ model_w.copy_(weight)
311
+ except:
312
+ print('While copying the weigths named {}, whose dimensions in the model are'
313
+ ' {} and whose dimensions in the saved file are {}, ...'.format(
314
+ key, model_w.size(), weight.size()))
315
+ raise
torchmoji/sentence_tokenizer.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ Provides functionality for converting a given list of tokens (words) into
4
+ numbers, according to the given vocabulary.
5
+ '''
6
+ from __future__ import print_function, division, unicode_literals
7
+
8
+ import numbers
9
+ import numpy as np
10
+
11
+ from torchmoji.create_vocab import extend_vocab, VocabBuilder
12
+ from torchmoji.word_generator import WordGenerator
13
+ from torchmoji.global_variables import SPECIAL_TOKENS
14
+
15
+ # import torch
16
+
17
+ from sklearn.model_selection import train_test_split
18
+
19
+ from copy import deepcopy
20
+
21
+ class SentenceTokenizer():
22
+ """ Create numpy array of tokens corresponding to input sentences.
23
+ The vocabulary can include Unicode tokens.
24
+ """
25
+ def __init__(self, vocabulary, fixed_length, custom_wordgen=None,
26
+ ignore_sentences_with_only_custom=False, masking_value=0,
27
+ unknown_value=1):
28
+ """ Needs a dictionary as input for the vocabulary.
29
+ """
30
+
31
+ if len(vocabulary) > np.iinfo('uint16').max:
32
+ raise ValueError('Dictionary is too big ({} tokens) for the numpy '
33
+ 'datatypes used (max limit={}). Reduce vocabulary'
34
+ ' or adjust code accordingly!'
35
+ .format(len(vocabulary), np.iinfo('uint16').max))
36
+
37
+ # Shouldn't be able to modify the given vocabulary
38
+ self.vocabulary = deepcopy(vocabulary)
39
+ self.fixed_length = fixed_length
40
+ self.ignore_sentences_with_only_custom = ignore_sentences_with_only_custom
41
+ self.masking_value = masking_value
42
+ self.unknown_value = unknown_value
43
+
44
+ # Initialized with an empty stream of sentences that must then be fed
45
+ # to the generator at a later point for reusability.
46
+ # A custom word generator can be used for domain-specific filtering etc
47
+ if custom_wordgen is not None:
48
+ assert custom_wordgen.stream is None
49
+ self.wordgen = custom_wordgen
50
+ self.uses_custom_wordgen = True
51
+ else:
52
+ self.wordgen = WordGenerator(None, allow_unicode_text=True,
53
+ ignore_emojis=False,
54
+ remove_variation_selectors=True,
55
+ break_replacement=True)
56
+ self.uses_custom_wordgen = False
57
+
58
+ def tokenize_sentences(self, sentences, reset_stats=True, max_sentences=None):
59
+ """ Converts a given list of sentences into a numpy array according to
60
+ its vocabulary.
61
+
62
+ # Arguments:
63
+ sentences: List of sentences to be tokenized.
64
+ reset_stats: Whether the word generator's stats should be reset.
65
+ max_sentences: Maximum length of sentences. Must be set if the
66
+ length cannot be inferred from the input.
67
+
68
+ # Returns:
69
+ Numpy array of the tokenization sentences with masking,
70
+ infos,
71
+ stats
72
+
73
+ # Raises:
74
+ ValueError: When maximum length is not set and cannot be inferred.
75
+ """
76
+
77
+ if max_sentences is None and not hasattr(sentences, '__len__'):
78
+ raise ValueError('Either you must provide an array with a length'
79
+ 'attribute (e.g. a list) or specify the maximum '
80
+ 'length yourself using `max_sentences`!')
81
+ n_sentences = (max_sentences if max_sentences is not None
82
+ else len(sentences))
83
+
84
+ if self.masking_value == 0:
85
+ tokens = np.zeros((n_sentences, self.fixed_length), dtype='uint16')
86
+ else:
87
+ tokens = (np.ones((n_sentences, self.fixed_length), dtype='uint16')
88
+ * self.masking_value)
89
+
90
+ if reset_stats:
91
+ self.wordgen.reset_stats()
92
+
93
+ # With a custom word generator info can be extracted from each
94
+ # sentence (e.g. labels)
95
+ infos = []
96
+
97
+ # Returns words as strings and then map them to vocabulary
98
+ self.wordgen.stream = sentences
99
+ next_insert = 0
100
+ n_ignored_unknowns = 0
101
+ for s_words, s_info in self.wordgen:
102
+ s_tokens = self.find_tokens(s_words)
103
+
104
+ if (self.ignore_sentences_with_only_custom and
105
+ np.all([True if t < len(SPECIAL_TOKENS)
106
+ else False for t in s_tokens])):
107
+ n_ignored_unknowns += 1
108
+ continue
109
+ if len(s_tokens) > self.fixed_length:
110
+ s_tokens = s_tokens[:self.fixed_length]
111
+ tokens[next_insert,:len(s_tokens)] = s_tokens
112
+ infos.append(s_info)
113
+ next_insert += 1
114
+
115
+ # For standard word generators all sentences should be tokenized
116
+ # this is not necessarily the case for custom wordgenerators as they
117
+ # may filter the sentences etc.
118
+ if not self.uses_custom_wordgen and not self.ignore_sentences_with_only_custom:
119
+ assert len(sentences) == next_insert
120
+ else:
121
+ # adjust based on actual tokens received
122
+ tokens = tokens[:next_insert]
123
+ infos = infos[:next_insert]
124
+
125
+ return tokens, infos, self.wordgen.stats
126
+
127
+ def find_tokens(self, words):
128
+ assert len(words) > 0
129
+ tokens = []
130
+ for w in words:
131
+ try:
132
+ tokens.append(self.vocabulary[w])
133
+ except KeyError:
134
+ tokens.append(self.unknown_value)
135
+ return tokens
136
+
137
+ def split_train_val_test(self, sentences, info_dicts,
138
+ split_parameter=[0.7, 0.1, 0.2], extend_with=0):
139
+ """ Splits given sentences into three different datasets: training,
140
+ validation and testing.
141
+
142
+ # Arguments:
143
+ sentences: The sentences to be tokenized.
144
+ info_dicts: A list of dicts that contain information about each
145
+ sentence (e.g. a label).
146
+ split_parameter: A parameter for deciding the splits between the
147
+ three different datasets. If instead of being passed three
148
+ values, three lists are passed, then these will be used to
149
+ specify which observation belong to which dataset.
150
+ extend_with: An optional parameter. If > 0 then this is the number
151
+ of tokens added to the vocabulary from this dataset. The
152
+ expanded vocab will be generated using only the training set,
153
+ but is applied to all three sets.
154
+
155
+ # Returns:
156
+ List of three lists of tokenized sentences,
157
+
158
+ List of three corresponding dictionaries with information,
159
+
160
+ How many tokens have been added to the vocab. Make sure to extend
161
+ the embedding layer of the model accordingly.
162
+ """
163
+
164
+ # If passed three lists, use those directly
165
+ if isinstance(split_parameter, list) and \
166
+ all(isinstance(x, list) for x in split_parameter) and \
167
+ len(split_parameter) == 3:
168
+
169
+ # Helper function to verify provided indices are numbers in range
170
+ def verify_indices(inds):
171
+ return list(filter(lambda i: isinstance(i, numbers.Number)
172
+ and i < len(sentences), inds))
173
+
174
+ ind_train = verify_indices(split_parameter[0])
175
+ ind_val = verify_indices(split_parameter[1])
176
+ ind_test = verify_indices(split_parameter[2])
177
+ else:
178
+ # Split sentences and dicts
179
+ ind = list(range(len(sentences)))
180
+ ind_train, ind_test = train_test_split(ind, test_size=split_parameter[2])
181
+ ind_train, ind_val = train_test_split(ind_train, test_size=split_parameter[1])
182
+
183
+ # Map indices to data
184
+ train = np.array([sentences[x] for x in ind_train])
185
+ test = np.array([sentences[x] for x in ind_test])
186
+ val = np.array([sentences[x] for x in ind_val])
187
+
188
+ info_train = np.array([info_dicts[x] for x in ind_train])
189
+ info_test = np.array([info_dicts[x] for x in ind_test])
190
+ info_val = np.array([info_dicts[x] for x in ind_val])
191
+
192
+ added = 0
193
+ # Extend vocabulary with training set tokens
194
+ if extend_with > 0:
195
+ wg = WordGenerator(train)
196
+ vb = VocabBuilder(wg)
197
+ vb.count_all_words()
198
+ added = extend_vocab(self.vocabulary, vb, max_tokens=extend_with)
199
+
200
+ # Wrap results
201
+ result = [self.tokenize_sentences(s)[0] for s in [train, val, test]]
202
+ result_infos = [info_train, info_val, info_test]
203
+ # if type(result_infos[0][0]) in [np.double, np.float, np.int64, np.int32, np.uint8]:
204
+ # result_infos = [torch.from_numpy(label).long() for label in result_infos]
205
+
206
+ return result, result_infos, added
207
+
208
+ def to_sentence(self, sentence_idx):
209
+ """ Converts a tokenized sentence back to a list of words.
210
+
211
+ # Arguments:
212
+ sentence_idx: List of numbers, representing a tokenized sentence
213
+ given the current vocabulary.
214
+
215
+ # Returns:
216
+ String created by converting all numbers back to words and joined
217
+ together with spaces.
218
+ """
219
+ # Have to recalculate the mappings in case the vocab was extended.
220
+ ind_to_word = {ind: word for word, ind in self.vocabulary.items()}
221
+
222
+ sentence_as_list = [ind_to_word[x] for x in sentence_idx]
223
+ cleaned_list = [x for x in sentence_as_list if x != 'CUSTOM_MASK']
224
+ return " ".join(cleaned_list)
225
+
226
+
227
+ def coverage(dataset, verbose=False):
228
+ """ Computes the percentage of words in a given dataset that are unknown.
229
+
230
+ # Arguments:
231
+ dataset: Tokenized dataset to be checked.
232
+ verbose: Verbosity flag.
233
+
234
+ # Returns:
235
+ Percentage of unknown tokens.
236
+ """
237
+ n_total = np.count_nonzero(dataset)
238
+ n_unknown = np.sum(dataset == 1)
239
+ coverage = 1.0 - float(n_unknown) / n_total
240
+
241
+ if verbose:
242
+ print("Unknown words: {}".format(n_unknown))
243
+ print("Total words: {}".format(n_total))
244
+ print("Coverage: {}".format(coverage))
245
+ return coverage
torchmoji/tokenizer.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ '''
3
+ Splits up a Unicode string into a list of tokens.
4
+ Recognises:
5
+ - Abbreviations
6
+ - URLs
7
+ - Emails
8
+ - #hashtags
9
+ - @mentions
10
+ - emojis
11
+ - emoticons (limited support)
12
+
13
+ Multiple consecutive symbols are also treated as a single token.
14
+ '''
15
+ from __future__ import absolute_import, division, print_function, unicode_literals
16
+
17
+ import re
18
+
19
+ # Basic patterns.
20
+ RE_NUM = r'[0-9]+'
21
+ RE_WORD = r'[a-zA-Z]+'
22
+ RE_WHITESPACE = r'\s+'
23
+ RE_ANY = r'.'
24
+
25
+ # Combined words such as 'red-haired' or 'CUSTOM_TOKEN'
26
+ RE_COMB = r'[a-zA-Z]+[-_][a-zA-Z]+'
27
+
28
+ # English-specific patterns
29
+ RE_CONTRACTIONS = RE_WORD + r'\'' + RE_WORD
30
+
31
+ TITLES = [
32
+ r'Mr\.',
33
+ r'Ms\.',
34
+ r'Mrs\.',
35
+ r'Dr\.',
36
+ r'Prof\.',
37
+ ]
38
+ # Ensure case insensitivity
39
+ RE_TITLES = r'|'.join([r'(?i)' + t for t in TITLES])
40
+
41
+ # Symbols have to be created as separate patterns in order to match consecutive
42
+ # identical symbols.
43
+ SYMBOLS = r'()<!?.,/\'\"-_=\\§|´ˇ°[]<>{}~$^&*;:%+\xa3€`'
44
+ RE_SYMBOL = r'|'.join([re.escape(s) + r'+' for s in SYMBOLS])
45
+
46
+ # Hash symbols and at symbols have to be defined separately in order to not
47
+ # clash with hashtags and mentions if there are multiple - i.e.
48
+ # ##hello -> ['#', '#hello'] instead of ['##', 'hello']
49
+ SPECIAL_SYMBOLS = r'|#+(?=#[a-zA-Z0-9_]+)|@+(?=@[a-zA-Z0-9_]+)|#+|@+'
50
+ RE_SYMBOL += SPECIAL_SYMBOLS
51
+
52
+ RE_ABBREVIATIONS = r'\b(?<!\.)(?:[A-Za-z]\.){2,}'
53
+
54
+ # Twitter-specific patterns
55
+ RE_HASHTAG = r'#[a-zA-Z0-9_]+'
56
+ RE_MENTION = r'@[a-zA-Z0-9_]+'
57
+
58
+ RE_URL = r'(?:https?://|www\.)(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
59
+ RE_EMAIL = r'\b[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+\b'
60
+
61
+ # Emoticons and emojis
62
+ RE_HEART = r'(?:<+/?3+)+'
63
+ EMOTICONS_START = [
64
+ r'>:',
65
+ r':',
66
+ r'=',
67
+ r';',
68
+ ]
69
+ EMOTICONS_MID = [
70
+ r'-',
71
+ r',',
72
+ r'^',
73
+ '\'',
74
+ '\"',
75
+ ]
76
+ EMOTICONS_END = [
77
+ r'D',
78
+ r'd',
79
+ r'p',
80
+ r'P',
81
+ r'v',
82
+ r')',
83
+ r'o',
84
+ r'O',
85
+ r'(',
86
+ r'3',
87
+ r'/',
88
+ r'|',
89
+ '\\',
90
+ ]
91
+ EMOTICONS_EXTRA = [
92
+ r'-_-',
93
+ r'x_x',
94
+ r'^_^',
95
+ r'o.o',
96
+ r'o_o',
97
+ r'(:',
98
+ r'):',
99
+ r');',
100
+ r'(;',
101
+ ]
102
+
103
+ RE_EMOTICON = r'|'.join([re.escape(s) for s in EMOTICONS_EXTRA])
104
+ for s in EMOTICONS_START:
105
+ for m in EMOTICONS_MID:
106
+ for e in EMOTICONS_END:
107
+ RE_EMOTICON += '|{0}{1}?{2}+'.format(re.escape(s), re.escape(m), re.escape(e))
108
+
109
+ # requires ucs4 in python2.7 or python3+
110
+ # RE_EMOJI = r"""[\U0001F300-\U0001F64F\U0001F680-\U0001F6FF\u2600-\u26FF\u2700-\u27BF]"""
111
+ # safe for all python
112
+ RE_EMOJI = r"""\ud83c[\udf00-\udfff]|\ud83d[\udc00-\ude4f\ude80-\udeff]|[\u2600-\u26FF\u2700-\u27BF]"""
113
+
114
+ # List of matched token patterns, ordered from most specific to least specific.
115
+ TOKENS = [
116
+ RE_URL,
117
+ RE_EMAIL,
118
+ RE_COMB,
119
+ RE_HASHTAG,
120
+ RE_MENTION,
121
+ RE_HEART,
122
+ RE_EMOTICON,
123
+ RE_CONTRACTIONS,
124
+ RE_TITLES,
125
+ RE_ABBREVIATIONS,
126
+ RE_NUM,
127
+ RE_WORD,
128
+ RE_SYMBOL,
129
+ RE_EMOJI,
130
+ RE_ANY
131
+ ]
132
+
133
+ # List of ignored token patterns
134
+ IGNORED = [
135
+ RE_WHITESPACE
136
+ ]
137
+
138
+ # Final pattern
139
+ RE_PATTERN = re.compile(r'|'.join(IGNORED) + r'|(' + r'|'.join(TOKENS) + r')',
140
+ re.UNICODE)
141
+
142
+
143
+ def tokenize(text):
144
+ '''Splits given input string into a list of tokens.
145
+
146
+ # Arguments:
147
+ text: Input string to be tokenized.
148
+
149
+ # Returns:
150
+ List of strings (tokens).
151
+ '''
152
+ result = RE_PATTERN.findall(text)
153
+
154
+ # Remove empty strings
155
+ result = [t for t in result if t.strip()]
156
+ return result
torchmoji/word_generator.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ ''' Extracts lists of words from a given input to be used for later vocabulary
3
+ generation or for creating tokenized datasets.
4
+ Supports functionality for handling different file types and
5
+ filtering/processing of this input.
6
+ '''
7
+
8
+ from __future__ import division, print_function, unicode_literals
9
+
10
+ import re
11
+ import unicodedata
12
+ import numpy as np
13
+ from text_unidecode import unidecode
14
+
15
+ from torchmoji.tokenizer import RE_MENTION, tokenize
16
+ from torchmoji.filter_utils import (convert_linebreaks,
17
+ convert_nonbreaking_space,
18
+ correct_length,
19
+ extract_emojis,
20
+ mostly_english,
21
+ non_english_user,
22
+ process_word,
23
+ punct_word,
24
+ remove_control_chars,
25
+ remove_variation_selectors,
26
+ separate_emojis_and_text)
27
+
28
+ try:
29
+ unicode # Python 2
30
+ except NameError:
31
+ unicode = str # Python 3
32
+
33
+ # Only catch retweets in the beginning of the tweet as those are the
34
+ # automatically added ones.
35
+ # We do not want to remove tweets like "Omg.. please RT this!!"
36
+ RETWEETS_RE = re.compile(r'^[rR][tT]')
37
+
38
+ # Use fast and less precise regex for removing tweets with URLs
39
+ # It doesn't matter too much if a few tweets with URL's make it through
40
+ URLS_RE = re.compile(r'https?://|www\.')
41
+
42
+ MENTION_RE = re.compile(RE_MENTION)
43
+ ALLOWED_CONVERTED_UNICODE_PUNCTUATION = """!"#$'()+,-.:;<=>?@`~"""
44
+
45
+
46
+ class WordGenerator():
47
+ ''' Cleanses input and converts into words. Needs all sentences to be in
48
+ Unicode format. Has subclasses that read sentences differently based on
49
+ file type.
50
+
51
+ Takes a generator as input. This can be from e.g. a file.
52
+ unicode_handling in ['ignore_sentence', 'convert_punctuation', 'allow']
53
+ unicode_handling in ['ignore_emoji', 'ignore_sentence', 'allow']
54
+ '''
55
+ def __init__(self, stream, allow_unicode_text=False, ignore_emojis=True,
56
+ remove_variation_selectors=True, break_replacement=True):
57
+ self.stream = stream
58
+ self.allow_unicode_text = allow_unicode_text
59
+ self.remove_variation_selectors = remove_variation_selectors
60
+ self.ignore_emojis = ignore_emojis
61
+ self.break_replacement = break_replacement
62
+ self.reset_stats()
63
+
64
+ def get_words(self, sentence):
65
+ """ Tokenizes a sentence into individual words.
66
+ Converts Unicode punctuation into ASCII if that option is set.
67
+ Ignores sentences with Unicode if that option is set.
68
+ Returns an empty list of words if the sentence has Unicode and
69
+ that is not allowed.
70
+ """
71
+
72
+ if not isinstance(sentence, unicode):
73
+ raise ValueError("All sentences should be Unicode-encoded!")
74
+ sentence = sentence.strip().lower()
75
+
76
+ if self.break_replacement:
77
+ sentence = convert_linebreaks(sentence)
78
+
79
+ if self.remove_variation_selectors:
80
+ sentence = remove_variation_selectors(sentence)
81
+
82
+ # Split into words using simple whitespace splitting and convert
83
+ # Unicode. This is done to prevent word splitting issues with
84
+ # twokenize and Unicode
85
+ words = sentence.split()
86
+ converted_words = []
87
+ for w in words:
88
+ accept_sentence, c_w = self.convert_unicode_word(w)
89
+ # Unicode word detected and not allowed
90
+ if not accept_sentence:
91
+ return []
92
+ else:
93
+ converted_words.append(c_w)
94
+ sentence = ' '.join(converted_words)
95
+
96
+ words = tokenize(sentence)
97
+ words = [process_word(w) for w in words]
98
+ return words
99
+
100
+ def check_ascii(self, word):
101
+ """ Returns whether a word is ASCII """
102
+
103
+ try:
104
+ word.decode('ascii')
105
+ return True
106
+ except (UnicodeDecodeError, UnicodeEncodeError, AttributeError):
107
+ return False
108
+
109
+ def convert_unicode_punctuation(self, word):
110
+ word_converted_punct = []
111
+ for c in word:
112
+ decoded_c = unidecode(c).lower()
113
+ if len(decoded_c) == 0:
114
+ # Cannot decode to anything reasonable
115
+ word_converted_punct.append(c)
116
+ else:
117
+ # Check if all punctuation and therefore fine
118
+ # to include unidecoded version
119
+ allowed_punct = punct_word(
120
+ decoded_c,
121
+ punctuation=ALLOWED_CONVERTED_UNICODE_PUNCTUATION)
122
+
123
+ if allowed_punct:
124
+ word_converted_punct.append(decoded_c)
125
+ else:
126
+ word_converted_punct.append(c)
127
+ return ''.join(word_converted_punct)
128
+
129
+ def convert_unicode_word(self, word):
130
+ """ Converts Unicode words to ASCII using unidecode. If Unicode is not
131
+ allowed (set as a variable during initialization), then only
132
+ punctuation that can be converted to ASCII will be allowed.
133
+ """
134
+ if self.check_ascii(word):
135
+ return True, word
136
+
137
+ # First we ensure that the Unicode is normalized so it's
138
+ # always a single character.
139
+ word = unicodedata.normalize("NFKC", word)
140
+
141
+ # Convert Unicode punctuation to ASCII equivalent. We want
142
+ # e.g. "\u203c" (double exclamation mark) to be treated the same
143
+ # as "!!" no matter if we allow other Unicode characters or not.
144
+ word = self.convert_unicode_punctuation(word)
145
+
146
+ if self.ignore_emojis:
147
+ _, word = separate_emojis_and_text(word)
148
+
149
+ # If conversion of punctuation and removal of emojis took care
150
+ # of all the Unicode or if we allow Unicode then everything is fine
151
+ if self.check_ascii(word) or self.allow_unicode_text:
152
+ return True, word
153
+ else:
154
+ # Sometimes we might want to simply ignore Unicode sentences
155
+ # (e.g. for vocabulary creation). This is another way to prevent
156
+ # "polution" of strange Unicode tokens from low quality datasets
157
+ return False, ''
158
+
159
+ def data_preprocess_filtering(self, line, iter_i):
160
+ """ To be overridden with specific preprocessing/filtering behavior
161
+ if desired.
162
+
163
+ Returns a boolean of whether the line should be accepted and the
164
+ preprocessed text.
165
+
166
+ Runs prior to tokenization.
167
+ """
168
+ return True, line, {}
169
+
170
+ def data_postprocess_filtering(self, words, iter_i):
171
+ """ To be overridden with specific postprocessing/filtering behavior
172
+ if desired.
173
+
174
+ Returns a boolean of whether the line should be accepted and the
175
+ postprocessed text.
176
+
177
+ Runs after tokenization.
178
+ """
179
+ return True, words, {}
180
+
181
+ def extract_valid_sentence_words(self, line):
182
+ """ Line may either a string of a list of strings depending on how
183
+ the stream is being parsed.
184
+ Domain-specific processing and filtering can be done both prior to
185
+ and after tokenization.
186
+ Custom information about the line can be extracted during the
187
+ processing phases and returned as a dict.
188
+ """
189
+
190
+ info = {}
191
+
192
+ pre_valid, pre_line, pre_info = \
193
+ self.data_preprocess_filtering(line, self.stats['total'])
194
+ info.update(pre_info)
195
+ if not pre_valid:
196
+ self.stats['pretokenization_filtered'] += 1
197
+ return False, [], info
198
+
199
+ words = self.get_words(pre_line)
200
+ if len(words) == 0:
201
+ self.stats['unicode_filtered'] += 1
202
+ return False, [], info
203
+
204
+ post_valid, post_words, post_info = \
205
+ self.data_postprocess_filtering(words, self.stats['total'])
206
+ info.update(post_info)
207
+ if not post_valid:
208
+ self.stats['posttokenization_filtered'] += 1
209
+ return post_valid, post_words, info
210
+
211
+ def generate_array_from_input(self):
212
+ sentences = []
213
+ for words in self:
214
+ sentences.append(words)
215
+ return sentences
216
+
217
+ def reset_stats(self):
218
+ self.stats = {'pretokenization_filtered': 0,
219
+ 'unicode_filtered': 0,
220
+ 'posttokenization_filtered': 0,
221
+ 'total': 0,
222
+ 'valid': 0}
223
+
224
+ def __iter__(self):
225
+ if self.stream is None:
226
+ raise ValueError("Stream should be set before iterating over it!")
227
+
228
+ for line in self.stream:
229
+ valid, words, info = self.extract_valid_sentence_words(line)
230
+
231
+ # Words may be filtered away due to unidecode etc.
232
+ # In that case the words should not be passed on.
233
+ if valid and len(words):
234
+ self.stats['valid'] += 1
235
+ yield words, info
236
+
237
+ self.stats['total'] += 1
238
+
239
+
240
+ class TweetWordGenerator(WordGenerator):
241
+ ''' Returns np array or generator of ASCII sentences for given tweet input.
242
+ Any file opening/closing should be handled outside of this class.
243
+ '''
244
+ def __init__(self, stream, wanted_emojis=None, english_words=None,
245
+ non_english_user_set=None, allow_unicode_text=False,
246
+ ignore_retweets=True, ignore_url_tweets=True,
247
+ ignore_mention_tweets=False):
248
+
249
+ self.wanted_emojis = wanted_emojis
250
+ self.english_words = english_words
251
+ self.non_english_user_set = non_english_user_set
252
+ self.ignore_retweets = ignore_retweets
253
+ self.ignore_url_tweets = ignore_url_tweets
254
+ self.ignore_mention_tweets = ignore_mention_tweets
255
+ WordGenerator.__init__(self, stream,
256
+ allow_unicode_text=allow_unicode_text)
257
+
258
+ def validated_tweet(self, data):
259
+ ''' A bunch of checks to determine whether the tweet is valid.
260
+ Also returns emojis contained by the tweet.
261
+ '''
262
+
263
+ # Ordering of validations is important for speed
264
+ # If it passes all checks, then the tweet is validated for usage
265
+
266
+ # Skips incomplete tweets
267
+ if len(data) <= 9:
268
+ return False, []
269
+
270
+ text = data[9]
271
+
272
+ if self.ignore_retweets and RETWEETS_RE.search(text):
273
+ return False, []
274
+
275
+ if self.ignore_url_tweets and URLS_RE.search(text):
276
+ return False, []
277
+
278
+ if self.ignore_mention_tweets and MENTION_RE.search(text):
279
+ return False, []
280
+
281
+ if self.wanted_emojis is not None:
282
+ uniq_emojis = np.unique(extract_emojis(text, self.wanted_emojis))
283
+ if len(uniq_emojis) == 0:
284
+ return False, []
285
+ else:
286
+ uniq_emojis = []
287
+
288
+ if self.non_english_user_set is not None and \
289
+ non_english_user(data[1], self.non_english_user_set):
290
+ return False, []
291
+ return True, uniq_emojis
292
+
293
+ def data_preprocess_filtering(self, line, iter_i):
294
+ fields = line.strip().split("\t")
295
+ valid, emojis = self.validated_tweet(fields)
296
+ text = fields[9].replace('\\n', '') \
297
+ .replace('\\r', '') \
298
+ .replace('&amp', '&') if valid else ''
299
+ return valid, text, {'emojis': emojis}
300
+
301
+ def data_postprocess_filtering(self, words, iter_i):
302
+ valid_length = correct_length(words, 1, None)
303
+ valid_english, n_words, n_english = mostly_english(words,
304
+ self.english_words)
305
+ if valid_length and valid_english:
306
+ return True, words, {'length': len(words),
307
+ 'n_normal_words': n_words,
308
+ 'n_english': n_english}
309
+ else:
310
+ return False, [], {'length': len(words),
311
+ 'n_normal_words': n_words,
312
+ 'n_english': n_english}
vocabulary.json ADDED
The diff for this file is too large to render. See raw diff