beny2000 commited on
Commit
0e956f2
1 Parent(s): 65e6caa

Upload model

Browse files
Files changed (9) hide show
  1. config.json +30 -0
  2. configuration.py +30 -0
  3. embedding_model.py +121 -0
  4. embeddings.py +637 -0
  5. match_groups.py +870 -0
  6. pytorch_model.bin +3 -0
  7. scoring.py +196 -0
  8. scoring_model.py +60 -0
  9. similarity_model.py +369 -0
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SimilarityModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration.SimilarityModelConfig",
7
+ "AutoModel": "similarity_model.SimilarityModel"
8
+ },
9
+ "device": "cpu",
10
+ "embedding_model_config": {
11
+ "add_upper": true,
12
+ "d": 128,
13
+ "device": "cpu",
14
+ "model_class": "roberta",
15
+ "model_name": "roberta-base",
16
+ "normalize": true,
17
+ "pooling": "pooler",
18
+ "prompt": "",
19
+ "upper_case": false
20
+ },
21
+ "model_type": "roberta",
22
+ "score_model_config": {
23
+ "alpha": 50
24
+ },
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.27.4",
27
+ "weighting_function_config": {
28
+ "weighting_exponent": 0.5
29
+ }
30
+ }
configuration.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class SimilarityModelConfig(PretrainedConfig):
5
+ model_type = 'roberta'
6
+ def __init__(self, **kwargs):
7
+ super().__init__(**kwargs)
8
+
9
+ self.embedding_model_config = kwargs.get("embedding_model_config")
10
+ self.score_model_config = kwargs.get("score_model_config")
11
+ self.weighting_function_config = kwargs.get("weighting_function_config")
12
+
13
+
14
+ nama_base = SimilarityModelConfig(
15
+ embedding_model_config={
16
+ "model_class": 'roberta',
17
+ "model_name":'roberta-base',
18
+ "pooling": 'pooler',
19
+ "normalize":True,
20
+ "d":128,
21
+ "prompt":'',
22
+ "device":'cpu',
23
+ "add_upper": True,
24
+ "upper_case":False
25
+ },
26
+ score_model_config={"alpha": 50},
27
+ weighting_function_config={"weighting_exponent": 0.5},
28
+ device="cpu",
29
+ )
30
+
embedding_model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, RobertaModel
3
+
4
+ class EmbeddingModel(torch.nn.Module):
5
+
6
+
7
+ tokenizers = {'roberta': RobertaModel}
8
+ """
9
+ A basic wrapper around a Hugging Face transformer model.
10
+ Takes a string as input and produces an embedding vector of size d.
11
+ """
12
+ def __init__(self, config, **kwargs):
13
+
14
+ super().__init__()
15
+
16
+ self.model_class = self.tokenizers.get(config.get("model_class").lower())
17
+ self.model_name = config.get("model_name")
18
+ self.pooling = config.get("pooling")
19
+ self.normalize = config.get("normalize")
20
+ self.d = config.get("d")
21
+ self.prompt = config.get("prompt")
22
+ self.add_upper = config.get("add_upper")
23
+ self.upper_case = config.get("upper_case")
24
+
25
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
26
+
27
+ try:
28
+ self.transformer = self.model_class.from_pretrained(self.model_name)
29
+ except OSError:
30
+ self.transformer = self.model_class.from_pretrained(self.model_name,from_tf=True)
31
+
32
+ self.dropout = torch.nn.Dropout(0.5)
33
+
34
+ if self.d:
35
+ # Project embedding to a lower dimension
36
+ # Initialization based on random projection LSH (preserves approximate cosine distances)
37
+ self.projection = torch.nn.Linear(self.transformer.config.hidden_size,self.d)
38
+ torch.nn.init.normal_(self.projection.weight)
39
+ torch.nn.init.constant_(self.projection.bias,0)
40
+
41
+ self.to(config.get("device"))
42
+
43
+ def to(self,device):
44
+ super().to(device)
45
+ self.device = device
46
+
47
+ def encode(self,strings):
48
+ if self.prompt is not None:
49
+ strings = [self.prompt + s for s in strings]
50
+ if self.add_upper:
51
+ strings = [s + ' </s> ' + s.upper() for s in strings]
52
+ if self.upper_case:
53
+ strings = [s + ' </s> ' + s.upper() for s in strings]
54
+
55
+ try:
56
+ encoded = self.tokenizer(strings,padding=True,truncation=True)
57
+ except Exception as e:
58
+ print(strings)
59
+ raise Exception(e)
60
+ input_ids = torch.tensor(encoded['input_ids']).long()
61
+ attention_mask = torch.tensor(encoded['attention_mask'])
62
+
63
+ return input_ids,attention_mask
64
+
65
+ def forward(self,strings):
66
+
67
+ with torch.no_grad():
68
+ input_ids,attention_mask = self.encode(strings)
69
+
70
+ input_ids = input_ids.to(device=self.device)
71
+ attention_mask = attention_mask.to(device=self.device)
72
+
73
+ # with amp.autocast(self.amp):
74
+ batch_out = self.transformer(input_ids=input_ids,
75
+ attention_mask=attention_mask,
76
+ return_dict=True)
77
+
78
+ if self.pooling == 'pooler':
79
+ v = batch_out['pooler_output']
80
+ elif self.pooling == 'mean':
81
+ h = batch_out['last_hidden_state']
82
+
83
+ # Compute mean of unmasked token vectors
84
+ h = h*attention_mask[:,:,None]
85
+ v = h.sum(dim=1)/attention_mask.sum(dim=1)[:,None]
86
+
87
+ if self.d:
88
+ v = self.projection(v)
89
+
90
+ if self.normalize:
91
+ v = v/torch.sqrt((v**2).sum(dim=1)[:,None])
92
+
93
+ return v
94
+
95
+ def config_optimizer(self,transformer_lr=1e-5,projection_lr=1e-4):
96
+
97
+ parameters = list(self.named_parameters())
98
+ grouped_parameters = [
99
+ {
100
+ 'params': [param for name,param in parameters if name.startswith('transformer') and name.endswith('bias')],
101
+ 'weight_decay_rate': 0.0,
102
+ 'lr':transformer_lr,
103
+ },
104
+ {
105
+ 'params': [param for name,param in parameters if name.startswith('transformer') and not name.endswith('bias')],
106
+ 'weight_decay_rate': 0.0,
107
+ 'lr':transformer_lr,
108
+ },
109
+ {
110
+ 'params': [param for name,param in parameters if name.startswith('projection')],
111
+ 'weight_decay_rate': 0.0,
112
+ 'lr':projection_lr,
113
+ },
114
+ ]
115
+
116
+ # Drop groups with lr of 0
117
+ grouped_parameters = [p for p in grouped_parameters if p['lr']]
118
+
119
+ optimizer = torch.optim.AdamW(grouped_parameters)
120
+
121
+ return optimizer
embeddings.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from copy import copy
5
+ from collections import Counter
6
+ import torch
7
+ from zipfile import ZipFile
8
+ import pickle
9
+ from io import BytesIO
10
+
11
+ from .match_groups import MatchGroups
12
+
13
+
14
+ class Embeddings(torch.nn.Module):
15
+ """
16
+ Stores embeddings for a fixed array of strings and provides methods for
17
+ clustering the strings to create MatchGroups objects according to different
18
+ algorithms.
19
+ """
20
+ def __init__(self,strings,V,score_model,weighting_function,counts,device='cpu'):
21
+ super().__init__()
22
+
23
+ self.strings = np.array(list(strings))
24
+ self.string_map = {s:i for i,s in enumerate(strings)}
25
+ self.V = V
26
+ self.counts = counts
27
+ self.w = weighting_function(counts)
28
+ self.score_model = score_model
29
+ self.weighting_function = weighting_function
30
+ self.device = device
31
+
32
+ self.to(device)
33
+
34
+ def __repr__(self):
35
+ return f'<nama.Embeddings containing {self.V.shape[1]}-d vectors for {len(self)} strings'
36
+
37
+ def to(self,device):
38
+ super().to(device)
39
+ self.V = self.V.to(device)
40
+ self.counts = self.counts.to(device)
41
+ self.w = self.w.to(device)
42
+ self.score_model.to(device)
43
+ self.device = device
44
+
45
+ def save(self,f):
46
+ """
47
+ Save embeddings in a simple custom zipped archive format (torch.save
48
+ works too, but it requires huge amounts of memory to serialize large
49
+ embeddings objects).
50
+ """
51
+ with ZipFile(f,'w') as zip:
52
+
53
+ # Write score model
54
+ zip.writestr('score_model.pkl',pickle.dumps(self.score_model))
55
+
56
+ # Write score model
57
+ zip.writestr('weighting_function.pkl',pickle.dumps(self.weighting_function))
58
+
59
+ # Write string info
60
+ strings_df = pd.DataFrame().assign(
61
+ string=self.strings,
62
+ count=self.counts.to('cpu').numpy())
63
+ zip.writestr('strings.csv',strings_df.to_csv(index=False))
64
+
65
+ # Write embedding vectors
66
+ byte_io = BytesIO()
67
+ np.save(byte_io,self.V.to('cpu').numpy(),allow_pickle=False)
68
+ zip.writestr('V.npy',byte_io.getvalue())
69
+
70
+ def __getitem__(self,arg):
71
+ """
72
+ Slice a Match Groups object
73
+ """
74
+ if isinstance(arg,slice):
75
+ i = arg
76
+ elif isinstance(arg, MatchGroups):
77
+ return self[arg.strings()]
78
+ elif hasattr(arg,'__iter__'):
79
+ # Return a subset of the embeddings and their weights
80
+ string_map = self.string_map
81
+ i = [string_map[s] for s in arg]
82
+
83
+ if i == list(range(len(self))):
84
+ # Just selecting the whole match groups object - no need to slice the embedding
85
+ return copy(self)
86
+ else:
87
+ raise ValueError(f'Unknown slice input type ({type(input)}). Can only slice Embedding with a slice, match group, or iterable.')
88
+
89
+ new = copy(self)
90
+ new.strings = self.strings[i]
91
+ new.V = self.V[i]
92
+ new.counts = self.counts[i]
93
+ new.w = self.w[i]
94
+ new.string_map = {s:i for i,s in enumerate(new.strings)}
95
+
96
+ return new
97
+
98
+ def embed(self,grouping):
99
+ """
100
+ Construct updated Embeddings with counts from the input MatchGroups
101
+ """
102
+ new = self[grouping]
103
+ new.counts = torch.tensor([grouping.counts[s] for s in new.strings],device=self.device)
104
+ new.w = new.weighting_function(new.counts)
105
+
106
+ return new
107
+
108
+ def __len__(self):
109
+ return len(self.strings)
110
+
111
+ def _group_to_ids(self,grouping):
112
+ group_id_map = {g:i for i,g in enumerate(grouping.groups.keys())}
113
+ group_ids = torch.tensor([group_id_map[grouping[s]] for s in self.strings]).to(self.device)
114
+ return group_ids
115
+
116
+ def _ids_to_group(self,group_ids):
117
+ if isinstance(group_ids,torch.Tensor):
118
+ group_ids = group_ids.to('cpu').numpy()
119
+
120
+ strings = self.strings
121
+ counts = self.counts.to('cpu').numpy()
122
+
123
+ # Sort by group and string count
124
+ g_sort = np.lexsort((counts,group_ids))
125
+ group_ids = group_ids[g_sort]
126
+ strings = strings[g_sort]
127
+ counts = counts[g_sort]
128
+
129
+ # Identify group boundaries and split locations
130
+ split_locs = np.nonzero(group_ids[1:] != group_ids[:-1])[0] + 1
131
+
132
+ # Get grouped strings as separate arrays
133
+ groups = np.split(strings,split_locs)
134
+
135
+ # Build the groupings
136
+ grouping = MatchGroups()
137
+ grouping.counts = Counter({s:int(c) for s,c in zip(strings,counts)})
138
+ grouping.labels = {s:g[-1] for g in groups for s in g}
139
+ grouping.groups = {g[-1]:list(g) for g in groups}
140
+
141
+ return grouping
142
+
143
+ @torch.no_grad()
144
+ def _fast_unite_similar(self,group_ids,threshold=0.5,progress_bar=True,batch_size=64):
145
+
146
+ V = self.V
147
+ cos_threshold = self.score_model.score_to_cos(threshold)
148
+
149
+ for batch_start in tqdm(range(0,len(self),batch_size),
150
+ delay=1,desc='Predicting matches',disable=not progress_bar):
151
+
152
+ i_slice = slice(batch_start,batch_start+batch_size)
153
+ j_slice = slice(batch_start+1,None)
154
+
155
+ g_i = group_ids[i_slice]
156
+ g_j = group_ids[j_slice]
157
+
158
+ # Find j's with jaccard > threshold ("matches")
159
+ batch_matched = (V[i_slice]@V[j_slice].T >= cos_threshold) \
160
+ * (g_i[:,None] != g_j[None,:])
161
+
162
+ for k,matched in enumerate(batch_matched):
163
+ if matched.any():
164
+ # Get the group ids of the matched j's
165
+ matched_groups = g_j[matched]
166
+
167
+ # Identify all embeddings in these groups
168
+ ids_to_group = torch.isin(group_ids,matched_groups)
169
+
170
+ # Assign all matched embeddings to the same group
171
+ group_ids[ids_to_group] = g_i[k].clone()
172
+
173
+ return self._ids_to_group(group_ids)
174
+
175
+ @torch.no_grad()
176
+ def unite_similar(self,
177
+ threshold=0.5,
178
+ group_threshold=None,
179
+ always_match=None,
180
+ never_match=None,
181
+ batch_size=64,
182
+ progress_bar=True,
183
+ always_never_conflicts='warn',
184
+ return_united=False):
185
+
186
+ """
187
+ Unite embedding strings according to predicted pairwise similarity.
188
+
189
+ - "theshold" sets the minimimum match similarity required to unite two strings.
190
+ - Note that strings with similarity<threshold can end up matched if they are
191
+ linked by a chain of sufficiently similar strings (matching is transitive).
192
+ "group_threshold" can be used to add an additional constraing on the minimum
193
+ similarity within each group.
194
+ - "group_threshold" sets the minimum similarity required within a single group.
195
+ - "always_match" takes any argument that can be used to unite strings. These
196
+ strings will always be matched.
197
+ - "never_match" takes a set, or a list of sets, where each set indicates two or
198
+ more strings that should never be united with each other (these strings may
199
+ still be united with other strings).
200
+ - "always_never_conflicts" determines how to handle conflicts between
201
+ "always_match" and "never_match":
202
+ - always_never_conflicts="warn": Check for conflicts and print a warning
203
+ if any are found (default)
204
+ - always_never_conflicts="raise": Check for conflicts and raise an error
205
+ if any are found
206
+ - always_never_conflicts="ignore": Do not check for conflicts ("always_match"
207
+ will take precedence)
208
+
209
+ If "group_threshold" or "never_match" arguments are supplied, strings pairs are
210
+ united in order of similarity. Highest similarity strings are matched first, and
211
+ before each time a new pair of strings is united, the function checks if this will
212
+ result in grouping any two strings with similarity<group_threshold. If so, this
213
+ pair is skipped. This version of the algorithm requires more memory and processing
214
+ time, but guaruntees deterministic output that is consistent with the constraints.
215
+
216
+ returns: MatchGroups object
217
+ """
218
+ if group_threshold and group_threshold < threshold:
219
+ raise ValueError('group_threshold must be greater than or equal to threshold')
220
+
221
+ group_ids = torch.arange(len(self)).to(self.device)
222
+
223
+ if always_match is not None:
224
+ always_grouping = (MatchGroups(self.strings)
225
+ .unite(always_match))
226
+ always_match_labels = always_grouping.labels
227
+
228
+
229
+ # Use a simpler, faster prediction algorithm if possible
230
+ if not (return_united or group_threshold or (never_match is not None)):
231
+ if always_match is not None:
232
+ group_ids = self._group_to_ids(always_grouping)
233
+
234
+ return self._fast_unite_similar(
235
+ group_ids=group_ids,
236
+ threshold=threshold,
237
+ batch_size=batch_size,
238
+ progress_bar=progress_bar)
239
+
240
+ if never_match is not None:
241
+ # Ensure never_match is a nested list
242
+ if all(isinstance(s,str) for s in never_match):
243
+ never_match = [never_match]
244
+
245
+ if always_match is not None:
246
+
247
+ assert always_never_conflicts in ['raise','warn','ignore']
248
+
249
+ if always_never_conflicts != 'ignore':
250
+
251
+ # Find conflicts between never_match and always_match groups
252
+ conflicts = []
253
+ for i,g in enumerate(never_match):
254
+ g = sorted(list(g))
255
+ g_labels = [always_match_labels.get(s,s) for s in g]
256
+ if len(set(g_labels)) < len(g):
257
+ df = (pd.DataFrame()
258
+ .assign(
259
+ string=g,
260
+ never_match_group=i,
261
+ always_match_group=g_labels
262
+ ))
263
+ conflicts.append(df)
264
+
265
+ if conflicts:
266
+ conflicts_df = pd.concat(conflicts)
267
+
268
+ if always_never_conflicts == 'warn':
269
+ print(f'Warning: The following never_match groups are in conflict with always_match groups:\n{conflicts_df}')
270
+ print('Conflicted never_match relationships will be ignored')
271
+ else:
272
+ raise ValueError(f'The following never_match groups are in conflict with always_match groups\n{conflicts_df}')
273
+
274
+
275
+ # If always_match, collapse to group labels that should not match
276
+ # Note: Implicitly letting always_match over-ride never_match here
277
+ never_match = [{always_match_labels[s] for s in g if s in always_match_labels} for g in never_match]
278
+
279
+ else:
280
+ # Otherwise just use the strings themselves as labels
281
+ never_match = [set(s) for s in never_match]
282
+
283
+ # Convert thresholds from scores to raw cosine distances
284
+ V = self.V
285
+ cos_threshold = self.score_model.score_to_cos(threshold)
286
+ if group_threshold is not None:
287
+ separate_cos = self.score_model.score_to_cos(group_threshold)
288
+
289
+ # First collect all pairs to match (can be memory intensive!)
290
+ matches = []
291
+ cos_scores = []
292
+ for batch_start in tqdm(range(0,len(self),batch_size),
293
+ desc='Scoring pairs',
294
+ delay=1,disable=not progress_bar):
295
+
296
+ i_slice = slice(batch_start,batch_start+batch_size)
297
+ j_slice = slice(batch_start+1,None)
298
+
299
+ # Find j's with jaccard > threshold ("matches")
300
+ batch_cos = V[i_slice]@V[j_slice].T
301
+
302
+ # Search upper diagonal entries only
303
+ # (note j_slice starting index is offset by one)
304
+ batch_cos = torch.triu(batch_cos)
305
+
306
+ bi,bj = torch.nonzero(batch_cos >= cos_threshold,as_tuple=True)
307
+
308
+ if len(bi):
309
+ # Convert batch index locations to global index locations
310
+ i = bi + batch_start
311
+ j = bj + batch_start + 1
312
+
313
+ cos = batch_cos[bi,bj]
314
+
315
+ # Can skip strings that are already matched in the base grouping
316
+ unmatched = group_ids[i] != group_ids[j]
317
+ i = i[unmatched]
318
+ j = j[unmatched]
319
+ cos = cos[unmatched]
320
+
321
+ if len(i):
322
+ batch_matches = torch.hstack([i[:,None],j[:,None]])
323
+
324
+ matches.append(batch_matches.to('cpu').numpy())
325
+ cos_scores.append(cos.to('cpu').numpy())
326
+
327
+ # Unite potential match pairs in priority order, while respecting
328
+ # the group_threshold and never_match arguments
329
+ united = []
330
+ if matches:
331
+ matches = np.vstack(matches)
332
+ cos_scores = np.hstack(cos_scores).T
333
+
334
+ # Sort matches in descending order of score
335
+ m_sort = cos_scores.argsort()[::-1]
336
+ matches = matches[m_sort]
337
+
338
+ if return_united:
339
+ # Save cos scores for later return
340
+ cos_scores_df = pd.DataFrame(matches,columns=['i','j'])
341
+ cos_scores_df['cos'] = cos_scores[m_sort]
342
+
343
+ # Set up tensors
344
+ matches = torch.tensor(matches).to(self.device)
345
+
346
+ # Set-up per-string tracking of never-match relationships
347
+ if never_match is not None:
348
+ never_match_map = {s:sep for sep in never_match for s in sep}
349
+
350
+ if always_match is not None:
351
+ # If always_match, we use group labels instead of the strings themselves
352
+ never_match_array = np.array([never_match_map.get(always_match_labels[s],set()) for s in self.strings])
353
+ else:
354
+ never_match_array = np.array([never_match_map.get(s,set()) for s in self.strings])
355
+
356
+
357
+ n_matches = matches.shape[0]
358
+ with tqdm(total=n_matches,desc='Uniting matches',
359
+ delay=1,disable=not progress_bar) as p_bar:
360
+
361
+ while len(matches):
362
+
363
+ # Select the current match pair and remove it from the queue
364
+ match_pair = matches[0]
365
+ matches = matches[1:]
366
+
367
+ # Get the groups of the current match pair
368
+ g = group_ids[match_pair]
369
+ g0 = group_ids == g[0]
370
+ g1 = group_ids == g[1]
371
+
372
+ # Identify which strings should be united
373
+ to_unite = g0 | g1
374
+
375
+ # Flag whether the new group will have three or more strings
376
+ singletons = to_unite.sum() < 3
377
+
378
+ # Start by asuming that we can match this pair
379
+ unite_ok = True
380
+
381
+ # Check whether uniting this pair will unite any never_match strings/labels
382
+ if never_match is not None:
383
+ never_0 = never_match_array[match_pair[0]]
384
+ never_1 = never_match_array[match_pair[1]]
385
+
386
+ if never_0 and never_1 and (never_0 & never_1):
387
+ # Here we make use of the fact that any pair of never_match strings/labels
388
+ # will appear in both never_0 and never_1 if one string/label is in each group
389
+ unite_ok = False
390
+
391
+ # Check whether the uniting the pair will violate the group_threshold
392
+ # (impossible if the strings are singletons)
393
+ if unite_ok and group_threshold and not singletons:
394
+ V0 = V[g0,:]
395
+ V1 = V[g1,:]
396
+
397
+ unite_ok = (V0@V1.T).min() >= separate_cos
398
+
399
+
400
+ if unite_ok:
401
+
402
+ # Unite groups
403
+ group_ids[to_unite] = g[0]
404
+
405
+ if never_match and (never_0 or never_1):
406
+ # Propagate never_match information to the whole group
407
+ never_match_array[to_unite.detach().cpu().numpy()] = never_0 | never_1
408
+
409
+ # If we are uniting more than two strings, we can eliminate
410
+ # some redundant matches in the queue
411
+ if not singletons:
412
+ # Removed queued matches that are now in the same group
413
+ matches = matches[group_ids[matches[:,0]] != group_ids[matches[:,1]]]
414
+
415
+ if return_united:
416
+ match_record = np.empty(4,dtype=int)
417
+ match_record[:2] = match_pair.cpu().numpy().ravel()
418
+ match_record[2] = self.counts[g0].sum().item()
419
+ match_record[3] = self.counts[g1].sum().item()
420
+
421
+ united.append(match_record)
422
+ else:
423
+ # Remove queued matches connecting these groups
424
+ matches = matches[torch.isin(group_ids[matches[:,0]],g,invert=True) \
425
+ | torch.isin(group_ids[matches[:,1]],g,invert=True)]
426
+
427
+ # Update progress bar
428
+ p_bar.update(n_matches - matches.shape[0])
429
+ n_matches = matches.shape[0]
430
+
431
+ predicted_grouping = self.ids_to_group(group_ids)
432
+
433
+ if always_match is not None:
434
+ predicted_grouping = predicted_grouping.unite(always_grouping)
435
+
436
+ if return_united:
437
+ united_df = pd.DataFrame(np.vstack(united),columns=['i','j','n_i','n_j'])
438
+ united_df = pd.merge(united_df,cos_scores_df,how='inner',on=['i','j'])
439
+ united_df['score'] = self.score_model(
440
+ torch.tensor(united_df['cos'].values).to(self.device)
441
+ ).cpu().numpy()
442
+
443
+ united_df = united_df.drop('cos',axis=1)
444
+
445
+ for c in ['i','j']:
446
+ united_df[c] = [self.strings[i] for i in united_df[c]]
447
+
448
+ if always_match is not None:
449
+ united_df['always_match'] = [always_grouping[i] == always_grouping[j]
450
+ for i,j in united_df[['i','j']].values]
451
+
452
+ return predicted_grouping,united_df
453
+
454
+ else:
455
+
456
+ return predicted_grouping
457
+
458
+ @torch.no_grad()
459
+ def unite_nearest(self,target_strings,threshold=0,always_grouping=None,progress_bar=True,batch_size=64):
460
+ """
461
+ Unite embedding strings with each string's most similar target string.
462
+
463
+ - "always_grouping" will be used to inialize the group_ids before uniting new matches
464
+ - "theshold" sets the minimimum match similarity required between a string and target string
465
+ for the string to be matched. (i.e., setting theshold=0 will result in every embedding
466
+ string to be matched its nearest target string, while setting threshold=0.9 will leave
467
+ strings that have similarity<0.9 with their nearest target string unaffected)
468
+
469
+ returns: MatchGroups object
470
+ """
471
+
472
+ if always_grouping is not None:
473
+ # self = self.embed(always_grouping)
474
+ group_ids = self._group_to_ids(always_grouping)
475
+ else:
476
+ group_ids = torch.arange(len(self)).to(self.device)
477
+
478
+ V = self.V
479
+ cos_threshold = self.score_model.score_to_cos(threshold)
480
+
481
+ seed_ids = torch.tensor([self.string_map[s] for s in target_strings]).to(self.device)
482
+ V_seed = V[seed_ids]
483
+ g_seed = group_ids[seed_ids]
484
+ is_seed = torch.zeros(V.shape[0],dtype=torch.bool).to(self.device)
485
+ is_seed[g_seed] = True
486
+
487
+ for batch_start in tqdm(range(0,len(self),batch_size),
488
+ delay=1,desc='Predicting matches',disable=not progress_bar):
489
+
490
+ batch_slice = slice(batch_start,batch_start+batch_size)
491
+
492
+ batch_cos = V[batch_slice]@V_seed.T
493
+
494
+ max_cos,max_seed = torch.max(batch_cos,dim=1)
495
+
496
+ # Get batch index locations where score > threshold
497
+ batch_i = torch.nonzero(max_cos > cos_threshold)
498
+
499
+ if len(batch_i):
500
+ # Drop target strings from matches (otherwise numerical precision
501
+ # issues can allow target strings to match to other strings)
502
+ batch_i = batch_i[~is_seed[batch_slice][batch_i]]
503
+
504
+ if len(batch_i):
505
+ # Get indices of matched strings
506
+ i = batch_i + batch_start
507
+
508
+ # Assign matched strings to the target string's group
509
+ group_ids[i] = g_seed[max_seed[batch_i]]
510
+
511
+ return self._ids_to_group(group_ids)
512
+
513
+ @torch.no_grad()
514
+ def score_pairs(self,string_pairs,batch_size=64,progress_bar=True):
515
+ string_pairs = np.array(string_pairs)
516
+
517
+ scores = []
518
+ for batch_start in tqdm(range(0,string_pairs.shape[0],batch_size),desc='Scoring pairs',disable=not progress_bar):
519
+
520
+ V0 = self[string_pairs[batch_start:batch_start+batch_size,0]].V
521
+ V1 = self[string_pairs[batch_start:batch_start+batch_size,1]].V
522
+
523
+ batch_cos = (V0*V1).sum(dim=1).ravel()
524
+ batch_scores = self.score_model(batch_cos)
525
+
526
+ scores.append(batch_scores.cpu().numpy())
527
+
528
+ return np.concatenate(scores)
529
+
530
+ @torch.no_grad()
531
+ def _batch_scores(self,group_ids,batch_start,batch_size,
532
+ is_match=None,
533
+ min_score=None,max_score=None,
534
+ min_loss=None,max_loss=None):
535
+
536
+ strings = self.strings
537
+ V = self.V
538
+ w = self.w
539
+
540
+ # Create simple slice objects to avoid creating copies with advanced indexing
541
+ i_slice = slice(batch_start,batch_start+batch_size)
542
+ j_slice = slice(batch_start+1,None)
543
+
544
+ X = V[i_slice]@V[j_slice].T
545
+ Y = (group_ids[i_slice,None] == group_ids[None,j_slice]).float()
546
+ if w is not None:
547
+ W = w[i_slice,None]*w[None,j_slice]
548
+ else:
549
+ W = None
550
+
551
+ scores = self.score_model(X)
552
+ loss = self.score_model.loss(X,Y,weights=W)
553
+
554
+ # Search upper diagonal entries only
555
+ # (note j_slice starting index is offset by one)
556
+ scores = torch.triu(scores)
557
+
558
+ # Filter by match type
559
+ if is_match is not None:
560
+ if is_match:
561
+ scores *= Y
562
+ else:
563
+ scores *= (1 - Y)
564
+
565
+ # Filter by min score
566
+ if min_score is not None:
567
+ scores *= (scores >= min_score)
568
+
569
+ # Filter by max score
570
+ if max_score is not None:
571
+ scores *= (scores <= max_score)
572
+
573
+ # Filter by min loss
574
+ if min_loss is not None:
575
+ scores *= (loss >= min_loss)
576
+
577
+ # Filter by max loss
578
+ if max_loss is not None:
579
+ scores *= (loss <= max_loss)
580
+
581
+ # Collect scored pairs
582
+ i,j = torch.nonzero(scores,as_tuple=True)
583
+
584
+ pairs = np.hstack([
585
+ strings[i.cpu().numpy() + batch_start][:,None],
586
+ strings[j.cpu().numpy() + (batch_start + 1)][:,None]
587
+ ])
588
+
589
+ pair_groups = np.hstack([
590
+ strings[group_ids[i + batch_start].cpu().numpy()][:,None],
591
+ strings[group_ids[j + (batch_start + 1)].cpu().numpy()][:,None]
592
+ ])
593
+
594
+ pair_scores = scores[i,j].cpu().numpy()
595
+ pair_losses = loss[i,j].cpu().numpy()
596
+
597
+ return pairs,pair_groups,pair_scores,pair_losses
598
+
599
+ def iter_scores(self,grouping=None,batch_size=64,progress_bar=True,**kwargs):
600
+
601
+ if grouping is not None:
602
+ self = self.embed(grouping)
603
+ group_ids = self._group_to_ids(grouping)
604
+ else:
605
+ group_ids = torch.arange(len(self)).to(self.device)
606
+
607
+ for batch_start in tqdm(range(0,len(self),batch_size),desc='Scoring pairs',disable=not progress_bar):
608
+ pairs,pair_groups,scores,losses = self._batch_scored_pairs(self,group_ids,batch_start,batch_size,**kwargs)
609
+ for (s0,s1),(g0,g1),score,loss in zip(pairs,pair_groups,scores,losses):
610
+ yield {
611
+ 'string0':s0,
612
+ 'string1':s1,
613
+ 'group0':g0,
614
+ 'group1':g1,
615
+ 'score':score,
616
+ 'loss':loss,
617
+ }
618
+
619
+
620
+ def load_embeddings(f):
621
+ """
622
+ Load embeddings from custom zipped archive format
623
+ """
624
+ with ZipFile(f,'r') as zip:
625
+ score_model = pickle.loads(zip.read('score_model.pkl'))
626
+ weighting_function = pickle.loads(zip.read('weighting_function.pkl'))
627
+ strings_df = pd.read_csv(zip.open('strings.csv'),na_filter=False)
628
+ V = np.load(zip.open('V.npy'))
629
+
630
+ return Embeddings(
631
+ strings=strings_df['string'].values,
632
+ counts=torch.tensor(strings_df['count'].values),
633
+ score_model=score_model,
634
+ weighting_function=weighting_function,
635
+ V=torch.tensor(V)
636
+ )
637
+
match_groups.py ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from collections import Counter, defaultdict
4
+ from itertools import islice
5
+ import pandas as pd
6
+ import numpy as np
7
+ import networkx as nx
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib as mplt
10
+
11
+ MAX_STR = 50
12
+
13
+
14
+ class MatchGroups():
15
+ """A class for grouping strings based on set membership. Supports splitting and uniting of groups."""
16
+
17
+ def __init__(self, strings=None):
18
+ """
19
+ Initialize MatchGroups object.
20
+
21
+ Parameters
22
+ ----------
23
+ strings : list, optional
24
+ List of strings to add to the match groups object, by default None
25
+ """
26
+ self.counts = Counter()
27
+ self.labels = {}
28
+ self.groups = {}
29
+
30
+ if strings is not None:
31
+ self.add_strings(strings, inplace=True)
32
+
33
+ def __len__(self):
34
+ """Return the number of strings in the match groups object."""
35
+ return len(self.labels)
36
+
37
+ def __repr__(self):
38
+ """Return a string representation of the MatchGroups object."""
39
+ return f'<nama.MatchGroups containing {len(self)} strings in {len(self.groups)} groups>'
40
+
41
+ def __str__(self):
42
+ """Return a string representation of the groups of a MatchGroups object."""
43
+ output = self.__repr__()
44
+ remaining = MAX_STR
45
+ for group in self.groups.values():
46
+ for s in group:
47
+ if remaining:
48
+ output += '\n' + s
49
+ remaining -= 1
50
+ else:
51
+ output += f'...\n(Output truncated at {MAX_STR} strings)'
52
+ return output
53
+
54
+ output += '\n'
55
+
56
+ return output
57
+
58
+ def __contains__(self, s):
59
+ """Return True if string is in the match groups object, False otherwise."""
60
+ return s in self.labels
61
+
62
+ def __getitem__(self, strings):
63
+ """Return the group label for a single string or a list of strings."""
64
+ if isinstance(strings, str):
65
+ return self.labels[strings]
66
+ else:
67
+ return [self.labels[s] for s in strings]
68
+
69
+ def __add__(self, match_obj):
70
+ """Add two match groups objects together and return the result."""
71
+ result = self.add_strings(match_obj)
72
+ result.unite(match_obj, inplace=True)
73
+
74
+ return result
75
+
76
+ def items(self):
77
+ """Return an iterator of strings and their group labels."""
78
+ for i, g in self.labels.items():
79
+ yield i, g
80
+
81
+ def copy(self):
82
+ """Return a copy of the MatchGroups object."""
83
+ new_match_obj = MatchGroups()
84
+ new_match_obj.counts = self.counts.copy()
85
+ new_match_obj.labels = self.labels.copy()
86
+ new_match_obj.groups = self.groups.copy()
87
+
88
+ return new_match_obj
89
+
90
+ def strings(self):
91
+ """Return a list of strings in the match groups object. Order is not guaranteed."""
92
+ return list(self.labels.keys())
93
+
94
+ def matches(self, string):
95
+ """Return the group of strings that match the given string."""
96
+ return self.groups[self.labels[string]]
97
+
98
+ def add_strings(self, arg, inplace=False):
99
+ """Add new strings to the match groups object.
100
+
101
+ Parameters
102
+ ----------
103
+ arg : str, Counter, MatchGroups, Iterable
104
+ String or group of strings to add to the match groups object
105
+ inplace : bool, optional
106
+ If True, add strings to the existing MatchGroups object, by default False
107
+
108
+ Returns
109
+ -------
110
+ MatchGroups
111
+ The updated MatchGroups object
112
+ """
113
+ if isinstance(arg, str):
114
+ counts = {arg: 1}
115
+
116
+ elif isinstance(arg, Counter):
117
+ counts = arg
118
+
119
+ elif isinstance(arg, MatchGroups):
120
+ counts = arg.counts
121
+
122
+ elif hasattr(arg, '__next__') or hasattr(arg, '__iter__'):
123
+ counts = Counter(arg)
124
+
125
+ if not inplace:
126
+ self = self.copy()
127
+
128
+ for s in counts.keys():
129
+ if s not in self.labels:
130
+ self.labels[s] = s
131
+ self.groups[s] = [s]
132
+
133
+ self.counts += counts
134
+
135
+ return self
136
+
137
+ def drop(self, strings, inplace=False):
138
+ """Remove strings from the match groups object.
139
+
140
+ Parameters
141
+ ----------
142
+ strings : list or str
143
+ String or list of strings to remove from the match groups object
144
+ inplace : bool, optional
145
+ If True, remove strings from the existing MatchGroups object, by default False
146
+
147
+ Returns
148
+ -------
149
+ MatchGroups
150
+ The updated MatchGroups object
151
+ """
152
+ if isinstance(strings, str):
153
+ strings = [strings]
154
+
155
+ strings = set(strings)
156
+
157
+ if not inplace:
158
+ self = self.copy()
159
+
160
+ # Remove strings from their groups
161
+ affected_group_labels = {self[s] for s in strings}
162
+ for old_label in affected_group_labels:
163
+ old_group = self.groups[old_label]
164
+ new_group = [s for s in old_group if s not in strings]
165
+
166
+ if new_group:
167
+ counts = self.counts
168
+ new_label = min((-counts[s], s) for s in new_group)[1]
169
+
170
+ if new_label != old_label:
171
+ del self.groups[old_label]
172
+
173
+ self.groups[new_label] = new_group
174
+
175
+ for s in new_group:
176
+ self.labels[s] = new_label
177
+ else:
178
+ del self.groups[old_label]
179
+
180
+ # Remove strings from counts and labels
181
+ for s in strings:
182
+ del self.counts[s]
183
+ del self.labels[s]
184
+
185
+ return self
186
+
187
+ def keep(self, strings, inplace=False):
188
+ """Drop all strings from the match groups object except the passed strings.
189
+
190
+ Parameters
191
+ ----------
192
+ strings : list
193
+ List of strings to keep in the match groups object
194
+ inplace : bool, optional
195
+ If True, drop strings from the existing MatchGroups object, by default False
196
+
197
+ Returns
198
+ -------
199
+ MatchGroups
200
+ The updated MatchGroups object
201
+ """
202
+ strings = set(strings)
203
+
204
+ to_drop = [s for s in self.strings() if s not in strings]
205
+
206
+ return self.drop(to_drop, inplace=inplace)
207
+
208
+ def _unite_strings(self, strings):
209
+ """
210
+ Unite strings in the match groups object without checking argument type.
211
+ Intended as a low-level function called by self.unite()
212
+
213
+ Parameters
214
+ ----------
215
+ strings : list
216
+ List of strings to unite in the match groups object
217
+
218
+ Returns
219
+ -------
220
+ None
221
+ """
222
+ strings = {s for s in strings if s in self.labels}
223
+
224
+ if len(strings) > 1:
225
+
226
+ # Identify groups that will be united
227
+ old_labels = set(self[strings])
228
+
229
+ # Only need to do the merge if the strings span multiple groups
230
+ if len(old_labels) > 1:
231
+
232
+ # Identify the new group label
233
+ counts = self.counts
234
+ new_label = min((-counts[s], s) for s in old_labels)[1]
235
+
236
+ # Identify the groups which need to be modified
237
+ old_labels.remove(new_label)
238
+
239
+ for old_label in old_labels:
240
+ # Update the string group labels
241
+ for s in self.groups[old_label]:
242
+ self.labels[s] = new_label
243
+
244
+ # Update group dict
245
+ self.groups[new_label] = self.groups[new_label] + \
246
+ self.groups[old_label]
247
+ del self.groups[old_label]
248
+
249
+ def unite(self, arg, inplace=False, **kwargs):
250
+ """
251
+ Merge groups containing the passed strings. Groups can be passed as:
252
+ - A list of strings to unite
253
+ - A nested list to unite each set of strings
254
+ - A dictionary mapping strings to labels to unite by label
255
+ - A function mapping strings to labels to unite by label
256
+ - A MatchGroups instance to unite by MatchGroups groups
257
+
258
+ Parameters
259
+ ----------
260
+ arg : list, dict, function or MatchGroups instance
261
+ Argument representing the strings or labels to merge.
262
+ inplace : bool, optional
263
+ Whether to perform the operation in place or return a new MatchGroups.
264
+ kwargs : dict, optional
265
+ Additional arguments to be passed to predict_groupings method if arg
266
+ is a similarity model with a predict_groupings method.
267
+
268
+ Returns
269
+ -------
270
+ MatchGroups
271
+ The updated MatchGroups object. If `inplace` is True, the updated object
272
+ is returned, else a new MatchGroups object with the updates is returned.
273
+ """
274
+
275
+ if not inplace:
276
+ self = self.copy()
277
+
278
+ if isinstance(arg, str):
279
+ raise ValueError('Cannot unite a single string')
280
+
281
+ elif isinstance(arg, MatchGroups):
282
+ self.unite(arg.groups.values(), inplace=True)
283
+
284
+ elif hasattr(arg, 'unite_similar'):
285
+ # Unite can accept a similarity model if it has a unite_similar
286
+ # method
287
+ self.unite(arg.unite_similar(self, **kwargs))
288
+
289
+ elif callable(arg):
290
+ # Assume arg is a mapping from strings to labels and unite by label
291
+ groups = {s: arg(s) for s in self.strings()}
292
+ self.unite(groups, inplace=True)
293
+
294
+ elif isinstance(arg, dict):
295
+ # Assume arg is a mapping from strings to labels and unite by label
296
+ # groups = {label:[] for label in arg.values()}
297
+ groups = defaultdict(list)
298
+ for string, label in arg.items():
299
+ groups[label].append(string)
300
+
301
+ for group in groups.values():
302
+ self._unite_strings(group)
303
+
304
+ elif hasattr(arg, '__next__'):
305
+ # Assume arg is an iterator of groups to unite
306
+ # (This needs to be checked early to avoid consuming the first group)
307
+ for group in arg:
308
+ self._unite_strings(group)
309
+
310
+ elif all(isinstance(s, str) for s in arg):
311
+ # Main case: Unite group of strings
312
+ self._unite_strings(arg)
313
+
314
+ elif hasattr(arg, '__iter__'):
315
+ # Assume arg is an iterable of groups to unite
316
+ for group in arg:
317
+ self._unite_strings(group)
318
+
319
+ else:
320
+ raise ValueError('Unknown input type')
321
+
322
+ if not inplace:
323
+ return self
324
+
325
+ def split(self, strings, inplace=False):
326
+ """
327
+ Split strings into singleton groups. Strings can be passed as:
328
+ - A single string to isolate into a singleton group
329
+ - A list or iterator of strings to split
330
+
331
+ Parameters
332
+ ----------
333
+ strings : str or list of str
334
+ The string(s) to split into singleton groups.
335
+ inplace : bool, optional
336
+ Whether to perform the operation in place or return a new MatchGroups.
337
+
338
+ Returns
339
+ -------
340
+ MatchGroups
341
+ The updated MatchGroups object. If `inplace` is True, the updated object
342
+ is returned, else a new MatchGroups object with the updates is returned.
343
+ """
344
+ if not inplace:
345
+ self = self.copy()
346
+
347
+ if isinstance(strings, str):
348
+ strings = [strings]
349
+
350
+ strings = set(strings)
351
+
352
+ # Remove strings from their groups
353
+ affected_group_labels = {self[s] for s in strings}
354
+ for old_label in affected_group_labels:
355
+ old_group = self.groups[old_label]
356
+ if len(old_group) > 1:
357
+ new_group = [s for s in old_group if s not in strings]
358
+ if new_group:
359
+ counts = self.counts
360
+ new_label = min((-counts[s], s) for s in new_group)[1]
361
+
362
+ if new_label != old_label:
363
+ del self.groups[old_label]
364
+
365
+ self.groups[new_label] = new_group
366
+
367
+ for s in new_group:
368
+ self.labels[s] = new_label
369
+
370
+ # Update labels and add singleton groups
371
+ for s in strings:
372
+ self.labels[s] = s
373
+ self.groups[s] = [s]
374
+
375
+ return self
376
+
377
+ def split_all(self, inplace=False):
378
+ """
379
+ Split all strings into singleton groups.
380
+
381
+ Parameters
382
+ ----------
383
+ inplace : bool, optional
384
+ Whether to perform the operation in place or return a new MatchGroups.
385
+
386
+ Returns
387
+ -------
388
+ MatchGroups
389
+ The updated MatchGroups object. If `inplace` is True, the updated object
390
+ is returned, else a new MatchGroups object with the updates is returned.
391
+ """
392
+ if not inplace:
393
+ self = self.copy()
394
+
395
+ self.labels = {s: s for s in self.strings()}
396
+ self.groups = {s: [s] for s in self.strings()}
397
+
398
+ return self
399
+
400
+ def separate(
401
+ self,
402
+ strings,
403
+ similarity_model,
404
+ inplace=False,
405
+ threshold=0,
406
+ **kwargs):
407
+ """
408
+ Separate the strings in according to the prediction of the similarity_model.
409
+
410
+ Parameters
411
+ ----------
412
+ strings: list
413
+ List of strings to be separated.
414
+ similarity_model: Model
415
+ Model used to predict similarity between strings.
416
+ inplace: bool, optional
417
+ If True, the separation operation is performed in-place. Otherwise, a copy is created.
418
+ threshold: float, optional
419
+ Threshold value for prediction.
420
+ kwargs: dict, optional
421
+ Additional keyword arguments passed to the prediction function.
422
+
423
+ Returns
424
+ -------
425
+ self: MatchGroups
426
+ Returns the MatchGroups object after the separation operation.
427
+
428
+ """
429
+ if not inplace:
430
+ self = self.copy()
431
+
432
+ # Identify which groups contain the strings to separate
433
+ group_map = defaultdict(list)
434
+ for s in set(strings):
435
+ group_map[self[s]].append(s)
436
+
437
+ for g, g_sep in group_map.items():
438
+
439
+ # If group contains strings to separate...
440
+ if len(g_sep) > 1:
441
+ group_strings = self.groups[g]
442
+
443
+ # Split the group strings
444
+ self.split(group_strings, inplace=True)
445
+
446
+ # Re-unite with new prediction that enforces separation
447
+ try:
448
+ embeddings = similarity_model[group_strings]
449
+ except Exception as e:
450
+ print(f'{g=} {g_sep} {group_strings}')
451
+ raise e
452
+ predicted = embeddings.predict(
453
+ threshold=threshold,
454
+ separate_strings=strings,
455
+ **kwargs)
456
+
457
+ self.unite(predicted, inplace=True)
458
+
459
+ return self
460
+
461
+ # def refine(self,similarity_model)
462
+
463
+ def top_scored_pairs_df(self, similarity_model,
464
+ n=10000, buffer_n=100000,
465
+ by_group=True,
466
+ sort_by=['impact', 'score'], ascending=False,
467
+ skip_pairs=None, **kwargs):
468
+ """
469
+ Return the DataFrame containing the n most important pairs of strings, according to the score generated by the `similarity_model`.
470
+
471
+ Parameters
472
+ ----------
473
+ similarity_model: Model
474
+ Model used to predict similarity between strings.
475
+ n: int, optional
476
+ Number of most important pairs to return. Default is 10000.
477
+ buffer_n: int, optional
478
+ Size of buffer to iterate through the scored pairs. Default is 100000.
479
+ by_group: bool, optional
480
+ If True, only the most important pair will be returned for each unique group combination.
481
+ sort_by: list, optional
482
+ A list of column names by which to sort the dataframe. Default is ['impact','score'].
483
+ ascending: bool, optional
484
+ Whether the sort order should be ascending or descending. Default is False.
485
+ skip_pairs: list, optional
486
+ List of string pairs to ignore when constructing the ranking.
487
+ If by_group=True, any group combination represented in the skip_pairs list will be ignored
488
+ kwargs: dict, optional
489
+ Additional keyword arguments passed to the `iter_scored_pairs` function.
490
+
491
+ Returns
492
+ -------
493
+ top_df: pandas.DataFrame
494
+ The DataFrame containing the n most important pairs of strings.
495
+
496
+ """
497
+
498
+ top_df = pd.DataFrame(
499
+ columns=[
500
+ 'string0',
501
+ 'string1',
502
+ 'group0',
503
+ 'group1',
504
+ 'impact',
505
+ 'score',
506
+ 'loss'])
507
+ pair_iterator = similarity_model.iter_scored_pairs(self, **kwargs)
508
+
509
+ def group_size(g):
510
+ return len(self.groups[g])
511
+
512
+ if skip_pairs is not None:
513
+ if by_group:
514
+ skip_pairs = {tuple(sorted([self[s0], self[s1]]))
515
+ for s0, s1 in skip_pairs}
516
+ else:
517
+ skip_pairs = {tuple(sorted([s0, s1])) for s0, s1 in skip_pairs}
518
+
519
+ while True:
520
+ df = pd.DataFrame(islice(pair_iterator, buffer_n))
521
+
522
+ if len(df):
523
+ for i in 0, 1:
524
+ df[f'group{i}'] = [self[s] for s in df[f'string{i}']]
525
+ df['impact'] = df['group0'].apply(
526
+ group_size) * df['group1'].apply(group_size)
527
+
528
+ if by_group:
529
+ df['group_pair'] = [tuple(sorted([g0, g1])) for g0, g1 in df[[
530
+ 'group0', 'group1']].values]
531
+
532
+ if skip_pairs:
533
+ if by_group:
534
+ df = df[~df['group_pair'].isin(skip_pairs)]
535
+ else:
536
+ string_pairs = [tuple(sorted([s0, s1])) for s0, s1 in df[[
537
+ 'string0', 'string1']].values]
538
+ df = df[~string_pairs.isin(skip_pairs)]
539
+
540
+ if len(df):
541
+ top_df = pd.concat([top_df, df]) \
542
+ .sort_values(sort_by, ascending=ascending)
543
+
544
+ if by_group:
545
+ top_df = top_df \
546
+ .groupby('group_pair') \
547
+ .first() \
548
+ .reset_index()
549
+
550
+ top_df = top_df \
551
+ .sort_values(sort_by, ascending=ascending) \
552
+ .head(n)
553
+ else:
554
+ break
555
+
556
+ if len(top_df) and by_group:
557
+ top_df = top_df \
558
+ .drop('group_pair', axis=1) \
559
+ .reset_index()
560
+
561
+ return top_df
562
+
563
+ def reset_counts(self, inplace=False):
564
+ """
565
+ Reset the counts of strings in the MatchGroups object.
566
+
567
+ Parameters
568
+ ----------
569
+ inplace: bool, optional
570
+ If True, the operation is performed in-place. Otherwise, a copy is created.
571
+
572
+ Returns
573
+ -------
574
+ self: MatchGroups
575
+ Returns the MatchGroups object after the reset operation.
576
+
577
+ """
578
+ if not inplace:
579
+ self = self.copy()
580
+
581
+ self.counts = Counter(self.strings())
582
+
583
+ return self
584
+
585
+ def to_df(self, singletons=True, sort_groups=True):
586
+ """
587
+ Convert the match groups object to a dataframe with string, count and group columns.
588
+
589
+ Parameters
590
+ ----------
591
+ singletons: bool, optional
592
+ If True, the resulting DataFrame will include singleton groups. Default is True.
593
+ ...
594
+
595
+ Returns
596
+ -------
597
+ df: pandas.DataFrame
598
+ The resulting DataFrame.
599
+ """
600
+ strings = self.strings()
601
+
602
+ if singletons:
603
+ df = pd.DataFrame([(s, self.counts[s], self.labels[s]) for s in strings],
604
+ columns=['string', 'count', 'group'])
605
+ else:
606
+ df = pd.DataFrame([(s, self.counts[s], self.labels[s]) for s in strings
607
+ if len(self.groups[self[s]]) > 1],
608
+ columns=['string', 'count', 'group'])
609
+ if sort_groups:
610
+ df['group_count'] = df.groupby('group')['count'].transform('sum')
611
+ df = df.sort_values(['group_count', 'group', 'count', 'string'], ascending=[
612
+ False, True, False, True])
613
+ df = df.drop('group_count', axis=1)
614
+ df = df.reset_index(drop=True)
615
+
616
+ return df
617
+
618
+ def to_csv(self, filename, singletons=True, **pandas_args):
619
+ """
620
+ Save the match groups object as a csv file with string, count and group columns.
621
+
622
+ Parameters
623
+ ----------
624
+ filename : str
625
+ Path to file to save the data.
626
+ singletons : bool, optional
627
+ If True, include singleton groups in the saved file, by default True.
628
+ pandas_args : dict
629
+ Additional keyword arguments to pass to the pandas.DataFrame.to_csv method.
630
+ """
631
+ df = self.to_df(singletons=singletons)
632
+ df.to_csv(filename, index=False, **pandas_args)
633
+
634
+ def merge_dfs(self, left_df, right_df, how='inner',
635
+ on=None, left_on=None, right_on=None,
636
+ group_column_name='match_group', suffixes=('_x', '_y'),
637
+ **merge_args):
638
+ """
639
+ Replicated pandas.merge() functionality, except that dataframes are merged by match group instead of directly on the strings in the "on" columns.
640
+
641
+ Parameters
642
+ ----------
643
+ left_df : pandas.DataFrame
644
+ The left dataframe to merge.
645
+ right_df : pandas.DataFrame
646
+ The right dataframe to merge.
647
+ how : str, optional
648
+ How to merge the dataframes. Possible values are 'left', 'right', 'outer', 'inner', by default 'inner'.
649
+ on : str, optional
650
+ Columns in both left and right dataframes to merge on.
651
+ left_on : str, optional
652
+ Columns in the left dataframe to merge on.
653
+ right_on : str, optional
654
+ Columns in the right dataframe to merge on.
655
+ group_column_name : str, optional
656
+ Column name for the merged match group, by default 'match_group'.
657
+ suffixes : tuple of str, optional
658
+ Suffix to apply to overlapping column names in the left and right dataframes, by default ('_x','_y').
659
+ **merge_args : dict
660
+ Additional keyword arguments to pass to the pandas.DataFrame.merge method.
661
+
662
+ Returns
663
+ -------
664
+ pandas.DataFrame
665
+ The merged dataframe.
666
+
667
+ Raises
668
+ ------
669
+ ValueError
670
+ If 'on', 'left_on', and 'right_on' are all None.
671
+ ValueError
672
+ If `group_column_name` already exists in one of the dataframes.
673
+ """
674
+
675
+ if ((left_on is None) or (right_on is None)) and (on is None):
676
+ raise ValueError('Must provide column(s) to merge on')
677
+
678
+ left_df = left_df.copy()
679
+ right_df = right_df.copy()
680
+
681
+ if on is not None:
682
+ left_on = on + suffixes[0]
683
+ right_on = on + suffixes[1]
684
+
685
+ left_df = left_df.rename(columns={on:left_on})
686
+ right_df = right_df.rename(columns={on:right_on})
687
+
688
+ group_map = lambda s: self[s] if s in self.labels else np.nan
689
+
690
+ left_group = left_df[left_on].apply(group_map)
691
+ right_group = right_df[right_on].apply(group_map)
692
+
693
+ if group_column_name:
694
+ if group_column_name in list(left_df.columns) + list(right_df.columns):
695
+ raise ValueError('f{group_column_name=} already exists in one of the dataframes.')
696
+ else:
697
+ left_df[group_column_name] = left_group
698
+
699
+ merged_df = pd.merge(left_df,right_df,left_on=left_group,right_on=right_group,how=how,suffixes=suffixes,**merge_args)
700
+
701
+ merged_df = merged_df[[c for c in merged_df.columns if c in list(left_df.columns) + list(right_df.columns)]]
702
+
703
+ return merged_df
704
+
705
+
706
+ def from_df(
707
+ df,
708
+ match_format='detect',
709
+ pair_columns=[
710
+ 'string0',
711
+ 'string1'],
712
+ string_column='string',
713
+ group_column='group',
714
+ count_column='count'):
715
+ """
716
+ Construct a new match groups object from a pandas DataFrame.
717
+
718
+ Parameters
719
+ ----------
720
+ df : pandas.DataFrame
721
+ The input dataframe.
722
+ match_format : str, optional
723
+ The format of the dataframe, by default "detect".
724
+ It can be one of ['unmatched', 'groups', 'pairs', 'detect'].
725
+ pair_columns : list of str, optional
726
+ The columns names containing the string pairs, by default ['string0','string1'].
727
+ string_column : str, optional
728
+ The column name containing the strings, by default 'string'.
729
+ group_column : str, optional
730
+ The column name containing the groups, by default 'group'.
731
+ count_column : str, optional
732
+ The column name containing the counts, by default 'count'.
733
+
734
+ Returns
735
+ -------
736
+ MatchGroups
737
+ The constructed MatchGroups object.
738
+
739
+ Raises
740
+ ------
741
+ ValueError
742
+ If the input `match_format` is not one of ['unmatched', 'groups', 'pairs', 'detect'].
743
+ ValueError
744
+ If the `match_format` is 'detect' and the input dataframe format could not be inferred.
745
+
746
+ Notes
747
+ -----
748
+ The function accepts two formats of the input dataframe:
749
+
750
+ - "groups": The standard format for a match groups object dataframe. It includes a
751
+ string column, and a "group" column that contains group labels, and an
752
+ optional "count" column. These three columns completely describe a
753
+ match groups object, allowing lossless match groups object -> dataframe -> match groups object
754
+ conversion (though the specific group labels in the dataframe will be
755
+ ignored and rebuilt in the new match groups object).
756
+
757
+ - "pairs": The dataframe includes two string columns, and each row indicates
758
+ a link between a pair of strings. A new match groups object will be constructed by
759
+ uniting each pair of strings.
760
+ """
761
+
762
+ if match_format not in ['unmatched', 'groups', 'pairs', 'detect']:
763
+ raise ValueError(
764
+ 'match_format must be one of "unmatched", "groups", "pairs", or "detect"')
765
+
766
+ # Create an empty match groups object
767
+ match_obj = MatchGroups()
768
+
769
+ if match_format == 'detect':
770
+ if (string_column in df.columns):
771
+ if group_column is None:
772
+ match_format = 'unmatched'
773
+ elif (group_column in df.columns):
774
+ match_format = 'groups'
775
+ elif set(df.columns) == set(pair_columns):
776
+ match_format = 'pairs'
777
+
778
+ if match_format == 'detect':
779
+ raise ValueError('Could not infer valid dataframe format from input')
780
+
781
+ if count_column in df.columns:
782
+ counts = df[count_column].values
783
+ else:
784
+ counts = np.ones(len(df))
785
+
786
+ if match_format == 'unmatched':
787
+ strings = df[string_column].values
788
+
789
+ # Build the match groups object
790
+ match_obj.counts = Counter({s: int(c) for s, c in zip(strings, counts)})
791
+ match_obj.labels = {s: s for s in strings}
792
+ match_obj.groups = {s: [s] for s in strings}
793
+
794
+ elif match_format == 'groups':
795
+
796
+ strings = df[string_column].values
797
+ group_ids = df[group_column].values
798
+
799
+ # Sort by group and string count
800
+ g_sort = np.lexsort((counts, group_ids))
801
+ group_ids = group_ids[g_sort]
802
+ strings = strings[g_sort]
803
+ counts = counts[g_sort]
804
+
805
+ # Identify group boundaries and split locations
806
+ split_locs = np.nonzero(group_ids[1:] != group_ids[:-1])[0] + 1
807
+
808
+ # Get grouped strings as separate arrays
809
+ groups = np.split(strings, split_locs)
810
+
811
+ # Build the match groups object
812
+ match_obj.counts = Counter({s: int(c) for s, c in zip(strings, counts)})
813
+ match_obj.labels = {s: g[-1] for g in groups for s in g}
814
+ match_obj.groups = {g[-1]: list(g) for g in groups}
815
+
816
+ elif match_format == 'pairs':
817
+ # TODO: Allow pairs data to use counts
818
+ for pair_column in pair_columns:
819
+ match_obj.add_strings(df[pair_column].values, inplace=True)
820
+
821
+ # There are several ways to unite pairs
822
+ # Guessing it is most efficient to "group by" one of the string columns
823
+ groups = {s: pair[1] for pair in df[pair_columns].values for s in pair}
824
+
825
+ match_obj.unite(groups, inplace=True)
826
+
827
+ return match_obj
828
+
829
+
830
+ def read_csv(
831
+ filename,
832
+ match_format='detect',
833
+ pair_columns=[
834
+ 'string0',
835
+ 'string1'],
836
+ string_column='string',
837
+ group_column='group',
838
+ count_column='count',
839
+ **pandas_args):
840
+ """
841
+ Read a csv file and construct a new match groups object.
842
+
843
+ Parameters
844
+ ----------
845
+ filename : str
846
+ The path to the csv file.
847
+ match_format : str, optional (default='detect')
848
+ One of "unmatched", "groups", "pairs", or "detect".
849
+ pair_columns : list of str, optional (default=['string0', 'string1'])
850
+ Two string columns to use if match_format='pairs'.
851
+ string_column : str, optional (default='string')
852
+ Column name for string values in match_format='unmatched' or 'groups'.
853
+ group_column : str, optional (default='group')
854
+ Column name for group values in match_format='groups'.
855
+ count_column : str, optional (default='count')
856
+ Column name for count values in match_format='unmatched' or 'groups'.
857
+ **pandas_args : optional
858
+ Optional arguments to pass to `pandas.read_csv`.
859
+
860
+ Returns
861
+ -------
862
+ MatchGroups
863
+ A new match groups object built from the csv file.
864
+ """
865
+ df = pd.read_csv(filename, **pandas_args, na_filter=False)
866
+ df = df.astype(str)
867
+
868
+ return from_df(df, match_format=match_format, pair_columns=pair_columns,
869
+ string_column=string_column, group_column=group_column,
870
+ count_column=count_column)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e112a851e5079096d2f0bab96d21dfb0dea8cd92a2f23fe9218cc65dd9777fbe
3
+ size 499051193
scoring.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import random
3
+
4
+
5
+ def confusion_df(predicted_groupings, gold_groupings, use_counts=True):
6
+ """
7
+ Computes the confusion matrix dataframe for a predicted match groups object relative to a gold match groups object.
8
+
9
+ Parameters
10
+ ----------
11
+ predicted_groupings : MatchGroups
12
+ The predicted match groups object.
13
+ gold_groupings : MatchGroups
14
+ The gold match groups object.
15
+ use_counts : bool, optional
16
+ Use the count of each string. If False, the count is set to 1.
17
+
18
+ Returns
19
+ -------
20
+ df : pandas.DataFrame
21
+ Confusion matrix dataframe with columns 'TP', 'FP', 'TN', and 'FN'.
22
+ """
23
+
24
+ df = pd.merge(
25
+ predicted_groupings.to_df(),
26
+ gold_groupings.to_df().drop(
27
+ 'count',
28
+ axis=1),
29
+ on='string',
30
+ suffixes=[
31
+ '_pred',
32
+ '_gold'])
33
+
34
+ if not use_counts:
35
+ df['count'] = 1
36
+
37
+ df['TP'] = (df.groupby(['group_pred', 'group_gold'])[
38
+ 'count'].transform('sum') - df['count']) * df['count']
39
+ df['FP'] = (df.groupby('group_pred')['count'].transform(
40
+ 'sum') - df['count']) * df['count'] - df['TP']
41
+ df['FN'] = (df.groupby('group_gold')['count'].transform(
42
+ 'sum') - df['count']) * df['count'] - df['TP']
43
+ df['TN'] = (df['count'].sum() - df['count']) * \
44
+ df['count'] - df['TP'] - df['FP'] - df['FN']
45
+
46
+ return df
47
+
48
+
49
+ def confusion_matrix(predicted_groupings, gold_groupings, use_counts=True):
50
+ """
51
+ Computes the confusion matrix for a predicted match groups object relative to a gold match groups object.
52
+
53
+ Parameters
54
+ ----------
55
+ predicted_groupings : MatchGroups
56
+ The predicted match groups object.
57
+ gold_groupings : MatchGroups
58
+ The gold match groups object.
59
+ use_counts : bool, optional
60
+ Use the count of each string. If False, the count is set to 1.
61
+
62
+ Returns
63
+ -------
64
+ confusion_matrix : dict
65
+ Dictionary with keys 'TP', 'FP', 'TN', and 'FN', representing the values in the confusion matrix.
66
+ """
67
+
68
+ df = confusion_df(predicted_groupings, gold_groupings, use_counts=use_counts)
69
+
70
+ return {c: df[c].sum() // 2 for c in ['TP', 'FP', 'TN', 'FN']}
71
+
72
+
73
+ def score_predicted(
74
+ predicted_groupings,
75
+ gold_groupings,
76
+ use_counts=True,
77
+ drop_self_matches=True):
78
+ """
79
+ Computes the F1 score of a predicted match groups object relative to a gold match groups object
80
+ which is assumed to be correct.
81
+
82
+ Parameters
83
+ ----------
84
+ predicted_groupings : MatchGroups
85
+ The predicted match groups object .
86
+ gold_groupings : MatchGroups
87
+ The gold match groups object.
88
+ use_counts : bool, optional
89
+ Use the count of each string. If False, the count is set to 1.
90
+ drop_self_matches : bool, optional
91
+ Remove the matches between a string and itself.
92
+
93
+ Returns
94
+ -------
95
+ scores : dict
96
+ Dictionary with keys 'accuracy', 'precision', 'recall', 'F1', and 'coverage'.
97
+ """
98
+
99
+ scores = confusion_matrix(
100
+ predicted_groupings,
101
+ gold_groupings,
102
+ use_counts=use_counts)
103
+
104
+ n_scored = scores['TP'] + scores['TN'] + scores['FP'] + scores['FN']
105
+
106
+ if use_counts:
107
+ n_predicted = (sum(predicted_groupings.counts.values())**2 -
108
+ sum(c**2 for c in predicted_groupings.counts.values())) / 2
109
+ else:
110
+ n_predicted = (len(predicted_groupings)**2
111
+ - len(predicted_groupings)) / 2
112
+
113
+ scores['coverage'] = n_scored / n_predicted
114
+
115
+ if scores['TP']:
116
+ scores['accuracy'] = (scores['TP'] + scores['TN']) / n_scored
117
+ scores['precision'] = scores['TP'] / (scores['TP'] + scores['FP'])
118
+ scores['recall'] = scores['TP'] / (scores['TP'] + scores['FN'])
119
+ scores['F1'] = 2 * (scores['precision'] * scores['recall']) / \
120
+ (scores['precision'] + scores['recall'])
121
+
122
+ else:
123
+ scores['accuracy'] = 0
124
+ scores['precision'] = 0
125
+ scores['recall'] = 0
126
+ scores['F1'] = 0
127
+
128
+ return scores
129
+
130
+
131
+ def split_on_groups(groupings, frac=0.5, seed=None):
132
+ """
133
+ Splits the match groups object into two parts by given fraction.
134
+
135
+ Parameters
136
+ ----------
137
+ groupings : MatchGroups
138
+ The match groups object to be split.
139
+ frac : float, optional
140
+ The fraction of groups to select.
141
+ seed : int, optional
142
+ Seed for the random number generator.
143
+
144
+ Returns
145
+ -------
146
+ groupings1, groupings2 : tuple of match groups objects
147
+ Tuple of two match groups objects.
148
+ """
149
+ if seed is not None:
150
+ random.seed(seed)
151
+
152
+ groups = list(groupings.groups.values())
153
+ random.shuffle(groups)
154
+
155
+ selected_groups = groups[:int(frac * len(groups))]
156
+ selected_strings = [s for group in selected_groups for s in group]
157
+
158
+ return groupings.keep(selected_strings), groupings.drop(selected_strings)
159
+
160
+
161
+ def kfold_on_groups(groupings, k=4, shuffle=True, seed=None):
162
+ """
163
+ Perform K-fold cross validation on groups of strings.
164
+
165
+ Parameters
166
+ ----------
167
+ groupings : object
168
+ MatchGroups object to perform K-fold cross validation on.
169
+ k : int, optional
170
+ Number of folds to perform, by default 4.
171
+ shuffle : bool, optional
172
+ Whether to shuffle the groups before splitting, by default True.
173
+ seed : int, optional
174
+ Seed for the random number generator, by default None.
175
+
176
+ Yields
177
+ ------
178
+ tuple : MatchGroups, MatchGroups
179
+ A tuple of k match groups objects, the first for the training set and the second for the testing set for each fold.
180
+ """
181
+ if seed is not None:
182
+ random.seed(seed)
183
+
184
+ groups = list(groupings.groups.keys())
185
+
186
+ if shuffle:
187
+ random.shuffle(groups)
188
+ else:
189
+ groups = sorted(groups)
190
+
191
+ for fold in range(k):
192
+
193
+ fold_groups = groups[fold::k]
194
+ fold_strings = [s for g in fold_groups for s in groupings.groups[g]]
195
+
196
+ yield groupings.drop(fold_strings), groupings.keep(fold_strings)
scoring_model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ class SimilarityScore(torch.nn.Module):
5
+ """
6
+ A trainable similarity scoring model that estimates the probability
7
+ of a match as the negative exponent of 1+cosine distance between
8
+ embeddings:
9
+ p(match|v_i,v_j) = exp(-alpha*(1-v_i@v_j))
10
+ """
11
+ def __init__(self,config,**kwargs):
12
+
13
+ super().__init__()
14
+
15
+ self.alpha = torch.nn.Parameter(torch.tensor(float(config.get("alpha"))))
16
+
17
+ def __repr__(self):
18
+ return f'<nama.ExpCosSimilarity with {self.alpha=}>'
19
+
20
+ def forward(self,X):
21
+ # Z is a scaled distance measure: Z=0 means that the score should be 1
22
+ Z = self.alpha*(1 - X)
23
+ return torch.clamp(torch.exp(-Z),min=0,max=1.0)
24
+
25
+ def loss(self,X,Y,weights=None,decay=1e-6,epsilon=1e-6):
26
+
27
+ Z = self.alpha*(1 - X)
28
+
29
+ # Put epsilon floor to prevent overflow/undefined results
30
+ # Z = torch.tensor([1e-2,1e-3,1e-6,1e-7,1e-8,1e-9])
31
+ # torch.log(1 - torch.exp(-Z))
32
+ # 1/(1 - torch.exp(-Z))
33
+ with torch.no_grad():
34
+ Z_eps_adjustment = torch.clamp(epsilon-Z,min=0)
35
+
36
+ Z += Z_eps_adjustment
37
+
38
+ # Cross entropy loss with a simplified and (hopefully) numerically appropriate formula
39
+ # TODO: Stick an epsilon in here to prevent nan?
40
+ loss = Y*Z - torch.xlogy(1-Y,-torch.expm1(-Z))
41
+ # loss = Y*Z - torch.xlogy(1-Y,1-torch.exp(-Z))
42
+
43
+ if weights is not None:
44
+ loss *= weights*loss
45
+
46
+ if decay:
47
+ loss += decay*self.alpha**2
48
+
49
+ return loss
50
+
51
+ def score_to_cos(self,score):
52
+ if score > 0:
53
+ return 1 + np.log(score)/self.alpha.item()
54
+ else:
55
+ return -99
56
+
57
+ def config_optimizer(self,lr=10):
58
+ optimizer = torch.optim.AdamW(self.parameters(),lr=lr,weight_decay=0)
59
+
60
+ return optimizer
similarity_model.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import pandas as pd
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from copy import copy,deepcopy
6
+ from collections import Counter
7
+ import torch
8
+ from torch import nn
9
+ from torch.utils.data import DataLoader
10
+ from transformers import get_cosine_schedule_with_warmup,get_linear_schedule_with_warmup, logging
11
+ from transformers.modeling_utils import PreTrainedModel
12
+
13
+ from .match_groups import MatchGroups
14
+ from .scoring import score_predicted
15
+ from .scoring_model import SimilarityScore
16
+ from .embeddings import Embeddings
17
+ from .embedding_model import EmbeddingModel
18
+ from .configuration import SimilarityModelConfig
19
+ logging.set_verbosity_error()
20
+
21
+
22
+ class ExponentWeights():
23
+ def __init__(self, config,**kwargs):
24
+ self.exponent = config.get("weighting_exponent", 0.5)
25
+
26
+ def __call__(self,counts):
27
+ return counts**self.exponent
28
+
29
+
30
+ class SimilarityModel(PreTrainedModel):
31
+ config_class = SimilarityModelConfig
32
+ """
33
+ A combined embedding/scorer model that produces Embeddings objects
34
+ as its primary output.
35
+
36
+ - train() jointly optimizes the embedding_model and score_model using
37
+ contrastive learning to learn from a training MatchGroups.
38
+ """
39
+ def __init__(self, config, **kwargs):
40
+ super().__init__(config)
41
+
42
+ self.embedding_model = EmbeddingModel(config.embedding_model_config, **kwargs)
43
+ self.score_model = SimilarityScore(config.score_model_config, **kwargs)
44
+ self.weighting_function = ExponentWeights(config.weighting_function_config, **kwargs)
45
+
46
+ self.config = config
47
+ self.to(config.device)
48
+
49
+ def to(self,device):
50
+ super().to(device)
51
+ self.embedding_model.to(device)
52
+ self.score_model.to(device)
53
+ #self.device = device
54
+
55
+ def save(self,savefile):
56
+ torch.save({'metadata': self.config, 'state_dict': self.state_dict()}, savefile)
57
+
58
+ @torch.no_grad()
59
+ def embed(self,input,to=None,batch_size=64,progress_bar=True,**kwargs):
60
+ """
61
+ Construct an Embeddings object from input strings or a MatchGroups
62
+ """
63
+
64
+ if to is None:
65
+ to = self.device
66
+
67
+ if isinstance(input, MatchGroups):
68
+ strings = input.strings()
69
+ counts = torch.tensor([input.counts[s] for s in strings],device=self.device).float().to(to)
70
+
71
+ else:
72
+ strings = list(input)
73
+ counts = torch.ones(len(strings),device=self.device).float().to(to)
74
+
75
+ input_loader = DataLoader(strings,batch_size=batch_size,num_workers=0)
76
+
77
+ self.embedding_model.eval()
78
+
79
+ V = None
80
+ batch_start = 0
81
+ with tqdm(total=len(strings),delay=1,desc='Embedding strings',disable=not progress_bar) as pbar:
82
+ for batch_strings in input_loader:
83
+
84
+ v = self.embedding_model(batch_strings).detach().to(to)
85
+
86
+ if V is None:
87
+ # Use v to determine dim and dtype of pre-allocated embedding tensor
88
+ # (Pre-allocating avoids duplicating tensors with a big .cat() operation)
89
+ V = torch.empty(len(strings),v.shape[1],device=to,dtype=v.dtype)
90
+
91
+ V[batch_start:batch_start+len(batch_strings),:] = v
92
+
93
+ pbar.update(len(batch_strings))
94
+ batch_start += len(batch_strings)
95
+
96
+ score_model = copy(self.score_model)
97
+ score_model.load_state_dict(self.score_model.state_dict())
98
+ score_model.to(to)
99
+
100
+ weighting_function = deepcopy(self.weighting_function)
101
+
102
+ return Embeddings(strings=strings,
103
+ V=V.detach(),
104
+ counts=counts.detach(),
105
+ score_model=score_model,
106
+ weighting_function=weighting_function,
107
+ device=to)
108
+
109
+ def train(self,training_groupings,max_epochs=1,batch_size=8,
110
+ score_decay=0,regularization=0,
111
+ transformer_lr=1e-5,projection_lr=1e-5,score_lr=10,warmup_frac=0.1,
112
+ max_grad_norm=1,dropout=False,
113
+ validation_groupings=None,target='F1',restore_best=True,val_seed=None,
114
+ validation_interval=1000,early_stopping=True,early_stopping_patience=3,
115
+ verbose=False,progress_bar=True,
116
+ **kwargs):
117
+
118
+ """
119
+ Train the embedding_model and score_model to predict match probabilities
120
+ using the training_groupings as a source of "correct" matches.
121
+ Training algorithm uses contrastive learning with hard-positive
122
+ and hard-negative mining to fine tune the embedding model to place
123
+ matched strings near to each other in embedding space, while
124
+ simulataneously calibrating the score_model to predict the match
125
+ probabilities as a function of cosine distance
126
+ """
127
+
128
+ if validation_groupings is None:
129
+ early_stopping = False
130
+ restore_best = False
131
+
132
+ num_training_steps = max_epochs*len(training_groupings)//batch_size
133
+ num_warmup_steps = int(warmup_frac*num_training_steps)
134
+
135
+ if transformer_lr or projection_lr:
136
+ embedding_optimizer = self.embedding_model.config_optimizer(transformer_lr,projection_lr)
137
+ embedding_scheduler = get_cosine_schedule_with_warmup(
138
+ embedding_optimizer,
139
+ num_warmup_steps=num_warmup_steps,
140
+ num_training_steps=num_training_steps)
141
+ if score_lr:
142
+ score_optimizer = self.score_model.config_optimizer(score_lr)
143
+ score_scheduler = get_linear_schedule_with_warmup(
144
+ score_optimizer,
145
+ num_warmup_steps=num_warmup_steps,
146
+ num_training_steps=num_training_steps)
147
+
148
+ step = 0
149
+ self.history = []
150
+ self.val_scores = []
151
+ for epoch in range(max_epochs):
152
+
153
+ global_embeddings = self.embed(training_groupings)
154
+
155
+ strings = global_embeddings.strings
156
+ V = global_embeddings.V
157
+ w = global_embeddings.w
158
+
159
+ groups = torch.tensor([global_embeddings.string_map[training_groupings[s]] for s in strings],device=self.device)
160
+
161
+ # Normalize weights to make learning rates more general
162
+ if w is not None:
163
+ w = w/w.mean()
164
+
165
+ shuffled_ids = list(range(len(strings)))
166
+ random.shuffle(shuffled_ids)
167
+
168
+ if dropout:
169
+ self.embedding_model.train()
170
+ else:
171
+ self.embedding_model.eval()
172
+
173
+ for batch_start in tqdm(range(0,len(strings),batch_size),desc=f'training epoch {epoch}',disable=not progress_bar):
174
+
175
+ h = {'epoch':epoch,'step':step}
176
+
177
+ batch_i = shuffled_ids[batch_start:batch_start+batch_size]
178
+
179
+ # Recycle ids from the beginning to pad the last batch if necessary
180
+ if len(batch_i) < batch_size:
181
+ batch_i = batch_i + shuffled_ids[:(batch_size-len(batch_i))]
182
+
183
+ """
184
+ Find highest loss match for each batch string (global search)
185
+
186
+ Note: If we compute V_i with dropout enabled, it will add noise
187
+ to the embeddings and prevent the same pairs from being selected
188
+ every time.
189
+ """
190
+ V_i = self.embedding_model(strings[batch_i])
191
+
192
+ # Update global embedding cache
193
+ V[batch_i,:] = V_i.detach()
194
+
195
+ with torch.no_grad():
196
+
197
+ global_X = V_i@V.T
198
+ global_Y = (groups[batch_i][:,None] == groups[None,:]).float()
199
+
200
+ if w is not None:
201
+ global_W = torch.outer(w[batch_i],w)
202
+ else:
203
+ global_W = None
204
+
205
+ # Train scoring model only
206
+ if score_lr:
207
+ # Make sure gradients are enabled for score model
208
+ self.score_model.requires_grad_(True)
209
+
210
+ global_loss = self.score_model.loss(global_X,global_Y,weights=global_W,decay=score_decay)
211
+
212
+ score_optimizer.zero_grad()
213
+ global_loss.nanmean().backward()
214
+ torch.nn.utils.clip_grad_norm_(self.score_model.parameters(),max_norm=max_grad_norm)
215
+
216
+ score_optimizer.step()
217
+ score_scheduler.step()
218
+
219
+ h['score_lr'] = score_optimizer.param_groups[0]['lr']
220
+ h['global_mean_cos'] = global_X.mean().item()
221
+ try:
222
+ h['score_alpha'] = self.score_model.alpha.item()
223
+ except:
224
+ pass
225
+
226
+ else:
227
+ with torch.no_grad():
228
+ global_loss = self.score_model.loss(global_X,global_Y)
229
+
230
+ h['global_loss'] = global_loss.detach().nanmean().item()
231
+
232
+ # Train embedding model
233
+ if (transformer_lr or projection_lr) and step <= num_warmup_steps + num_training_steps:
234
+
235
+ # Turn off score model updating - only want to train embedding here
236
+ self.score_model.requires_grad_(False)
237
+
238
+ # Select hard training examples
239
+ with torch.no_grad():
240
+ batch_j = global_loss.argmax(dim=1).flatten()
241
+
242
+ if w is not None:
243
+ batch_W = torch.outer(w[batch_i],w[batch_j])
244
+ else:
245
+ batch_W = None
246
+
247
+ # Train the model on the selected high-loss pairs
248
+ V_j = self.embedding_model(strings[batch_j.tolist()])
249
+
250
+ # Update global embedding cache
251
+ V[batch_j,:] = V_j.detach()
252
+
253
+ batch_X = V_i@V_j.T
254
+ batch_Y = (groups[batch_i][:,None] == groups[batch_j][None,:]).float()
255
+ h['batch_obs'] = len(batch_i)*len(batch_j)
256
+
257
+ batch_loss = self.score_model.loss(batch_X,batch_Y,weights=batch_W)
258
+
259
+ if regularization:
260
+ # Apply Global Orthogonal Regularization from https://arxiv.org/abs/1708.06320
261
+ gor_Y = (groups[batch_i][:,None] != groups[batch_i][None,:]).float()
262
+ gor_n = gor_Y.sum()
263
+ if gor_n > 1:
264
+ gor_X = (V_i@V_i.T)*gor_Y
265
+ gor_m1 = 0.5*gor_X.sum()/gor_n
266
+ gor_m2 = 0.5*(gor_X**2).sum()/gor_n
267
+ batch_loss += regularization*(gor_m1 + torch.clamp(gor_m2 - 1/self.embedding_model.d,min=0))
268
+
269
+ h['batch_nan'] = torch.isnan(batch_loss.detach()).sum().item()
270
+
271
+ embedding_optimizer.zero_grad()
272
+ batch_loss.nanmean().backward()
273
+
274
+ torch.nn.utils.clip_grad_norm_(self.parameters(),max_norm=max_grad_norm)
275
+
276
+ embedding_optimizer.step()
277
+ embedding_scheduler.step()
278
+
279
+ h['transformer_lr'] = embedding_optimizer.param_groups[1]['lr']
280
+ h['projection_lr'] = embedding_optimizer.param_groups[-1]['lr']
281
+
282
+ # Save stats
283
+ h['batch_loss'] = batch_loss.detach().mean().item()
284
+ h['batch_pos_target'] = batch_Y.detach().mean().item()
285
+
286
+ self.history.append(h)
287
+ step += 1
288
+
289
+ if (validation_groupings is not None) and not (step % validation_interval):
290
+
291
+ validation = len(self.validation_scores)
292
+ val_scores = self.test(validation_groupings)
293
+ val_scores['step'] = step - 1
294
+ val_scores['epoch'] = epoch
295
+ val_scores['validation'] = validation
296
+
297
+ self.validation_scores.append(val_scores)
298
+
299
+ # Print validation stats
300
+ if verbose:
301
+ print(f'\nValidation results at step {step} (current epoch {epoch})')
302
+ for k,v in val_scores.items():
303
+ print(f' {k}: {v:.4f}')
304
+
305
+ print(list(self.score_model.named_parameters()))
306
+
307
+ # Update best saved model
308
+ if restore_best:
309
+ if val_scores[target] >= max(h[target] for h in self.validation_scores):
310
+ best_state = deepcopy({
311
+ 'state_dict':self.state_dict(),
312
+ 'val_scores':val_scores
313
+ })
314
+
315
+ if early_stopping and (validation - best_state['val_scores']['validation'] > early_stopping_patience):
316
+ print(f'Stopping training ({early_stopping_patience} validation checks since best validation score)')
317
+ break
318
+
319
+ if restore_best:
320
+ print(f"Restoring to best state (step {best_state['val_scores']['step']}):")
321
+ for k,v in best_state['val_scores'].items():
322
+ print(f' {k}: {v:.4f}')
323
+
324
+ self.to('cpu')
325
+ self.load_state_dict(best_state['state_dict'])
326
+ self.to(self.device)
327
+
328
+ return pd.DataFrame(self.history)
329
+
330
+ def unite_similar(self,input,**kwargs):
331
+ embeddings = self.embed(input,**kwargs)
332
+ return embeddings.unite_similar(**kwargs)
333
+
334
+ def test(self,gold_groupings, threshold=0.5, **kwargs):
335
+ embeddings = self.embed(gold_groupings, **kwargs)
336
+
337
+ if (isinstance(threshold, float)):
338
+ predicted = embeddings.unite_similar(threshold=threshold, **kwargs)
339
+ scores = score_predicted(predicted, gold_groupings, use_counts=True)
340
+
341
+ return scores
342
+
343
+ results = []
344
+ for thres in threshold:
345
+ predicted = embeddings.unite_similar(threshold=thres, **kwargs)
346
+
347
+ scores = score_predicted(predicted, gold_groupings, use_counts=True)
348
+ scores["threshold"] = thres
349
+ results.append(scores)
350
+
351
+
352
+ return results
353
+
354
+
355
+
356
+ def load_similarity_model(f,map_location='cpu',*args,**kwargs):
357
+ checkpoint = torch.load(f, map_location=map_location, **kwargs)
358
+ metadata = checkpoint['metadata']
359
+ state_dict = checkpoint['state_dict']
360
+
361
+ model = SimilarityModel(config=metadata)
362
+ model.load_state_dict(state_dict)
363
+
364
+ return model
365
+ #return torch.load(f,map_location=map_location,**kwargs)
366
+
367
+
368
+
369
+