xizaoqu
init
27ca8b3
import math
import torch.nn as nn
from torch.nn import functional as F
def is_square_of_two(num):
if num <= 0:
return False
return num & (num - 1) == 0
class CnnEncoder(nn.Module):
"""
Simple cnn encoder that encodes a 64x64 image to embeddings
"""
def __init__(self, embedding_size, activation_function='relu'):
super().__init__()
self.act_fn = getattr(F, activation_function)
self.embedding_size = embedding_size
self.fc = nn.Linear(1024, self.embedding_size)
self.conv1 = nn.Conv2d(3, 32, 4, stride=2)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
self.modules = [self.conv1, self.conv2, self.conv3, self.conv4]
def forward(self, observation):
batch_size = observation.shape[0]
hidden = self.act_fn(self.conv1(observation))
hidden = self.act_fn(self.conv2(hidden))
hidden = self.act_fn(self.conv3(hidden))
hidden = self.act_fn(self.conv4(hidden))
hidden = self.fc(hidden.view(batch_size, 1024))
return hidden
class CnnDecoder(nn.Module):
"""
Simple Cnn decoder that decodes an embedding to 64x64 images
"""
def __init__(self, embedding_size, activation_function='relu'):
super().__init__()
self.act_fn = getattr(F, activation_function)
self.embedding_size = embedding_size
self.fc = nn.Linear(embedding_size, 128)
self.conv1 = nn.ConvTranspose2d(128, 128, 5, stride=2)
self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2)
self.modules = [self.conv1, self.conv2, self.conv3, self.conv4]
def forward(self, embedding):
batch_size = embedding.shape[0]
hidden = self.fc(embedding)
hidden = hidden.view(batch_size, 128, 1, 1)
hidden = self.act_fn(self.conv1(hidden))
hidden = self.act_fn(self.conv2(hidden))
hidden = self.act_fn(self.conv3(hidden))
observation = self.conv4(hidden)
return observation
class FullyConvEncoder(nn.Module):
"""
Simple fully convolutional encoder, with 2D input and 2D output
"""
def __init__(self,
input_shape=(3, 64, 64),
embedding_shape=(8, 16, 16),
activation_function='relu',
init_channels=16,
):
super().__init__()
assert len(input_shape) == 3, "input_shape must be a tuple of length 3"
assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3"
assert input_shape[1] == input_shape[2] and is_square_of_two(input_shape[1]), "input_shape must be square"
assert embedding_shape[1] == embedding_shape[2], "embedding_shape must be square"
assert input_shape[1] % embedding_shape[1] == 0, "input_shape must be divisible by embedding_shape"
assert is_square_of_two(init_channels), "init_channels must be a square of 2"
depth = int(math.sqrt(input_shape[1] / embedding_shape[1])) + 1
channels_per_layer = [init_channels * (2 ** i) for i in range(depth)]
self.act_fn = getattr(F, activation_function)
self.downs = nn.ModuleList([])
self.downs.append(nn.Conv2d(input_shape[0], channels_per_layer[0], kernel_size=3, stride=1, padding=1))
for i in range(1, depth):
self.downs.append(nn.Conv2d(channels_per_layer[i-1], channels_per_layer[i],
kernel_size=3, stride=2, padding=1))
# Bottleneck layer
self.downs.append(nn.Conv2d(channels_per_layer[-1], embedding_shape[0], kernel_size=1, stride=1, padding=0))
def forward(self, observation):
hidden = observation
for layer in self.downs:
hidden = self.act_fn(layer(hidden))
return hidden
class FullyConvDecoder(nn.Module):
"""
Simple fully convolutional decoder, with 2D input and 2D output
"""
def __init__(self,
embedding_shape=(8, 16, 16),
output_shape=(3, 64, 64),
activation_function='relu',
init_channels=16,
):
super().__init__()
assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3"
assert len(output_shape) == 3, "output_shape must be a tuple of length 3"
assert output_shape[1] == output_shape[2] and is_square_of_two(output_shape[1]), "output_shape must be square"
assert embedding_shape[1] == embedding_shape[2], "input_shape must be square"
assert output_shape[1] % embedding_shape[1] == 0, "output_shape must be divisible by input_shape"
assert is_square_of_two(init_channels), "init_channels must be a square of 2"
depth = int(math.sqrt(output_shape[1] / embedding_shape[1])) + 1
channels_per_layer = [init_channels * (2 ** i) for i in range(depth)]
self.act_fn = getattr(F, activation_function)
self.ups = nn.ModuleList([])
self.ups.append(nn.ConvTranspose2d(embedding_shape[0], channels_per_layer[-1],
kernel_size=1, stride=1, padding=0))
for i in range(1, depth):
self.ups.append(nn.ConvTranspose2d(channels_per_layer[-i], channels_per_layer[-i-1],
kernel_size=3, stride=2, padding=1, output_padding=1))
self.output_layer = nn.ConvTranspose2d(channels_per_layer[0], output_shape[0],
kernel_size=3, stride=1, padding=1)
def forward(self, embedding):
hidden = embedding
for layer in self.ups:
hidden = self.act_fn(layer(hidden))
return self.output_layer(hidden)