ssa-perin / utility /hungarian_matching.py
larkkin's picture
Add code and readme
c45d283
raw
history blame
No virus
2.08 kB
#!/usr/bin/env python3
# coding=utf-8
import torch
from scipy.optimize import linear_sum_assignment
@torch.no_grad()
def match_label(target, matching, shape, device, compute_mask=True):
idx = _get_src_permutation_idx(matching)
target_classes = torch.zeros(shape, dtype=torch.long, device=device)
target_classes[idx] = torch.cat([t[J] for t, (_, J) in zip(target, matching)])
return target_classes
@torch.no_grad()
def match_anchor(anchor, matching, shape, device):
target, _ = anchor
idx = _get_src_permutation_idx(matching)
target_classes = torch.zeros(shape, dtype=torch.long, device=device)
target_classes[idx] = torch.cat([t[J, :] for t, (_, J) in zip(target, matching)])
matched_mask = torch.ones(shape[:2], dtype=torch.bool, device=device)
matched_mask[idx] = False
return target_classes, matched_mask
def _get_src_permutation_idx(indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
@torch.no_grad()
def get_matching(cost_matrices):
output = []
for cost_matrix in cost_matrices:
indices = linear_sum_assignment(cost_matrix, maximize=True)
indices = (torch.tensor(indices[0], dtype=torch.long), torch.tensor(indices[1], dtype=torch.long))
output.append(indices)
return output
def sort_by_target(matchings):
new_matching = []
for matching in matchings:
source, target = matching
target, indices = target.sort()
source = source[indices]
new_matching.append((source, target))
return new_matching
def reorder(hidden, matchings, max_length):
batch_size, _, hidden_dim = hidden.shape
matchings = sort_by_target(matchings)
result = torch.zeros(batch_size, max_length, hidden_dim, device=hidden.device)
for b in range(batch_size):
indices = matchings[b][0]
result[b, : len(indices), :] = hidden[b, indices, :]
return result