baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
7.7 kB
import torch
import torch.utils.data
from torch import nn
from torch.nn import functional as F
from rlkit.pythonplusplus import identity
from rlkit.torch import pytorch_util as ptu
import numpy as np
from rlkit.torch.conv_networks import CNN, DCNN
from rlkit.torch.vae.vae_base import GaussianLatentVAE
###### DEFAULT ARCHITECTURES #########
imsize48_default_architecture = dict(
conv_args=dict(
kernel_sizes=[5, 3, 3],
n_channels=[16, 32, 64],
strides=[3, 2, 2],
),
conv_kwargs=dict(
hidden_sizes=[],
batch_norm_conv=False,
batch_norm_fc=False,
),
deconv_args=dict(
hidden_sizes=[],
deconv_input_width=3,
deconv_input_height=3,
deconv_input_channels=64,
deconv_output_kernel_size=6,
deconv_output_strides=3,
deconv_output_channels=3,
kernel_sizes=[3, 3],
n_channels=[32, 16],
strides=[2, 2],
),
deconv_kwargs=dict(
batch_norm_deconv=False,
batch_norm_fc=False,
)
)
imsize48_default_architecture_with_more_hidden_layers = dict(
conv_args=dict(
kernel_sizes=[5, 3, 3],
n_channels=[16, 32, 64],
strides=[3, 2, 2],
),
conv_kwargs=dict(
hidden_sizes=[500, 300, 150],
),
deconv_args=dict(
hidden_sizes=[150, 300, 500],
deconv_input_width=3,
deconv_input_height=3,
deconv_input_channels=64,
deconv_output_kernel_size=6,
deconv_output_strides=3,
deconv_output_channels=3,
kernel_sizes=[3, 3],
n_channels=[32, 16],
strides=[2, 2],
),
deconv_kwargs=dict(
)
)
imsize84_default_architecture = dict(
conv_args=dict(
kernel_sizes=[5, 5, 5],
n_channels=[16, 32, 32],
strides=[3, 3, 3],
),
conv_kwargs=dict(
hidden_sizes=[],
batch_norm_conv=True,
batch_norm_fc=False,
),
deconv_args=dict(
hidden_sizes=[],
deconv_input_width=2,
deconv_input_height=2,
deconv_input_channels=32,
deconv_output_kernel_size=6,
deconv_output_strides=3,
deconv_output_channels=3,
kernel_sizes=[5, 6],
n_channels=[32, 16],
strides=[3, 3],
),
deconv_kwargs=dict(
batch_norm_deconv=False,
batch_norm_fc=False,
)
)
class ConvVAE(GaussianLatentVAE):
def __init__(
self,
representation_size,
architecture,
encoder_class=CNN,
decoder_class=DCNN,
decoder_output_activation=identity,
decoder_distribution='bernoulli',
input_channels=1,
imsize=48,
init_w=1e-3,
min_variance=1e-3,
hidden_init=ptu.fanin_init,
):
"""
:param representation_size:
:param conv_args:
must be a dictionary specifying the following:
kernel_sizes
n_channels
strides
:param conv_kwargs:
a dictionary specifying the following:
hidden_sizes
batch_norm
:param deconv_args:
must be a dictionary specifying the following:
hidden_sizes
deconv_input_width
deconv_input_height
deconv_input_channels
deconv_output_kernel_size
deconv_output_strides
deconv_output_channels
kernel_sizes
n_channels
strides
:param deconv_kwargs:
batch_norm
:param encoder_class:
:param decoder_class:
:param decoder_output_activation:
:param decoder_distribution:
:param input_channels:
:param imsize:
:param init_w:
:param min_variance:
:param hidden_init:
"""
super().__init__(representation_size)
if min_variance is None:
self.log_min_variance = None
else:
self.log_min_variance = float(np.log(min_variance))
self.input_channels = input_channels
self.imsize = imsize
self.imlength = self.imsize * self.imsize * self.input_channels
conv_args, conv_kwargs, deconv_args, deconv_kwargs = \
architecture['conv_args'], architecture['conv_kwargs'], \
architecture['deconv_args'], architecture['deconv_kwargs']
conv_output_size = deconv_args['deconv_input_width'] * \
deconv_args['deconv_input_height'] * \
deconv_args['deconv_input_channels']
self.encoder = encoder_class(
**conv_args,
paddings=np.zeros(len(conv_args['kernel_sizes']), dtype=np.int64),
input_height=self.imsize,
input_width=self.imsize,
input_channels=self.input_channels,
output_size=conv_output_size,
init_w=init_w,
hidden_init=hidden_init,
**conv_kwargs)
self.fc1 = nn.Linear(self.encoder.output_size, representation_size)
self.fc2 = nn.Linear(self.encoder.output_size, representation_size)
self.fc1.weight.data.uniform_(-init_w, init_w)
self.fc1.bias.data.uniform_(-init_w, init_w)
self.fc2.weight.data.uniform_(-init_w, init_w)
self.fc2.bias.data.uniform_(-init_w, init_w)
self.decoder = decoder_class(
**deconv_args,
fc_input_size=representation_size,
init_w=init_w,
output_activation=decoder_output_activation,
paddings=np.zeros(len(deconv_args['kernel_sizes']), dtype=np.int64),
hidden_init=hidden_init,
**deconv_kwargs)
self.epoch = 0
self.decoder_distribution = decoder_distribution
def encode(self, input):
h = self.encoder(input)
mu = self.fc1(h)
if self.log_min_variance is None:
logvar = self.fc2(h)
else:
logvar = self.log_min_variance + torch.abs(self.fc2(h))
return (mu, logvar)
def decode(self, latents):
decoded = self.decoder(latents).view(-1,
self.imsize * self.imsize * self.input_channels)
if self.decoder_distribution == 'bernoulli':
return decoded, [decoded]
elif self.decoder_distribution == 'gaussian_identity_variance':
return torch.clamp(decoded, 0, 1), [torch.clamp(decoded, 0, 1),
torch.ones_like(decoded)]
else:
raise NotImplementedError('Distribution {} not supported'.format(
self.decoder_distribution))
def logprob(self, inputs, obs_distribution_params):
if self.decoder_distribution == 'bernoulli':
inputs = inputs.narrow(start=0, length=self.imlength,
dim=1).contiguous().view(-1, self.imlength)
log_prob = - F.binary_cross_entropy(
obs_distribution_params[0],
inputs,
reduction='elementwise_mean'
) * self.imlength
return log_prob
if self.decoder_distribution == 'gaussian_identity_variance':
inputs = inputs.narrow(start=0, length=self.imlength,
dim=1).contiguous().view(-1, self.imlength)
log_prob = -1 * F.mse_loss(inputs, obs_distribution_params[0],
reduction='elementwise_mean')
return log_prob
else:
raise NotImplementedError('Distribution {} not supported'.format(
self.decoder_distribution))