|
import random |
|
import time |
|
import datetime |
|
import sys |
|
|
|
from torch.autograd import Variable |
|
import torch |
|
import numpy as np |
|
|
|
from torchvision.utils import save_image |
|
|
|
|
|
class ReplayBuffer: |
|
def __init__(self, max_size=50): |
|
assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful." |
|
self.max_size = max_size |
|
self.data = [] |
|
|
|
def push_and_pop(self, data): |
|
to_return = [] |
|
for element in data.data: |
|
element = torch.unsqueeze(element, 0) |
|
if len(self.data) < self.max_size: |
|
self.data.append(element) |
|
to_return.append(element) |
|
else: |
|
if random.uniform(0, 1) > 0.5: |
|
i = random.randint(0, self.max_size - 1) |
|
to_return.append(self.data[i].clone()) |
|
self.data[i] = element |
|
else: |
|
to_return.append(element) |
|
return Variable(torch.cat(to_return)) |
|
|
|
|
|
class LambdaLR: |
|
def __init__(self, n_epochs, offset, decay_start_epoch): |
|
assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!" |
|
self.n_epochs = n_epochs |
|
self.offset = offset |
|
self.decay_start_epoch = decay_start_epoch |
|
|
|
def step(self, epoch): |
|
return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch) |