R2-Tuning / models /generator.py
yeliudev's picture
Add files
bc120ce
raw
history blame
1.99 kB
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
import torch
import torch.nn as nn
class BufferList(nn.Module):
def __init__(self, buffers):
super(BufferList, self).__init__()
for i, buffer in enumerate(buffers):
self.register_buffer(str(i), buffer, persistent=False)
def __len__(self):
return len(self._buffers)
def __iter__(self):
return iter(self._buffers.values())
class PointGenerator(nn.Module):
def __init__(self, strides, buffer_size, offset=False):
super(PointGenerator, self).__init__()
reg_range, last = [], 0
for stride in strides[1:]:
reg_range.append((last, stride))
last = stride
reg_range.append((last, float('inf')))
self.strides = strides
self.reg_range = reg_range
self.buffer_size = buffer_size
self.offset = offset
self.buffer = self._cache_points()
def _cache_points(self):
buffer_list = []
for stride, reg_range in zip(self.strides, self.reg_range):
reg_range = torch.Tensor([reg_range])
lv_stride = torch.Tensor([stride])
points = torch.arange(0, self.buffer_size, stride)[:, None]
if self.offset:
points += 0.5 * stride
reg_range = reg_range.repeat(points.size(0), 1)
lv_stride = lv_stride.repeat(points.size(0), 1)
buffer_list.append(torch.cat((points, reg_range, lv_stride), dim=1))
buffer = BufferList(buffer_list)
return buffer
def forward(self, pymid):
points = []
sizes = [p.size(1) for p in pymid] + [0] * (len(self.buffer) - len(pymid))
for size, buffer in zip(sizes, self.buffer):
if size == 0:
continue
assert size <= buffer.size(0), 'reached max buffer size'
points.append(buffer[:size, :])
points = torch.cat(points)
return points