| import math |
| import torch.nn as nn |
| from torch.nn import functional as F |
|
|
|
|
| def is_square_of_two(num): |
| if num <= 0: |
| return False |
| return num & (num - 1) == 0 |
|
|
|
|
| class CnnEncoder(nn.Module): |
| """ |
| Simple cnn encoder that encodes a 64x64 image to embeddings |
| """ |
|
|
| def __init__(self, embedding_size, activation_function="relu"): |
| super().__init__() |
| self.act_fn = getattr(F, activation_function) |
| self.embedding_size = embedding_size |
| self.fc = nn.Linear(1024, self.embedding_size) |
| self.conv1 = nn.Conv2d(3, 32, 4, stride=2) |
| self.conv2 = nn.Conv2d(32, 64, 4, stride=2) |
| self.conv3 = nn.Conv2d(64, 128, 4, stride=2) |
| self.conv4 = nn.Conv2d(128, 256, 4, stride=2) |
| self.modules = [self.conv1, self.conv2, self.conv3, self.conv4] |
|
|
| def forward(self, observation): |
| batch_size = observation.shape[0] |
| hidden = self.act_fn(self.conv1(observation)) |
| hidden = self.act_fn(self.conv2(hidden)) |
| hidden = self.act_fn(self.conv3(hidden)) |
| hidden = self.act_fn(self.conv4(hidden)) |
| hidden = self.fc(hidden.view(batch_size, 1024)) |
| return hidden |
|
|
|
|
| class CnnDecoder(nn.Module): |
| """ |
| Simple Cnn decoder that decodes an embedding to 64x64 images |
| """ |
|
|
| def __init__(self, embedding_size, activation_function="relu"): |
| super().__init__() |
| self.act_fn = getattr(F, activation_function) |
| self.embedding_size = embedding_size |
| self.fc = nn.Linear(embedding_size, 128) |
| self.conv1 = nn.ConvTranspose2d(128, 128, 5, stride=2) |
| self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2) |
| self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2) |
| self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2) |
| self.modules = [self.conv1, self.conv2, self.conv3, self.conv4] |
|
|
| def forward(self, embedding): |
| batch_size = embedding.shape[0] |
| hidden = self.fc(embedding) |
| hidden = hidden.view(batch_size, 128, 1, 1) |
| hidden = self.act_fn(self.conv1(hidden)) |
| hidden = self.act_fn(self.conv2(hidden)) |
| hidden = self.act_fn(self.conv3(hidden)) |
| observation = self.conv4(hidden) |
| return observation |
|
|
|
|
| class FullyConvEncoder(nn.Module): |
| """ |
| Simple fully convolutional encoder, with 2D input and 2D output |
| """ |
|
|
| def __init__( |
| self, |
| input_shape=(3, 64, 64), |
| embedding_shape=(8, 16, 16), |
| activation_function="relu", |
| init_channels=16, |
| ): |
| super().__init__() |
|
|
| assert len(input_shape) == 3, "input_shape must be a tuple of length 3" |
| assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3" |
| assert input_shape[1] == input_shape[2] and is_square_of_two( |
| input_shape[1] |
| ), "input_shape must be square" |
| assert ( |
| embedding_shape[1] == embedding_shape[2] |
| ), "embedding_shape must be square" |
| assert ( |
| input_shape[1] % embedding_shape[1] == 0 |
| ), "input_shape must be divisible by embedding_shape" |
| assert is_square_of_two(init_channels), "init_channels must be a square of 2" |
|
|
| depth = int(math.sqrt(input_shape[1] / embedding_shape[1])) + 1 |
| channels_per_layer = [init_channels * (2**i) for i in range(depth)] |
| self.act_fn = getattr(F, activation_function) |
|
|
| self.downs = nn.ModuleList([]) |
| self.downs.append( |
| nn.Conv2d( |
| input_shape[0], |
| channels_per_layer[0], |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ) |
| ) |
|
|
| for i in range(1, depth): |
| self.downs.append( |
| nn.Conv2d( |
| channels_per_layer[i - 1], |
| channels_per_layer[i], |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ) |
| ) |
|
|
| |
| self.downs.append( |
| nn.Conv2d( |
| channels_per_layer[-1], |
| embedding_shape[0], |
| kernel_size=1, |
| stride=1, |
| padding=0, |
| ) |
| ) |
|
|
| def forward(self, observation): |
| hidden = observation |
| for layer in self.downs: |
| hidden = self.act_fn(layer(hidden)) |
| return hidden |
|
|
|
|
| class FullyConvDecoder(nn.Module): |
| """ |
| Simple fully convolutional decoder, with 2D input and 2D output |
| """ |
|
|
| def __init__( |
| self, |
| embedding_shape=(8, 16, 16), |
| output_shape=(3, 64, 64), |
| activation_function="relu", |
| init_channels=16, |
| ): |
| super().__init__() |
|
|
| assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3" |
| assert len(output_shape) == 3, "output_shape must be a tuple of length 3" |
| assert output_shape[1] == output_shape[2] and is_square_of_two( |
| output_shape[1] |
| ), "output_shape must be square" |
| assert embedding_shape[1] == embedding_shape[2], "input_shape must be square" |
| assert ( |
| output_shape[1] % embedding_shape[1] == 0 |
| ), "output_shape must be divisible by input_shape" |
| assert is_square_of_two(init_channels), "init_channels must be a square of 2" |
|
|
| depth = int(math.sqrt(output_shape[1] / embedding_shape[1])) + 1 |
| channels_per_layer = [init_channels * (2**i) for i in range(depth)] |
| self.act_fn = getattr(F, activation_function) |
|
|
| self.ups = nn.ModuleList([]) |
| self.ups.append( |
| nn.ConvTranspose2d( |
| embedding_shape[0], |
| channels_per_layer[-1], |
| kernel_size=1, |
| stride=1, |
| padding=0, |
| ) |
| ) |
|
|
| for i in range(1, depth): |
| self.ups.append( |
| nn.ConvTranspose2d( |
| channels_per_layer[-i], |
| channels_per_layer[-i - 1], |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| output_padding=1, |
| ) |
| ) |
|
|
| self.output_layer = nn.ConvTranspose2d( |
| channels_per_layer[0], output_shape[0], kernel_size=3, stride=1, padding=1 |
| ) |
|
|
| def forward(self, embedding): |
| hidden = embedding |
| for layer in self.ups: |
| hidden = self.act_fn(layer(hidden)) |
|
|
| return self.output_layer(hidden) |
|
|