|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class BufferList(nn.Module): |
|
|
|
def __init__(self, buffers): |
|
super(BufferList, self).__init__() |
|
for i, buffer in enumerate(buffers): |
|
self.register_buffer(str(i), buffer, persistent=False) |
|
|
|
def __len__(self): |
|
return len(self._buffers) |
|
|
|
def __iter__(self): |
|
return iter(self._buffers.values()) |
|
|
|
|
|
class PointGenerator(nn.Module): |
|
|
|
def __init__(self, strides, buffer_size, offset=False): |
|
super(PointGenerator, self).__init__() |
|
|
|
reg_range, last = [], 0 |
|
for stride in strides[1:]: |
|
reg_range.append((last, stride)) |
|
last = stride |
|
reg_range.append((last, float('inf'))) |
|
|
|
self.strides = strides |
|
self.reg_range = reg_range |
|
self.buffer_size = buffer_size |
|
self.offset = offset |
|
|
|
self.buffer = self._cache_points() |
|
|
|
def _cache_points(self): |
|
buffer_list = [] |
|
for stride, reg_range in zip(self.strides, self.reg_range): |
|
reg_range = torch.Tensor([reg_range]) |
|
lv_stride = torch.Tensor([stride]) |
|
points = torch.arange(0, self.buffer_size, stride)[:, None] |
|
if self.offset: |
|
points += 0.5 * stride |
|
reg_range = reg_range.repeat(points.size(0), 1) |
|
lv_stride = lv_stride.repeat(points.size(0), 1) |
|
buffer_list.append(torch.cat((points, reg_range, lv_stride), dim=1)) |
|
buffer = BufferList(buffer_list) |
|
return buffer |
|
|
|
def forward(self, pymid): |
|
points = [] |
|
sizes = [p.size(1) for p in pymid] + [0] * (len(self.buffer) - len(pymid)) |
|
for size, buffer in zip(sizes, self.buffer): |
|
if size == 0: |
|
continue |
|
assert size <= buffer.size(0), 'reached max buffer size' |
|
points.append(buffer[:size, :]) |
|
points = torch.cat(points) |
|
return points |
|
|