File size: 2,599 Bytes
6680682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import *

import numpy as np
import torch
from scipy.optimize import linear_sum_assignment
from torch.nn.utils.rnn import pad_sequence


def num2mask(
        nums: torch.Tensor,
        max_length: Optional[int] = None
) -> torch.Tensor:
    """
    E.g. input a tensor [2, 3, 4], return [[T T F F], [T T T F], [T T T T]]
    :param nums: Shape [batch]
    :param max_length: maximum length. if not provided, will choose the largest number from nums.
    :return: 2D binary mask.
    """
    shape_backup = nums.shape
    nums = nums.flatten()
    max_length = max_length or int(nums.max())
    batch_size = len(nums)
    range_nums = torch.arange(0, max_length, device=nums.device).unsqueeze(0).expand([batch_size, max_length])
    ret = (range_nums.T < nums).T
    return ret.reshape(*shape_backup, max_length)


def mask2idx(
        mask: torch.Tensor,
        max_length: Optional[int] = None,
        padding_value: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    E.g. input a tensor [[T T F F], [T T T F], [F F F T]] with padding value -1,
    return [[0, 1, -1], [0, 1, 2], [3, -1, -1]]
    :param mask: Mask tensor. Boolean. Not necessarily to be 2D.
    :param max_length: If provided, will truncate.
    :param padding_value: Padding value. Default to 0.
    :return: Index tensor.
    """
    shape_prefix, mask_length = mask.shape[:-1], mask.shape[-1]
    flat_mask = mask.flatten(0, -2)
    index_list = [torch.arange(mask_length, device=mask.device)[one_mask] for one_mask in flat_mask.unbind(0)]
    index_tensor = pad_sequence(index_list, batch_first=True, padding_value=padding_value)
    if max_length is not None:
        index_tensor = index_tensor[:, :max_length]
    index_tensor = index_tensor.reshape(*shape_prefix, -1)
    return index_tensor, mask.sum(-1)


def one_hot(tags: torch.Tensor, num_tags: Optional[int] = None) -> torch.Tensor:
    num_tags = num_tags or int(tags.max())
    ret = tags.new_zeros(size=[*tags.shape, num_tags], dtype=torch.bool)
    ret.scatter_(2, tags.unsqueeze(2), tags.new_ones([*tags.shape, 1], dtype=torch.bool))
    return ret


def numpy2torch(
        dict_obj: dict
) -> dict:
    """
    Convert list/np.ndarray data to torch.Tensor and add add a batch dim.
    """
    ret = dict()
    for k, v in dict_obj.items():
        if isinstance(v, list) or isinstance(v, np.ndarray):
            ret[k] = torch.tensor(v).unsqueeze(0)
        else:
            ret[k] = v
    return ret


def max_match(mat: np.ndarray):
    row_idx, col_idx = linear_sum_assignment(mat, True)
    return mat[row_idx, col_idx].sum()