|
|
|
|
|
|
|
import torch |
|
from scipy.optimize import linear_sum_assignment |
|
|
|
|
|
@torch.no_grad() |
|
def match_label(target, matching, shape, device, compute_mask=True): |
|
idx = _get_src_permutation_idx(matching) |
|
|
|
target_classes = torch.zeros(shape, dtype=torch.long, device=device) |
|
target_classes[idx] = torch.cat([t[J] for t, (_, J) in zip(target, matching)]) |
|
|
|
return target_classes |
|
|
|
|
|
@torch.no_grad() |
|
def match_anchor(anchor, matching, shape, device): |
|
target, _ = anchor |
|
|
|
idx = _get_src_permutation_idx(matching) |
|
target_classes = torch.zeros(shape, dtype=torch.long, device=device) |
|
target_classes[idx] = torch.cat([t[J, :] for t, (_, J) in zip(target, matching)]) |
|
|
|
matched_mask = torch.ones(shape[:2], dtype=torch.bool, device=device) |
|
matched_mask[idx] = False |
|
|
|
return target_classes, matched_mask |
|
|
|
|
|
def _get_src_permutation_idx(indices): |
|
|
|
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) |
|
src_idx = torch.cat([src for (src, _) in indices]) |
|
return batch_idx, src_idx |
|
|
|
|
|
@torch.no_grad() |
|
def get_matching(cost_matrices): |
|
output = [] |
|
for cost_matrix in cost_matrices: |
|
indices = linear_sum_assignment(cost_matrix, maximize=True) |
|
indices = (torch.tensor(indices[0], dtype=torch.long), torch.tensor(indices[1], dtype=torch.long)) |
|
output.append(indices) |
|
|
|
return output |
|
|
|
|
|
def sort_by_target(matchings): |
|
new_matching = [] |
|
for matching in matchings: |
|
source, target = matching |
|
target, indices = target.sort() |
|
source = source[indices] |
|
new_matching.append((source, target)) |
|
return new_matching |
|
|
|
|
|
def reorder(hidden, matchings, max_length): |
|
batch_size, _, hidden_dim = hidden.shape |
|
matchings = sort_by_target(matchings) |
|
|
|
result = torch.zeros(batch_size, max_length, hidden_dim, device=hidden.device) |
|
for b in range(batch_size): |
|
indices = matchings[b][0] |
|
result[b, : len(indices), :] = hidden[b, indices, :] |
|
|
|
return result |
|
|