|
import math |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
|
|
def is_square_of_two(num): |
|
if num <= 0: |
|
return False |
|
return num & (num - 1) == 0 |
|
|
|
class CnnEncoder(nn.Module): |
|
""" |
|
Simple cnn encoder that encodes a 64x64 image to embeddings |
|
""" |
|
def __init__(self, embedding_size, activation_function='relu'): |
|
super().__init__() |
|
self.act_fn = getattr(F, activation_function) |
|
self.embedding_size = embedding_size |
|
self.fc = nn.Linear(1024, self.embedding_size) |
|
self.conv1 = nn.Conv2d(3, 32, 4, stride=2) |
|
self.conv2 = nn.Conv2d(32, 64, 4, stride=2) |
|
self.conv3 = nn.Conv2d(64, 128, 4, stride=2) |
|
self.conv4 = nn.Conv2d(128, 256, 4, stride=2) |
|
self.modules = [self.conv1, self.conv2, self.conv3, self.conv4] |
|
|
|
def forward(self, observation): |
|
batch_size = observation.shape[0] |
|
hidden = self.act_fn(self.conv1(observation)) |
|
hidden = self.act_fn(self.conv2(hidden)) |
|
hidden = self.act_fn(self.conv3(hidden)) |
|
hidden = self.act_fn(self.conv4(hidden)) |
|
hidden = self.fc(hidden.view(batch_size, 1024)) |
|
return hidden |
|
|
|
|
|
class CnnDecoder(nn.Module): |
|
""" |
|
Simple Cnn decoder that decodes an embedding to 64x64 images |
|
""" |
|
def __init__(self, embedding_size, activation_function='relu'): |
|
super().__init__() |
|
self.act_fn = getattr(F, activation_function) |
|
self.embedding_size = embedding_size |
|
self.fc = nn.Linear(embedding_size, 128) |
|
self.conv1 = nn.ConvTranspose2d(128, 128, 5, stride=2) |
|
self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2) |
|
self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2) |
|
self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2) |
|
self.modules = [self.conv1, self.conv2, self.conv3, self.conv4] |
|
|
|
def forward(self, embedding): |
|
batch_size = embedding.shape[0] |
|
hidden = self.fc(embedding) |
|
hidden = hidden.view(batch_size, 128, 1, 1) |
|
hidden = self.act_fn(self.conv1(hidden)) |
|
hidden = self.act_fn(self.conv2(hidden)) |
|
hidden = self.act_fn(self.conv3(hidden)) |
|
observation = self.conv4(hidden) |
|
return observation |
|
|
|
|
|
class FullyConvEncoder(nn.Module): |
|
""" |
|
Simple fully convolutional encoder, with 2D input and 2D output |
|
""" |
|
def __init__(self, |
|
input_shape=(3, 64, 64), |
|
embedding_shape=(8, 16, 16), |
|
activation_function='relu', |
|
init_channels=16, |
|
): |
|
super().__init__() |
|
|
|
assert len(input_shape) == 3, "input_shape must be a tuple of length 3" |
|
assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3" |
|
assert input_shape[1] == input_shape[2] and is_square_of_two(input_shape[1]), "input_shape must be square" |
|
assert embedding_shape[1] == embedding_shape[2], "embedding_shape must be square" |
|
assert input_shape[1] % embedding_shape[1] == 0, "input_shape must be divisible by embedding_shape" |
|
assert is_square_of_two(init_channels), "init_channels must be a square of 2" |
|
|
|
depth = int(math.sqrt(input_shape[1] / embedding_shape[1])) + 1 |
|
channels_per_layer = [init_channels * (2 ** i) for i in range(depth)] |
|
self.act_fn = getattr(F, activation_function) |
|
|
|
self.downs = nn.ModuleList([]) |
|
self.downs.append(nn.Conv2d(input_shape[0], channels_per_layer[0], kernel_size=3, stride=1, padding=1)) |
|
|
|
for i in range(1, depth): |
|
self.downs.append(nn.Conv2d(channels_per_layer[i-1], channels_per_layer[i], |
|
kernel_size=3, stride=2, padding=1)) |
|
|
|
|
|
self.downs.append(nn.Conv2d(channels_per_layer[-1], embedding_shape[0], kernel_size=1, stride=1, padding=0)) |
|
|
|
def forward(self, observation): |
|
hidden = observation |
|
for layer in self.downs: |
|
hidden = self.act_fn(layer(hidden)) |
|
return hidden |
|
|
|
|
|
class FullyConvDecoder(nn.Module): |
|
""" |
|
Simple fully convolutional decoder, with 2D input and 2D output |
|
""" |
|
def __init__(self, |
|
embedding_shape=(8, 16, 16), |
|
output_shape=(3, 64, 64), |
|
activation_function='relu', |
|
init_channels=16, |
|
): |
|
super().__init__() |
|
|
|
assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3" |
|
assert len(output_shape) == 3, "output_shape must be a tuple of length 3" |
|
assert output_shape[1] == output_shape[2] and is_square_of_two(output_shape[1]), "output_shape must be square" |
|
assert embedding_shape[1] == embedding_shape[2], "input_shape must be square" |
|
assert output_shape[1] % embedding_shape[1] == 0, "output_shape must be divisible by input_shape" |
|
assert is_square_of_two(init_channels), "init_channels must be a square of 2" |
|
|
|
depth = int(math.sqrt(output_shape[1] / embedding_shape[1])) + 1 |
|
channels_per_layer = [init_channels * (2 ** i) for i in range(depth)] |
|
self.act_fn = getattr(F, activation_function) |
|
|
|
self.ups = nn.ModuleList([]) |
|
self.ups.append(nn.ConvTranspose2d(embedding_shape[0], channels_per_layer[-1], |
|
kernel_size=1, stride=1, padding=0)) |
|
|
|
for i in range(1, depth): |
|
self.ups.append(nn.ConvTranspose2d(channels_per_layer[-i], channels_per_layer[-i-1], |
|
kernel_size=3, stride=2, padding=1, output_padding=1)) |
|
|
|
self.output_layer = nn.ConvTranspose2d(channels_per_layer[0], output_shape[0], |
|
kernel_size=3, stride=1, padding=1) |
|
|
|
def forward(self, embedding): |
|
hidden = embedding |
|
for layer in self.ups: |
|
hidden = self.act_fn(layer(hidden)) |
|
|
|
return self.output_layer(hidden) |
|
|