File size: 8,306 Bytes
4d1ebf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import torch
from typing import List

class KeyValueMemoryStore:
    """
    Works for key/value pairs type storage
    e.g., working and long-term memory
    """

    """
    An object group is created when new objects enter the video
    Objects in the same group share the same temporal extent
    i.e., objects initialized in the same frame are in the same group
    For DAVIS/interactive, there is only one object group
    For YouTubeVOS, there can be multiple object groups
    """

    def __init__(self, count_usage: bool):
        self.count_usage = count_usage

        # keys are stored in a single tensor and are shared between groups/objects
        # values are stored as a list indexed by object groups
        self.k = None
        self.v = []
        self.obj_groups = []
        # for debugging only
        self.all_objects = []

        # shrinkage and selection are also single tensors
        self.s = self.e = None

        # usage
        if self.count_usage:
            self.use_count = self.life_count = None

    def add(self, key, value, shrinkage, selection, objects: List[int]):
        new_count = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32)
        new_life = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + 1e-7

        # add the key
        if self.k is None:
            self.k = key
            self.s = shrinkage
            self.e = selection
            if self.count_usage:
                self.use_count = new_count
                self.life_count = new_life
        else:
            self.k = torch.cat([self.k, key], -1)
            if shrinkage is not None:
                self.s = torch.cat([self.s, shrinkage], -1)
            if selection is not None:
                self.e = torch.cat([self.e, selection], -1)
            if self.count_usage:
                self.use_count = torch.cat([self.use_count, new_count], -1)
                self.life_count = torch.cat([self.life_count, new_life], -1)

        # add the value
        if objects is not None:
            # When objects is given, v is a tensor; used in working memory
            assert isinstance(value, torch.Tensor)
            # First consume objects that are already in the memory bank
            # cannot use set here because we need to preserve order
            # shift by one as background is not part of value
            remaining_objects = [obj-1 for obj in objects]
            for gi, group in enumerate(self.obj_groups):
                for obj in group:
                    # should properly raise an error if there are overlaps in obj_groups
                    remaining_objects.remove(obj)
                self.v[gi] = torch.cat([self.v[gi], value[group]], -1)

            # If there are remaining objects, add them as a new group
            if len(remaining_objects) > 0:
                new_group = list(remaining_objects)
                self.v.append(value[new_group])
                self.obj_groups.append(new_group)
                self.all_objects.extend(new_group)
                
                assert sorted(self.all_objects) == self.all_objects, 'Objects MUST be inserted in sorted order '
        else:
            # When objects is not given, v is a list that already has the object groups sorted
            # used in long-term memory
            assert isinstance(value, list)
            for gi, gv in enumerate(value):
                if gv is None:
                    continue
                if gi < self.num_groups:
                    self.v[gi] = torch.cat([self.v[gi], gv], -1)
                else:
                    self.v.append(gv)

    def update_usage(self, usage):
        # increase all life count by 1
        # increase use of indexed elements
        if not self.count_usage:
            return
        
        self.use_count += usage.view_as(self.use_count)
        self.life_count += 1

    def sieve_by_range(self, start: int, end: int, min_size: int):
        # keep only the elements *outside* of this range (with some boundary conditions)
        # i.e., concat (a[:start], a[end:])
        # min_size is only used for values, we do not sieve values under this size
        # (because they are not consolidated)

        if end == 0:
            # negative 0 would not work as the end index!
            self.k = self.k[:,:,:start]
            if self.count_usage:
                self.use_count = self.use_count[:,:,:start]
                self.life_count = self.life_count[:,:,:start]
            if self.s is not None:
                self.s = self.s[:,:,:start]
            if self.e is not None:
                self.e = self.e[:,:,:start]
            
            for gi in range(self.num_groups):
                if self.v[gi].shape[-1] >= min_size:
                    self.v[gi] = self.v[gi][:,:,:start]
        else:
            self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
            if self.count_usage:
                self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
                self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
            if self.s is not None:
                self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
            if self.e is not None:
                self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
            
            for gi in range(self.num_groups):
                if self.v[gi].shape[-1] >= min_size:
                    self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)

    def remove_obsolete_features(self, max_size: int):
        # normalize with life duration
        usage = self.get_usage().flatten()

        values, _ = torch.topk(usage, k=(self.size-max_size), largest=False, sorted=True)
        survived = (usage > values[-1])

        self.k = self.k[:, :, survived]
        self.s = self.s[:, :, survived] if self.s is not None else None
        # Long-term memory does not store ek so this should not be needed
        self.e = self.e[:, :, survived] if self.e is not None else None
        if self.num_groups > 1:
            raise NotImplementedError("""The current data structure does not support feature removal with 
            multiple object groups (e.g., some objects start to appear later in the video)
            The indices for "survived" is based on keys but not all values are present for every key
            Basically we need to remap the indices for keys to values
            """)
        for gi in range(self.num_groups):
            self.v[gi] = self.v[gi][:, :, survived]

        self.use_count = self.use_count[:, :, survived]
        self.life_count = self.life_count[:, :, survived]

    def get_usage(self):
        # return normalized usage
        if not self.count_usage:
            raise RuntimeError('I did not count usage!')
        else:
            usage = self.use_count / self.life_count
            return usage

    def get_all_sliced(self, start: int, end: int):
        # return k, sk, ek, usage in order, sliced by start and end

        if end == 0:
            # negative 0 would not work as the end index!
            k = self.k[:,:,start:]
            sk = self.s[:,:,start:] if self.s is not None else None
            ek = self.e[:,:,start:] if self.e is not None else None
            usage = self.get_usage()[:,:,start:]
        else:
            k = self.k[:,:,start:end]
            sk = self.s[:,:,start:end] if self.s is not None else None
            ek = self.e[:,:,start:end] if self.e is not None else None
            usage = self.get_usage()[:,:,start:end]

        return k, sk, ek, usage

    def get_v_size(self, ni: int):
        return self.v[ni].shape[2]

    def engaged(self):
        return self.k is not None

    @property
    def size(self):
        if self.k is None:
            return 0
        else:
            return self.k.shape[-1]

    @property
    def num_groups(self):
        return len(self.v)

    @property
    def key(self):
        return self.k

    @property
    def value(self):
        return self.v

    @property
    def shrinkage(self):
        return self.s

    @property
    def selection(self):
        return self.e