File size: 3,542 Bytes
d4b77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# --------------------------------------------------------
# SiamMask
# Licensed under The MIT License
# Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
# --------------------------------------------------------
import numpy as np


class Meter(object):
    def __init__(self, name, val, avg):
        self.name = name
        self.val = val
        self.avg = avg

    def __repr__(self):
        return "{name}: {val:.6f} ({avg:.6f})".format(
            name=self.name, val=self.val, avg=self.avg
        )

    def __format__(self, *tuples, **kwargs):
        return self.__repr__()


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = {}
        self.sum = {}
        self.count = {}

    def update(self, batch=1, **kwargs):
        val = {}
        for k in kwargs:
            val[k] = kwargs[k] / float(batch)
        self.val.update(val)
        for k in kwargs:
            if k not in self.sum:
                self.sum[k] = 0
                self.count[k] = 0
            self.sum[k] += kwargs[k]
            self.count[k] += batch

    def __repr__(self):
        s = ''
        for k in self.sum:
            s += self.format_str(k)
        return s

    def format_str(self, attr):
        return "{name}: {val:.6f} ({avg:.6f}) ".format(
                    name=attr,
                    val=float(self.val[attr]),
                    avg=float(self.sum[attr]) / self.count[attr])

    def __getattr__(self, attr):
        if attr in self.__dict__:
            return super(AverageMeter, self).__getattr__(attr)
        if attr not in self.sum:
            # logger.warn("invalid key '{}'".format(attr))
            print("invalid key '{}'".format(attr))
            return Meter(attr, 0, 0)
        return Meter(attr, self.val[attr], self.avg(attr))

    def avg(self, attr):
        return float(self.sum[attr]) / self.count[attr]


class IouMeter(object):
    def __init__(self, thrs, sz):
        self.sz = sz
        self.iou = np.zeros((sz, len(thrs)), dtype=np.float32)
        self.thrs = thrs
        self.reset()

    def reset(self):
        self.iou.fill(0.)
        self.n = 0

    def add(self, output, target):
        if self.n >= len(self.iou):
            return
        target, output = target.squeeze(), output.squeeze()
        for i, thr in enumerate(self.thrs):
            pred = output > thr
            mask_sum = (pred == 1).astype(np.uint8) + (target > 0).astype(np.uint8)
            intxn = np.sum(mask_sum == 2)
            union = np.sum(mask_sum > 0)
            if union > 0:
                self.iou[self.n, i] = intxn / union
            elif union == 0 and intxn == 0:
                self.iou[self.n, i] = 1
        self.n += 1

    def value(self, s):
        nb = max(int(np.sum(self.iou > 0)), 1)
        iou = self.iou[:nb]

        def is_number(s):
            try:
                float(s)
                return True
            except ValueError:
                return False
        if s == 'mean':
            res = np.mean(iou, axis=0)
        elif s == 'median':
            res = np.median(iou, axis=0)
        elif is_number(s):
            res = np.sum(iou > float(s), axis=0) / float(nb)
        return res


if __name__ == '__main__':
    avg = AverageMeter()
    avg.update(time=1.1, accuracy=.99)
    avg.update(time=1.0, accuracy=.90)

    print(avg)

    print(avg.time)
    print(avg.time.avg)
    print(avg.time.val)
    print(avg.SS)