|
import pandas as pd |
|
import numpy as np |
|
from tqdm import tqdm |
|
from copy import copy |
|
from collections import Counter |
|
import torch |
|
from zipfile import ZipFile |
|
import pickle |
|
from io import BytesIO |
|
|
|
from .match_groups import MatchGroups |
|
|
|
|
|
class Embeddings(torch.nn.Module): |
|
""" |
|
Stores embeddings for a fixed array of strings and provides methods for |
|
clustering the strings to create MatchGroups objects according to different |
|
algorithms. |
|
""" |
|
def __init__(self,strings,V,score_model,weighting_function,counts,device='cpu'): |
|
super().__init__() |
|
|
|
self.strings = np.array(list(strings)) |
|
self.string_map = {s:i for i,s in enumerate(strings)} |
|
self.V = V |
|
self.counts = counts |
|
self.w = weighting_function(counts) |
|
self.score_model = score_model |
|
self.weighting_function = weighting_function |
|
self.device = device |
|
|
|
self.to(device) |
|
|
|
def __repr__(self): |
|
return f'<nama.Embeddings containing {self.V.shape[1]}-d vectors for {len(self)} strings' |
|
|
|
def to(self,device): |
|
super().to(device) |
|
self.V = self.V.to(device) |
|
self.counts = self.counts.to(device) |
|
self.w = self.w.to(device) |
|
self.score_model.to(device) |
|
self.device = device |
|
|
|
def save(self,f): |
|
""" |
|
Save embeddings in a simple custom zipped archive format (torch.save |
|
works too, but it requires huge amounts of memory to serialize large |
|
embeddings objects). |
|
""" |
|
with ZipFile(f,'w') as zip: |
|
|
|
|
|
zip.writestr('score_model.pkl',pickle.dumps(self.score_model)) |
|
|
|
|
|
zip.writestr('weighting_function.pkl',pickle.dumps(self.weighting_function)) |
|
|
|
|
|
strings_df = pd.DataFrame().assign( |
|
string=self.strings, |
|
count=self.counts.to('cpu').numpy()) |
|
zip.writestr('strings.csv',strings_df.to_csv(index=False)) |
|
|
|
|
|
byte_io = BytesIO() |
|
np.save(byte_io,self.V.to('cpu').numpy(),allow_pickle=False) |
|
zip.writestr('V.npy',byte_io.getvalue()) |
|
|
|
def __getitem__(self,arg): |
|
""" |
|
Slice a Match Groups object |
|
""" |
|
if isinstance(arg,slice): |
|
i = arg |
|
elif isinstance(arg, MatchGroups): |
|
return self[arg.strings()] |
|
elif hasattr(arg,'__iter__'): |
|
|
|
string_map = self.string_map |
|
i = [string_map[s] for s in arg] |
|
|
|
if i == list(range(len(self))): |
|
|
|
return copy(self) |
|
else: |
|
raise ValueError(f'Unknown slice input type ({type(input)}). Can only slice Embedding with a slice, match group, or iterable.') |
|
|
|
new = copy(self) |
|
new.strings = self.strings[i] |
|
new.V = self.V[i] |
|
new.counts = self.counts[i] |
|
new.w = self.w[i] |
|
new.string_map = {s:i for i,s in enumerate(new.strings)} |
|
|
|
return new |
|
|
|
def embed(self,grouping): |
|
""" |
|
Construct updated Embeddings with counts from the input MatchGroups |
|
""" |
|
new = self[grouping] |
|
new.counts = torch.tensor([grouping.counts[s] for s in new.strings],device=self.device) |
|
new.w = new.weighting_function(new.counts) |
|
|
|
return new |
|
|
|
def __len__(self): |
|
return len(self.strings) |
|
|
|
def _group_to_ids(self,grouping): |
|
group_id_map = {g:i for i,g in enumerate(grouping.groups.keys())} |
|
group_ids = torch.tensor([group_id_map[grouping[s]] for s in self.strings]).to(self.device) |
|
return group_ids |
|
|
|
def _ids_to_group(self,group_ids): |
|
if isinstance(group_ids,torch.Tensor): |
|
group_ids = group_ids.to('cpu').numpy() |
|
|
|
strings = self.strings |
|
counts = self.counts.to('cpu').numpy() |
|
|
|
|
|
g_sort = np.lexsort((counts,group_ids)) |
|
group_ids = group_ids[g_sort] |
|
strings = strings[g_sort] |
|
counts = counts[g_sort] |
|
|
|
|
|
split_locs = np.nonzero(group_ids[1:] != group_ids[:-1])[0] + 1 |
|
|
|
|
|
groups = np.split(strings,split_locs) |
|
|
|
|
|
grouping = MatchGroups() |
|
grouping.counts = Counter({s:int(c) for s,c in zip(strings,counts)}) |
|
grouping.labels = {s:g[-1] for g in groups for s in g} |
|
grouping.groups = {g[-1]:list(g) for g in groups} |
|
|
|
return grouping |
|
|
|
@torch.no_grad() |
|
def _fast_unite_similar(self,group_ids,threshold=0.5,progress_bar=True,batch_size=64): |
|
|
|
V = self.V |
|
cos_threshold = self.score_model.score_to_cos(threshold) |
|
|
|
for batch_start in tqdm(range(0,len(self),batch_size), |
|
delay=1,desc='Predicting matches',disable=not progress_bar): |
|
|
|
i_slice = slice(batch_start,batch_start+batch_size) |
|
j_slice = slice(batch_start+1,None) |
|
|
|
g_i = group_ids[i_slice] |
|
g_j = group_ids[j_slice] |
|
|
|
|
|
batch_matched = (V[i_slice]@V[j_slice].T >= cos_threshold) \ |
|
* (g_i[:,None] != g_j[None,:]) |
|
|
|
for k,matched in enumerate(batch_matched): |
|
if matched.any(): |
|
|
|
matched_groups = g_j[matched] |
|
|
|
|
|
ids_to_group = torch.isin(group_ids,matched_groups) |
|
|
|
|
|
group_ids[ids_to_group] = g_i[k].clone() |
|
|
|
return self._ids_to_group(group_ids) |
|
|
|
@torch.no_grad() |
|
def unite_similar(self, |
|
threshold=0.5, |
|
group_threshold=None, |
|
always_match=None, |
|
never_match=None, |
|
batch_size=64, |
|
progress_bar=True, |
|
always_never_conflicts='warn', |
|
return_united=False): |
|
|
|
""" |
|
Unite embedding strings according to predicted pairwise similarity. |
|
|
|
- "theshold" sets the minimimum match similarity required to unite two strings. |
|
- Note that strings with similarity<threshold can end up matched if they are |
|
linked by a chain of sufficiently similar strings (matching is transitive). |
|
"group_threshold" can be used to add an additional constraing on the minimum |
|
similarity within each group. |
|
- "group_threshold" sets the minimum similarity required within a single group. |
|
- "always_match" takes any argument that can be used to unite strings. These |
|
strings will always be matched. |
|
- "never_match" takes a set, or a list of sets, where each set indicates two or |
|
more strings that should never be united with each other (these strings may |
|
still be united with other strings). |
|
- "always_never_conflicts" determines how to handle conflicts between |
|
"always_match" and "never_match": |
|
- always_never_conflicts="warn": Check for conflicts and print a warning |
|
if any are found (default) |
|
- always_never_conflicts="raise": Check for conflicts and raise an error |
|
if any are found |
|
- always_never_conflicts="ignore": Do not check for conflicts ("always_match" |
|
will take precedence) |
|
|
|
If "group_threshold" or "never_match" arguments are supplied, strings pairs are |
|
united in order of similarity. Highest similarity strings are matched first, and |
|
before each time a new pair of strings is united, the function checks if this will |
|
result in grouping any two strings with similarity<group_threshold. If so, this |
|
pair is skipped. This version of the algorithm requires more memory and processing |
|
time, but guaruntees deterministic output that is consistent with the constraints. |
|
|
|
returns: MatchGroups object |
|
""" |
|
if group_threshold and group_threshold < threshold: |
|
raise ValueError('group_threshold must be greater than or equal to threshold') |
|
|
|
group_ids = torch.arange(len(self)).to(self.device) |
|
|
|
if always_match is not None: |
|
always_grouping = (MatchGroups(self.strings) |
|
.unite(always_match)) |
|
always_match_labels = always_grouping.labels |
|
|
|
|
|
|
|
if not (return_united or group_threshold or (never_match is not None)): |
|
if always_match is not None: |
|
group_ids = self._group_to_ids(always_grouping) |
|
|
|
return self._fast_unite_similar( |
|
group_ids=group_ids, |
|
threshold=threshold, |
|
batch_size=batch_size, |
|
progress_bar=progress_bar) |
|
|
|
if never_match is not None: |
|
|
|
if all(isinstance(s,str) for s in never_match): |
|
never_match = [never_match] |
|
|
|
if always_match is not None: |
|
|
|
assert always_never_conflicts in ['raise','warn','ignore'] |
|
|
|
if always_never_conflicts != 'ignore': |
|
|
|
|
|
conflicts = [] |
|
for i,g in enumerate(never_match): |
|
g = sorted(list(g)) |
|
g_labels = [always_match_labels.get(s,s) for s in g] |
|
if len(set(g_labels)) < len(g): |
|
df = (pd.DataFrame() |
|
.assign( |
|
string=g, |
|
never_match_group=i, |
|
always_match_group=g_labels |
|
)) |
|
conflicts.append(df) |
|
|
|
if conflicts: |
|
conflicts_df = pd.concat(conflicts) |
|
|
|
if always_never_conflicts == 'warn': |
|
print(f'Warning: The following never_match groups are in conflict with always_match groups:\n{conflicts_df}') |
|
print('Conflicted never_match relationships will be ignored') |
|
else: |
|
raise ValueError(f'The following never_match groups are in conflict with always_match groups\n{conflicts_df}') |
|
|
|
|
|
|
|
|
|
never_match = [{always_match_labels[s] for s in g if s in always_match_labels} for g in never_match] |
|
|
|
else: |
|
|
|
never_match = [set(s) for s in never_match] |
|
|
|
|
|
V = self.V |
|
cos_threshold = self.score_model.score_to_cos(threshold) |
|
if group_threshold is not None: |
|
separate_cos = self.score_model.score_to_cos(group_threshold) |
|
|
|
|
|
matches = [] |
|
cos_scores = [] |
|
for batch_start in tqdm(range(0,len(self),batch_size), |
|
desc='Scoring pairs', |
|
delay=1,disable=not progress_bar): |
|
|
|
i_slice = slice(batch_start,batch_start+batch_size) |
|
j_slice = slice(batch_start+1,None) |
|
|
|
|
|
batch_cos = V[i_slice]@V[j_slice].T |
|
|
|
|
|
|
|
batch_cos = torch.triu(batch_cos) |
|
|
|
bi,bj = torch.nonzero(batch_cos >= cos_threshold,as_tuple=True) |
|
|
|
if len(bi): |
|
|
|
i = bi + batch_start |
|
j = bj + batch_start + 1 |
|
|
|
cos = batch_cos[bi,bj] |
|
|
|
|
|
unmatched = group_ids[i] != group_ids[j] |
|
i = i[unmatched] |
|
j = j[unmatched] |
|
cos = cos[unmatched] |
|
|
|
if len(i): |
|
batch_matches = torch.hstack([i[:,None],j[:,None]]) |
|
|
|
matches.append(batch_matches.to('cpu').numpy()) |
|
cos_scores.append(cos.to('cpu').numpy()) |
|
|
|
|
|
|
|
united = [] |
|
if matches: |
|
matches = np.vstack(matches) |
|
cos_scores = np.hstack(cos_scores).T |
|
|
|
|
|
m_sort = cos_scores.argsort()[::-1] |
|
matches = matches[m_sort] |
|
|
|
if return_united: |
|
|
|
cos_scores_df = pd.DataFrame(matches,columns=['i','j']) |
|
cos_scores_df['cos'] = cos_scores[m_sort] |
|
|
|
|
|
matches = torch.tensor(matches).to(self.device) |
|
|
|
|
|
if never_match is not None: |
|
never_match_map = {s:sep for sep in never_match for s in sep} |
|
|
|
if always_match is not None: |
|
|
|
never_match_array = np.array([never_match_map.get(always_match_labels[s],set()) for s in self.strings]) |
|
else: |
|
never_match_array = np.array([never_match_map.get(s,set()) for s in self.strings]) |
|
|
|
|
|
n_matches = matches.shape[0] |
|
with tqdm(total=n_matches,desc='Uniting matches', |
|
delay=1,disable=not progress_bar) as p_bar: |
|
|
|
while len(matches): |
|
|
|
|
|
match_pair = matches[0] |
|
matches = matches[1:] |
|
|
|
|
|
g = group_ids[match_pair] |
|
g0 = group_ids == g[0] |
|
g1 = group_ids == g[1] |
|
|
|
|
|
to_unite = g0 | g1 |
|
|
|
|
|
singletons = to_unite.sum() < 3 |
|
|
|
|
|
unite_ok = True |
|
|
|
|
|
if never_match is not None: |
|
never_0 = never_match_array[match_pair[0]] |
|
never_1 = never_match_array[match_pair[1]] |
|
|
|
if never_0 and never_1 and (never_0 & never_1): |
|
|
|
|
|
unite_ok = False |
|
|
|
|
|
|
|
if unite_ok and group_threshold and not singletons: |
|
V0 = V[g0,:] |
|
V1 = V[g1,:] |
|
|
|
unite_ok = (V0@V1.T).min() >= separate_cos |
|
|
|
|
|
if unite_ok: |
|
|
|
|
|
group_ids[to_unite] = g[0] |
|
|
|
if never_match and (never_0 or never_1): |
|
|
|
never_match_array[to_unite.detach().cpu().numpy()] = never_0 | never_1 |
|
|
|
|
|
|
|
if not singletons: |
|
|
|
matches = matches[group_ids[matches[:,0]] != group_ids[matches[:,1]]] |
|
|
|
if return_united: |
|
match_record = np.empty(4,dtype=int) |
|
match_record[:2] = match_pair.cpu().numpy().ravel() |
|
match_record[2] = self.counts[g0].sum().item() |
|
match_record[3] = self.counts[g1].sum().item() |
|
|
|
united.append(match_record) |
|
else: |
|
|
|
matches = matches[torch.isin(group_ids[matches[:,0]],g,invert=True) \ |
|
| torch.isin(group_ids[matches[:,1]],g,invert=True)] |
|
|
|
|
|
p_bar.update(n_matches - matches.shape[0]) |
|
n_matches = matches.shape[0] |
|
|
|
predicted_grouping = self.ids_to_group(group_ids) |
|
|
|
if always_match is not None: |
|
predicted_grouping = predicted_grouping.unite(always_grouping) |
|
|
|
if return_united: |
|
united_df = pd.DataFrame(np.vstack(united),columns=['i','j','n_i','n_j']) |
|
united_df = pd.merge(united_df,cos_scores_df,how='inner',on=['i','j']) |
|
united_df['score'] = self.score_model( |
|
torch.tensor(united_df['cos'].values).to(self.device) |
|
).cpu().numpy() |
|
|
|
united_df = united_df.drop('cos',axis=1) |
|
|
|
for c in ['i','j']: |
|
united_df[c] = [self.strings[i] for i in united_df[c]] |
|
|
|
if always_match is not None: |
|
united_df['always_match'] = [always_grouping[i] == always_grouping[j] |
|
for i,j in united_df[['i','j']].values] |
|
|
|
return predicted_grouping,united_df |
|
|
|
else: |
|
|
|
return predicted_grouping |
|
|
|
@torch.no_grad() |
|
def unite_nearest(self,target_strings,threshold=0,always_grouping=None,progress_bar=True,batch_size=64): |
|
""" |
|
Unite embedding strings with each string's most similar target string. |
|
|
|
- "always_grouping" will be used to inialize the group_ids before uniting new matches |
|
- "theshold" sets the minimimum match similarity required between a string and target string |
|
for the string to be matched. (i.e., setting theshold=0 will result in every embedding |
|
string to be matched its nearest target string, while setting threshold=0.9 will leave |
|
strings that have similarity<0.9 with their nearest target string unaffected) |
|
|
|
returns: MatchGroups object |
|
""" |
|
|
|
if always_grouping is not None: |
|
|
|
group_ids = self._group_to_ids(always_grouping) |
|
else: |
|
group_ids = torch.arange(len(self)).to(self.device) |
|
|
|
V = self.V |
|
cos_threshold = self.score_model.score_to_cos(threshold) |
|
|
|
seed_ids = torch.tensor([self.string_map[s] for s in target_strings]).to(self.device) |
|
V_seed = V[seed_ids] |
|
g_seed = group_ids[seed_ids] |
|
is_seed = torch.zeros(V.shape[0],dtype=torch.bool).to(self.device) |
|
is_seed[g_seed] = True |
|
|
|
for batch_start in tqdm(range(0,len(self),batch_size), |
|
delay=1,desc='Predicting matches',disable=not progress_bar): |
|
|
|
batch_slice = slice(batch_start,batch_start+batch_size) |
|
|
|
batch_cos = V[batch_slice]@V_seed.T |
|
|
|
max_cos,max_seed = torch.max(batch_cos,dim=1) |
|
|
|
|
|
batch_i = torch.nonzero(max_cos > cos_threshold) |
|
|
|
if len(batch_i): |
|
|
|
|
|
batch_i = batch_i[~is_seed[batch_slice][batch_i]] |
|
|
|
if len(batch_i): |
|
|
|
i = batch_i + batch_start |
|
|
|
|
|
group_ids[i] = g_seed[max_seed[batch_i]] |
|
|
|
return self._ids_to_group(group_ids) |
|
|
|
@torch.no_grad() |
|
def score_pairs(self,string_pairs,batch_size=64,progress_bar=True): |
|
string_pairs = np.array(string_pairs) |
|
|
|
scores = [] |
|
for batch_start in tqdm(range(0,string_pairs.shape[0],batch_size),desc='Scoring pairs',disable=not progress_bar): |
|
|
|
V0 = self[string_pairs[batch_start:batch_start+batch_size,0]].V |
|
V1 = self[string_pairs[batch_start:batch_start+batch_size,1]].V |
|
|
|
batch_cos = (V0*V1).sum(dim=1).ravel() |
|
batch_scores = self.score_model(batch_cos) |
|
|
|
scores.append(batch_scores.cpu().numpy()) |
|
|
|
return np.concatenate(scores) |
|
|
|
@torch.no_grad() |
|
def _batch_scores(self,group_ids,batch_start,batch_size, |
|
is_match=None, |
|
min_score=None,max_score=None, |
|
min_loss=None,max_loss=None): |
|
|
|
strings = self.strings |
|
V = self.V |
|
w = self.w |
|
|
|
|
|
i_slice = slice(batch_start,batch_start+batch_size) |
|
j_slice = slice(batch_start+1,None) |
|
|
|
X = V[i_slice]@V[j_slice].T |
|
Y = (group_ids[i_slice,None] == group_ids[None,j_slice]).float() |
|
if w is not None: |
|
W = w[i_slice,None]*w[None,j_slice] |
|
else: |
|
W = None |
|
|
|
scores = self.score_model(X) |
|
loss = self.score_model.loss(X,Y,weights=W) |
|
|
|
|
|
|
|
scores = torch.triu(scores) |
|
|
|
|
|
if is_match is not None: |
|
if is_match: |
|
scores *= Y |
|
else: |
|
scores *= (1 - Y) |
|
|
|
|
|
if min_score is not None: |
|
scores *= (scores >= min_score) |
|
|
|
|
|
if max_score is not None: |
|
scores *= (scores <= max_score) |
|
|
|
|
|
if min_loss is not None: |
|
scores *= (loss >= min_loss) |
|
|
|
|
|
if max_loss is not None: |
|
scores *= (loss <= max_loss) |
|
|
|
|
|
i,j = torch.nonzero(scores,as_tuple=True) |
|
|
|
pairs = np.hstack([ |
|
strings[i.cpu().numpy() + batch_start][:,None], |
|
strings[j.cpu().numpy() + (batch_start + 1)][:,None] |
|
]) |
|
|
|
pair_groups = np.hstack([ |
|
strings[group_ids[i + batch_start].cpu().numpy()][:,None], |
|
strings[group_ids[j + (batch_start + 1)].cpu().numpy()][:,None] |
|
]) |
|
|
|
pair_scores = scores[i,j].cpu().numpy() |
|
pair_losses = loss[i,j].cpu().numpy() |
|
|
|
return pairs,pair_groups,pair_scores,pair_losses |
|
|
|
def iter_scores(self,grouping=None,batch_size=64,progress_bar=True,**kwargs): |
|
|
|
if grouping is not None: |
|
self = self.embed(grouping) |
|
group_ids = self._group_to_ids(grouping) |
|
else: |
|
group_ids = torch.arange(len(self)).to(self.device) |
|
|
|
for batch_start in tqdm(range(0,len(self),batch_size),desc='Scoring pairs',disable=not progress_bar): |
|
pairs,pair_groups,scores,losses = self._batch_scored_pairs(self,group_ids,batch_start,batch_size,**kwargs) |
|
for (s0,s1),(g0,g1),score,loss in zip(pairs,pair_groups,scores,losses): |
|
yield { |
|
'string0':s0, |
|
'string1':s1, |
|
'group0':g0, |
|
'group1':g1, |
|
'score':score, |
|
'loss':loss, |
|
} |
|
|
|
|
|
def load_embeddings(f): |
|
""" |
|
Load embeddings from custom zipped archive format |
|
""" |
|
with ZipFile(f,'r') as zip: |
|
score_model = pickle.loads(zip.read('score_model.pkl')) |
|
weighting_function = pickle.loads(zip.read('weighting_function.pkl')) |
|
strings_df = pd.read_csv(zip.open('strings.csv'),na_filter=False) |
|
V = np.load(zip.open('V.npy')) |
|
|
|
return Embeddings( |
|
strings=strings_df['string'].values, |
|
counts=torch.tensor(strings_df['count'].values), |
|
score_model=score_model, |
|
weighting_function=weighting_function, |
|
V=torch.tensor(V) |
|
) |
|
|
|
|