lvwerra HF staff commited on
Commit
85f9580
1 Parent(s): 1e59200
Files changed (3) hide show
  1. bary_score.py +9 -18
  2. requirements.txt +4 -1
  3. score.py +255 -0
bary_score.py CHANGED
@@ -16,6 +16,8 @@
16
  import evaluate
17
  import datasets
18
 
 
 
19
 
20
  # TODO: Add BibTeX citation
21
  _CITATION = """\
@@ -53,10 +55,6 @@ Examples:
53
  {'accuracy': 1.0}
54
  """
55
 
56
- # TODO: Define external resources urls if needed
57
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
-
59
-
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
  class BaryScore(evaluate.EvaluationModule):
62
  """TODO: Short description of my evaluation module."""
@@ -71,8 +69,8 @@ class BaryScore(evaluate.EvaluationModule):
71
  inputs_description=_KWARGS_DESCRIPTION,
72
  # This defines the format of each prediction and reference
73
  features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
  }),
77
  # Homepage of the module for documentation
78
  homepage="http://module.homepage",
@@ -81,15 +79,8 @@ class BaryScore(evaluate.EvaluationModule):
81
  reference_urls=["http://path.to.reference.url/new_module"]
82
  )
83
 
84
- def _download_and_prepare(self, dl_manager):
85
- """Optional: download external resources useful to compute the scores"""
86
- # TODO: Download external resources if needed
87
- pass
88
-
89
- def _compute(self, predictions, references):
90
- """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
16
  import evaluate
17
  import datasets
18
 
19
+ from score import BaryScoreMetric
20
+
21
 
22
  # TODO: Add BibTeX citation
23
  _CITATION = """\
55
  {'accuracy': 1.0}
