| import numpy as np |
| import pdb |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
|
|
| class OnlineBuffer(nn.Module): |
| def __init__(self, buffer_size, batch_size, input_size): |
| super().__init__() |
|
|
| self.place_left = True |
| self.strategy = None |
| self.buffer_size = buffer_size |
| print('buffer has %d slots' % buffer_size, buffer_size) |
|
|
| buf_data = torch.FloatTensor(buffer_size, *input_size).fill_(0) |
| buf_targets = torch.LongTensor(buffer_size).fill_(0) |
| buf_tasks = torch.LongTensor(buffer_size).fill_(0) |
|
|
| self.current_index = 0 |
| self.n_seen_so_far = 0 |
| self.is_full = 0 |
| self.total_classes = 0 |
| |
| self.register_buffer('buf_data', buf_data) |
| self.register_buffer('buf_targets', buf_targets) |
| self.register_buffer('buf_tasks', buf_tasks) |
|
|
|
|
| def tensor_to_device(self, device): |
| self.device = device |
| self.buf_data.to(device) |
| self.buf_targets.to(device) |
| self.buf_tasks.to(device) |
|
|
|
|
|
|
| def add_reservoir(self, x, y, task): |
| n_elem = x.size(0) |
| |
| self.device = x.device |
| place_left = max(0, self.buffer_size - self.current_index) |
| offset = min(place_left, n_elem) |
|
|
| if place_left: |
| offset = min(place_left, n_elem) |
|
|
| self.buf_data[self.current_index: self.current_index + offset].data.copy_(x[:offset]) |
| self.buf_targets[self.current_index: self.current_index + offset].data.copy_(y[:offset]) |
| self.buf_tasks[self.current_index: self.current_index + offset].fill_(task) |
| self.current_index += offset |
| self.n_seen_so_far += offset |
|
|
| if offset == x.size(0): |
| return |
|
|
| self.place_left = False |
|
|
| |
| x, y = x[place_left:], y[place_left:] |
|
|
| indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, self.n_seen_so_far).long() |
| valid_indices = (indices < self.buf_data.size(0)).long() |
|
|
| idx_new_data = valid_indices.nonzero().squeeze(-1) |
| idx_buffer = indices[idx_new_data] |
|
|
| self.n_seen_so_far += x.size(0) |
|
|
| if idx_buffer.numel() == 0: |
| return |
|
|
| assert idx_buffer.max() < self.buf_data.size(0), pdb.set_trace() |
| assert idx_buffer.max() < self.buf_targets.size(0), pdb.set_trace() |
| assert idx_buffer.max() < self.buf_tasks.size(0), pdb.set_trace() |
|
|
| assert idx_new_data.max() < x.size(0), pdb.set_trace() |
| assert idx_new_data.max() < y.size(0), pdb.set_trace() |
|
|
| if self.buf_data.device != x.device: |
| self.buf_data = self.buf_data.to(x.device) |
| self.buf_targets = self.buf_targets.to(x.device) |
| self.buf_tasks = self.buf_tasks.to(x.device) |
|
|
| self.buf_data[idx_buffer] = x[idx_new_data] |
| self.buf_targets[idx_buffer] = y[idx_new_data] |
| self.buf_tasks[idx_buffer] = task |
|
|
|
|
|
|
|
|
| def sample(self, amount, exclude_task = None, ret_ind = False): |
|
|
| if self.buf_data.device != self.device: |
| self.buf_data = self.buf_data.to(self.device) |
| self.buf_targets = self.buf_targets.to(self.device) |
| self.buf_tasks = self.buf_tasks.to(self.device) |
|
|
| if exclude_task is not None: |
| valid_indices = (self.t != exclude_task) |
| valid_indices = valid_indices.nonzero().squeeze() |
| bx, by, bt = self.buf_data[valid_indices], self.buf_targets[valid_indices], self.buf_tasks[valid_indices] |
| else: |
| bx, by, bt = self.buf_data[:self.current_index], self.buf_targets[:self.current_index], self.buf_tasks[:self.current_index] |
|
|
| if bx.size(0) < amount: |
| if ret_ind: |
| return bx, by, bt, torch.from_numpy(np.arange(bx.size(0))) |
| else: |
| return bx, by, bt |
| else: |
| indices = torch.from_numpy(np.random.choice(bx.size(0), amount, replace=False)) |
| indices = indices.to(self.device) |
|
|
| if ret_ind: |
| return bx[indices], by[indices], bt[indices], indices |
| else: |
| return bx[indices], by[indices], bt[indices] |