56
  """
57
 
 
 
 
 
58
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
59
  class BaryScore(evaluate.EvaluationModule):
60
  """TODO: Short description of my evaluation module."""
69
  inputs_description=_KWARGS_DESCRIPTION,
70
  # This defines the format of each prediction and reference
71
  features=datasets.Features({
72
+ 'predictions': datasets.Value('string'),
73
+ 'references': datasets.Value('string'),
74
  }),
75
  # Homepage of the module for documentation
76
  homepage="http://module.homepage",
79
  reference_urls=["http://path.to.reference.url/new_module"]
80
  )
81
 
82
+ def _compute(self, predictions, references, model_name="bert-base-uncased", last_layers=5, use_idfs=True, sinkhorn_ref=0.01):
83
+ metric_call = BaryScoreMetric(model_name=model_name, last_layers=last_layers, use_idfs=use_idfs, sinkhorn_ref=sinkhorn_ref)
84
+ metric_call.prepare_idfs(references, predictions)
85
+ result = metric_call.evaluate_batch(references, predictions)
86
+ return result
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  evaluate==0.1.0
2
- datasets~=2.0
 
 
 
1
  evaluate==0.1.0
2
+ datasets~=2.0
3
+ POT
4
+ transformers
5
+ torch
score.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+ import numpy as np
3
+ import torch
4
+ from tqdm import tqdm
5
+ import ot
6
+ from math import log
7
+ from collections import defaultdict, Counter
8
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
9
+
10
+
11
+ class BaryScoreMetric:
12
+ def __init__(self, model_name="bert-base-uncased", last_layers=5, use_idfs=True, sinkhorn_ref=0.01):
13
+ """
14
+ BaryScore metric
15
+ :param model_name: model name or path from HuggingFace Librairy
16
+ :param last_layers: last layer to use in the pretrained model
17
+ :param use_idfs: if true use idf costs else use uniform weights
18
+ :param sinkhorn_ref: weight of the KL in the SD
19
+ """
20
+
21
+ self.model_name = model_name
22
+ self.load_tokenizer_and_model()
23
+ n = self.model.config.num_hidden_layers + 1
24
+ assert n - last_layers > 0
25
+ self.layers_to_consider = range(n - last_layers, n)
26
+ self.use_idfs = use_idfs
27
+ self.sinkhorn_ref = sinkhorn_ref
28
+ self.idfs = []
29
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+
31
+ def prepare_idfs(self, hyps, refs):
32
+ """
33
+ :param hyps: hypothesis list of string sentences has to be computed at corpus level
34
+ :param refs:reference list of string sentences has to be computed at corpus level
35
+ """
36
+ t_hyps = self.tokenizer(hyps)['input_ids']
37
+ t_refs = self.tokenizer(refs)['input_ids']
38
+ idf_dict_ref = self.ref_list_to_idf(t_refs)
39
+ idf_dict_hyp = self.ref_list_to_idf(t_hyps)
40
+ idfs_tokenizer = (idf_dict_ref, idf_dict_hyp)
41
+ self.model_ids = idfs_tokenizer
42
+ return idf_dict_hyp, idf_dict_ref
43
+
44
+ def ref_list_to_idf(self, input_refs):
45
+ """
46
+ :param input_refs: list of input reference
47
+ :return: idf dictionnary
48
+ """
49
+ idf_count = Counter()
50
+ num_docs = len(input_refs)
51
+
52
+ idf_count.update(sum([list(set(i)) for i in input_refs], []))
53
+
54
+ idf_dict = defaultdict(lambda: log((num_docs + 1) / (1)))
55
+ idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()})
56
+ return idf_dict
57
+
58
+ def load_tokenizer_and_model(self):
59
+ """
60
+ Loading and initializing the chosen model and tokenizer
61
+ """
62
+ tokenizer = AutoTokenizer.from_pretrained('{}'.format(self.model_name))
63
+ model = AutoModelForMaskedLM.from_pretrained('{}'.format(self.model_name))
64
+ model.config.output_hidden_states = True
65
+ model.eval()
66
+ self.tokenizer = tokenizer
67
+ self.model = model
68
+
69
+ def evaluate_batch(self, batch_hyps, batch_refs, idf_hyps=None, idf_ref=None):
70
+ """
71
+ :param batch_hyps: hypothesis list of string sentences
72
+ :param batch_refs: reference list of string sentences
73
+ :param idf_hyps: idfs of hypothesis computed at corpus level
74
+ :param idf_ref: idfs of references computed at corpus level
75
+ :return: dictionnary of scores
76
+ """
77
+ ###############################################
78
+ ## Extract Embeddings From Pretrained Models ##
79
+ ###############################################
80
+ if isinstance(batch_hyps, str):
81
+ batch_hyps = [batch_hyps]
82
+ if isinstance(batch_refs, str):
83
+ batch_refs = [batch_refs]
84
+ nb_sentences = len(batch_refs)
85
+ baryscores = []
86
+ assert len(batch_hyps) == len(batch_refs)
87
+
88
+ if (idf_hyps is None) and (idf_ref is None):
89
+ idf_hyps, idf_ref = self.model_ids
90
+
91
+ model = self.model.to(self.device)
92
+
93
+ with torch.no_grad():
94
+ ###############################################
95
+ ## Extract Embeddings From Pretrained Models ##
96
+ ###############################################
97
+ batch_refs = self.tokenizer(batch_refs, return_tensors='pt', padding=True, truncation=True).to(self.device)
98
+ batch_refs_embeddings_ = model(**batch_refs)[-1]
99
+
100
+ batch_hyps = self.tokenizer(batch_hyps, return_tensors='pt', padding=True, truncation=True).to(self.device)
101
+ batch_hyps_embeddings_ = model(**batch_hyps)[-1]
102
+
103
+ batch_refs_embeddings = [batch_refs_embeddings_[i] for i in list(self.layers_to_consider)]
104
+ batch_hyps_embeddings = [batch_hyps_embeddings_[i] for i in list(self.layers_to_consider)]
105
+
106
+ batch_refs_embeddings = torch.cat([i.unsqueeze(0) for i in batch_refs_embeddings])
107
+ batch_refs_embeddings.div_(torch.norm(batch_refs_embeddings, dim=-1).unsqueeze(-1))
108
+ batch_hyps_embeddings = torch.cat([i.unsqueeze(0) for i in batch_hyps_embeddings])
109
+ batch_hyps_embeddings.div_(torch.norm(batch_hyps_embeddings, dim=-1).unsqueeze(-1))
110
+
111
+ ref_tokens_id = batch_refs['input_ids'].cpu().tolist()
112
+ hyp_tokens_id = batch_hyps['input_ids'].cpu().tolist()
113
+
114
+ ####################################
115
+ ## Unbatched BaryScore Prediction ##
116
+ ####################################
117
+ for index_sentence in tqdm(range(nb_sentences), 'BaryScore Progress'):
118
+ dict_score = {}
119
+ ref_ids_idf = batch_refs['input_ids'][index_sentence]
120
+ hyp_idf_ids = batch_hyps['input_ids'][index_sentence]
121
+
122
+ ref_tokens = [i for i in self.tokenizer.convert_ids_to_tokens(ref_tokens_id[index_sentence],
123
+ skip_special_tokens=False) if
124
+ i != self.tokenizer.pad_token]
125
+ hyp_tokens = [i for i in self.tokenizer.convert_ids_to_tokens(hyp_tokens_id[index_sentence],
126
+ skip_special_tokens=False) if
127
+ i != self.tokenizer.pad_token]
128
+
129
+ ref_ids = [k for k, w in enumerate(ref_tokens)]
130
+ hyp_ids = [k for k, w in enumerate(hyp_tokens)]
131
+
132
+ # With stop words
133
+ ref_idf_i = [idf_ref[i] for i in ref_ids_idf[ref_ids]]
134
+ hyp_idf_i = [idf_hyps[i] for i in hyp_idf_ids[hyp_ids]]
135
+
136
+ ref_embedding_i = batch_refs_embeddings[:, index_sentence, ref_ids, :]
137
+ hyp_embedding_i = batch_hyps_embeddings[:, index_sentence, hyp_ids, :]
138
+ measures_locations_ref = ref_embedding_i.permute(1, 0, 2).cpu().numpy().tolist()
139
+ measures_locations_ref = [np.array(i) for i in measures_locations_ref]
140
+ measures_locations_hyps = hyp_embedding_i.permute(1, 0, 2).cpu().numpy().tolist()
141
+ measures_locations_hyps = [np.array(i) for i in measures_locations_hyps]
142
+
143
+ # ADDED
144
+ measures_locations_ref = [np.array(i) for i in
145
+ np.array(measures_locations_ref).transpose(1, 0, 2).tolist()]
146
+ measures_locations_hyps = [np.array(i) for i in
147
+ np.array(measures_locations_hyps).transpose(1, 0,
148
+ 2).tolist()]
149
+
150
+ if self.use_idfs:
151
+ #########################
152
+ ## Use TF-IDF weights ##
153
+ #########################
154
+ baryscore = self.baryscore(measures_locations_ref, measures_locations_hyps, ref_idf_i,
155
+ hyp_idf_i)
156
+ else:
157
+ #####################
158
+ ## Uniform Weights ##
159
+ #####################
160
+ baryscore = self.baryscore(measures_locations_ref, measures_locations_hyps, None, None)
161
+
162
+ for key, value in baryscore.items():
163
+ dict_score['baryscore_{}'.format(key)] = value
164
+ baryscores.append(dict_score)
165
+ baryscores_dic = {}
166
+ for k in dict_score.keys():
167
+ baryscores_dic[k] = []
168
+ for score in baryscores:
169
+ baryscores_dic[k].append(score[k])
170
+
171
+ return baryscores_dic
172
+
173
+ def baryscore(self, measures_locations_ref, measures_locations_hyps, weights_refs, weights_hyps):
174
+ """
175
+ :param measures_locations_ref: input measure reference locations
176
+ :param measures_locations_hyps: input measure hypothesis locations
177
+ :param weights_refs: references weights in the Wasserstein Barycenters
178
+ :param weights_hyps: hypothesis weights in the Wasserstein Barycenters
179
+ :return:
180
+ """
181
+ if weights_hyps is not None or weights_refs is not None:
182
+ assert weights_refs is not None
183
+ assert weights_hyps is not None
184
+ weights_hyps = np.array([i / sum(weights_hyps) for i in weights_hyps]).astype(np.float64)
185
+ weights_refs = np.array([i / sum(weights_refs) for i in weights_refs]).astype(np.float64)
186
+
187
+ self.n_layers = len(measures_locations_ref)
188
+ self.d_bert = measures_locations_ref[0].shape[1]
189
+ ####################################
190
+ ## Compute Wasserstein Barycenter ##
191
+ ####################################
192
+ bary_ref = self.w_barycenter(measures_locations_ref, weights_refs)
193
+ bary_hyp = self.w_barycenter(measures_locations_hyps, weights_hyps)
194
+
195
+ #################################################
196
+ ## Compute Wasserstein and Sinkhorn Divergence ##
197
+ #################################################
198
+
199
+ C = ot.dist(bary_ref, bary_hyp)
200
+ weights_first_barycenter = np.zeros((C.shape[0])) + 1 / C.shape[0]
201
+ weights_second_barycenter = np.zeros((C.shape[1])) + 1 / C.shape[1]
202
+ wasserstein_distance = ot.emd2(weights_first_barycenter, weights_second_barycenter, C,
203
+ log=True)[0]
204
+ dic_results = {
205
+ "W": wasserstein_distance,
206
+
207
+ }
208
+ for reg in [10, 1, 5, 1, 0.1, 0.5, 0.01, 0.001]:
209
+ wasserstein_sinkhorn = ot.bregman.sinkhorn2(weights_first_barycenter, weights_second_barycenter, C,
210
+ reg=reg, numItermax=10000).tolist()
211
+ if isinstance(wasserstein_sinkhorn, list):
212
+ wasserstein_sinkhorn = wasserstein_sinkhorn[0] # for POT==0.7.0
213
+ dic_results['SD_{}'.format(reg)] = wasserstein_sinkhorn
214
+ return dic_results
215
+
216
+ def w_barycenter(self, measures_locations, weights):
217
+ """
218
+ :param measures_locations: location of the discrete input measures
219
+ :param weights: weights of the input measures
220
+ :return: barycentrique distribution
221
+ """
222
+ X_init = np.zeros((measures_locations[0].shape[0], self.d_bert)).astype(np.float64)
223
+ if weights is None:
224
+ measures_weights = [np.array(
225
+ [1 / measures_locations[0].shape[0]] * measures_locations[0].shape[0])] * self.n_layers
226
+ else:
227
+ measures_weights = [weights / sum(weights)] * self.n_layers
228
+ b = np.array([1 / measures_locations[0].shape[0]] * measures_locations[0].shape[0]).astype(np.float64)
229
+ mesure_bary = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init,
230
+ b=b, numItermax=1000, verbose=False)
231
+ return mesure_bary
232
+
233
+ @property
234
+ def supports_multi_ref(self):
235
+ """
236
+ :return: BaryScore does not support multi ref
237
+ """
238
+ return False
239
+
240
+
241
+ if __name__ == '__main__':
242
+ """
243
+ Here you can find an example to use the BaryScore
244
+ """
245
+ metric_call = BaryScoreMetric(use_idfs=False)
246
+
247
+ ref = [
248
+ 'I like my cakes very much',
249
+ 'I hate these cakes!']
250
+ hypothesis = ['I like my cakes very much',
251
+ 'I like my cakes very much']
252
+
253
+ metric_call.prepare_idfs(ref, hypothesis)
254
+ final_preds = metric_call.evaluate_batch(ref, hypothesis)
255
+ print(final_preds)