NCTC / models /research /maskgan /model_utils /variable_mapping.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
25.9 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def rnn_nas(hparams, model):
assert model == 'gen' or model == 'dis'
# This logic is only valid for rnn_zaremba
if model == 'gen':
assert FLAGS.generator_model == 'rnn_nas'
assert hparams.gen_num_layers == 2
if model == 'dis':
assert FLAGS.discriminator_model == 'rnn_nas'
assert hparams.dis_num_layers == 2
# Output variables only for the Generator. Discriminator output biases
# will begin randomly initialized.
if model == 'gen':
softmax_b = [
v for v in tf.trainable_variables() if v.op.name == 'gen/rnn/softmax_b'
][0]
# Common elements to Generator and Discriminator.
embedding = [
v for v in tf.trainable_variables()
if v.op.name == str(model) + '/rnn/embedding'
][0]
lstm_w_0 = [
v for v in tf.trainable_variables()
if v.op.name ==
str(model) + '/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat'
][0]
lstm_b_0 = [
v for v in tf.trainable_variables()
if v.op.name == str(model) +
'/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat'
][0]
lstm_w_1 = [
v for v in tf.trainable_variables()
if v.op.name ==
str(model) + '/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat'
][0]
lstm_b_1 = [
v for v in tf.trainable_variables()
if v.op.name == str(model) +
'/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat'
][0]
# Dictionary mapping.
if model == 'gen':
variable_mapping = {
'Model/embeddings/input_embedding':
embedding,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat':
lstm_w_0,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat':
lstm_b_0,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat':
lstm_w_1,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat':
lstm_b_1,
'Model/softmax_b':
softmax_b
}
else:
variable_mapping = {
'Model/embeddings/input_embedding':
embedding,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat':
lstm_w_0,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat':
lstm_b_0,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat':
lstm_w_1,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat':
lstm_b_1
}
return variable_mapping
def cnn():
"""Variable mapping for the CNN embedding.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_var.
"""
# This logic is only valid for cnn
assert FLAGS.discriminator_model == 'cnn'
# Retrieve CNN embedding.
embedding = [
v for v in tf.trainable_variables() if v.op.name == 'dis/embedding'
][0]
# Variable mapping.
variable_mapping = {'Model/embedding': embedding}
return variable_mapping
def rnn_zaremba(hparams, model):
"""Returns the PTB Variable name to MaskGAN Variable dictionary mapping. This
is a highly restrictive function just for testing. This will need to be
generalized.
Args:
hparams: Hyperparameters for the MaskGAN.
model: Model type, one of ['gen', 'dis'].
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_var.
"""
assert model == 'gen' or model == 'dis'
# This logic is only valid for rnn_zaremba
if model == 'gen':
assert FLAGS.generator_model == 'rnn_zaremba'
assert hparams.gen_num_layers == 2
if model == 'dis':
assert (FLAGS.discriminator_model == 'rnn_zaremba' or
FLAGS.discriminator_model == 'rnn_vd')
assert hparams.dis_num_layers == 2
# Output variables only for the Generator. Discriminator output weights
# and biases will begin randomly initialized.
if model == 'gen':
softmax_w = [
v for v in tf.trainable_variables() if v.op.name == 'gen/rnn/softmax_w'
][0]
softmax_b = [
v for v in tf.trainable_variables() if v.op.name == 'gen/rnn/softmax_b'
][0]
# Common elements to Generator and Discriminator.
if not FLAGS.dis_share_embedding or model != 'dis':
embedding = [
v for v in tf.trainable_variables()
if v.op.name == str(model) + '/rnn/embedding'
][0]
lstm_w_0 = [
v for v in tf.trainable_variables() if v.op.name == str(model) +
'/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
][0]
lstm_b_0 = [
v for v in tf.trainable_variables() if v.op.name == str(model) +
'/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
][0]
lstm_w_1 = [
v for v in tf.trainable_variables() if v.op.name == str(model) +
'/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
][0]
lstm_b_1 = [
v for v in tf.trainable_variables() if v.op.name == str(model) +
'/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
][0]
# Dictionary mapping.
if model == 'gen':
variable_mapping = {
'Model/embedding': embedding,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1,
'Model/softmax_w': softmax_w,
'Model/softmax_b': softmax_b
}
else:
if FLAGS.dis_share_embedding:
variable_mapping = {
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1
}
else:
variable_mapping = {
'Model/embedding': embedding,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1
}
return variable_mapping
def gen_encoder_seq2seq_nas(hparams):
"""Returns the NAS Variable name to MaskGAN Variable
dictionary mapping. This is a highly restrictive function just for testing.
This is for the *unidirecitional* seq2seq_nas encoder.
Args:
hparams: Hyperparameters for the MaskGAN.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself.
"""
assert FLAGS.generator_model == 'seq2seq_nas'
assert hparams.gen_num_layers == 2
## Encoder forward variables.
if not FLAGS.seq2seq_share_embedding:
encoder_embedding = [
v for v in tf.trainable_variables()
if v.op.name == 'gen/encoder/rnn/embedding'
][0]
encoder_lstm_w_0 = [
v for v in tf.trainable_variables()
if v.op.name ==
'gen/encoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat'
][0]
encoder_lstm_b_0 = [
v for v in tf.trainable_variables()
if v.op.name ==
'gen/encoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat'
][0]
encoder_lstm_w_1 = [
v for v in tf.trainable_variables()
if v.op.name ==
'gen/encoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat'
][0]
encoder_lstm_b_1 = [
v for v in tf.trainable_variables()
if v.op.name ==
'gen/encoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat'
][0]
if not FLAGS.seq2seq_share_embedding:
variable_mapping = {
'Model/embeddings/input_embedding':
encoder_embedding,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat':
encoder_lstm_w_0,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat':
encoder_lstm_b_0,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat':
encoder_lstm_w_1,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat':
encoder_lstm_b_1
}
else:
variable_mapping = {
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat':
encoder_lstm_w_0,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat':
encoder_lstm_b_0,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat':
encoder_lstm_w_1,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat':
encoder_lstm_b_1
}
return variable_mapping
def gen_decoder_seq2seq_nas(hparams):
assert FLAGS.generator_model == 'seq2seq_nas'
assert hparams.gen_num_layers == 2
decoder_embedding = [
v for v in tf.trainable_variables()
if v.op.name == 'gen/decoder/rnn/embedding'
][0]
decoder_lstm_w_0 = [
v for v in tf.trainable_variables()
if v.op.name ==
'gen/decoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat'
][0]
decoder_lstm_b_0 = [
v for v in tf.trainable_variables()
if v.op.name ==
'gen/decoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat'
][0]
decoder_lstm_w_1 = [
v for v in tf.trainable_variables()
if v.op.name ==
'gen/decoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat'
][0]
decoder_lstm_b_1 = [
v for v in tf.trainable_variables()
if v.op.name ==
'gen/decoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat'
][0]
decoder_softmax_b = [
v for v in tf.trainable_variables()
if v.op.name == 'gen/decoder/rnn/softmax_b'
][0]
variable_mapping = {
'Model/embeddings/input_embedding':
decoder_embedding,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat':
decoder_lstm_w_0,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat':
decoder_lstm_b_0,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat':
decoder_lstm_w_1,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat':
decoder_lstm_b_1,
'Model/softmax_b':
decoder_softmax_b
}
return variable_mapping
def gen_encoder_seq2seq(hparams):
"""Returns the PTB Variable name to MaskGAN Variable
dictionary mapping. This is a highly restrictive function just for testing.
This is foe the *unidirecitional* seq2seq_zaremba encoder.
Args:
hparams: Hyperparameters for the MaskGAN.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself.
"""
assert (FLAGS.generator_model == 'seq2seq_zaremba' or
FLAGS.generator_model == 'seq2seq_vd')
assert hparams.gen_num_layers == 2
## Encoder forward variables.
if not FLAGS.seq2seq_share_embedding:
encoder_embedding = [
v for v in tf.trainable_variables()
if v.op.name == 'gen/encoder/rnn/embedding'
][0]
encoder_lstm_w_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
][0]
encoder_lstm_b_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
][0]
encoder_lstm_w_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
][0]
encoder_lstm_b_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
][0]
if FLAGS.data_set == 'ptb':
model_str = 'Model'
else:
model_str = 'model'
if not FLAGS.seq2seq_share_embedding:
variable_mapping = {
str(model_str) + '/embedding':
encoder_embedding,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
encoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
encoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
encoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
encoder_lstm_b_1
}
else:
variable_mapping = {
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
encoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
encoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
encoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
encoder_lstm_b_1
}
return variable_mapping
def gen_decoder_seq2seq(hparams):
assert (FLAGS.generator_model == 'seq2seq_zaremba' or
FLAGS.generator_model == 'seq2seq_vd')
assert hparams.gen_num_layers == 2
decoder_embedding = [
v for v in tf.trainable_variables()
if v.op.name == 'gen/decoder/rnn/embedding'
][0]
decoder_lstm_w_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
][0]
decoder_lstm_b_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
][0]
decoder_lstm_w_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
][0]
decoder_lstm_b_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
][0]
decoder_softmax_b = [
v for v in tf.trainable_variables()
if v.op.name == 'gen/decoder/rnn/softmax_b'
][0]
if FLAGS.data_set == 'ptb':
model_str = 'Model'
else:
model_str = 'model'
variable_mapping = {
str(model_str) + '/embedding':
decoder_embedding,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
decoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
decoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
decoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
decoder_lstm_b_1,
str(model_str) + '/softmax_b':
decoder_softmax_b
}
return variable_mapping
def dis_fwd_bidirectional(hparams):
"""Returns the *forward* PTB Variable name to MaskGAN Variable dictionary
mapping. This is a highly restrictive function just for testing. This is for
the bidirectional_zaremba discriminator.
Args:
FLAGS: Flags for the model.
hparams: Hyperparameters for the MaskGAN.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself.
"""
assert (FLAGS.discriminator_model == 'bidirectional_zaremba' or
FLAGS.discriminator_model == 'bidirectional_vd')
assert hparams.dis_num_layers == 2
# Forward Discriminator Elements.
if not FLAGS.dis_share_embedding:
embedding = [
v for v in tf.trainable_variables() if v.op.name == 'dis/embedding'
][0]
fw_lstm_w_0 = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
][0]
fw_lstm_b_0 = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
][0]
fw_lstm_w_1 = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
][0]
fw_lstm_b_1 = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
][0]
if FLAGS.dis_share_embedding:
variable_mapping = {
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': fw_lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': fw_lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': fw_lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': fw_lstm_b_1
}
else:
variable_mapping = {
'Model/embedding': embedding,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': fw_lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': fw_lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': fw_lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': fw_lstm_b_1
}
return variable_mapping
def dis_bwd_bidirectional(hparams):
"""Returns the *backward* PTB Variable name to MaskGAN Variable dictionary
mapping. This is a highly restrictive function just for testing. This is for
the bidirectional_zaremba discriminator.
Args:
hparams: Hyperparameters for the MaskGAN.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself.
"""
assert (FLAGS.discriminator_model == 'bidirectional_zaremba' or
FLAGS.discriminator_model == 'bidirectional_vd')
assert hparams.dis_num_layers == 2
# Backward Discriminator Elements.
bw_lstm_w_0 = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
][0]
bw_lstm_b_0 = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
][0]
bw_lstm_w_1 = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
][0]
bw_lstm_b_1 = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
][0]
variable_mapping = {
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': bw_lstm_w_0,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': bw_lstm_b_0,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': bw_lstm_w_1,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': bw_lstm_b_1
}
return variable_mapping
def dis_encoder_seq2seq(hparams):
"""Returns the PTB Variable name to MaskGAN Variable
dictionary mapping.
Args:
hparams: Hyperparameters for the MaskGAN.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself.
"""
assert FLAGS.discriminator_model == 'seq2seq_vd'
assert hparams.dis_num_layers == 2
## Encoder forward variables.
encoder_lstm_w_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
][0]
encoder_lstm_b_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
][0]
encoder_lstm_w_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
][0]
encoder_lstm_b_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
][0]
if FLAGS.data_set == 'ptb':
model_str = 'Model'
else:
model_str = 'model'
variable_mapping = {
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
encoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
encoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
encoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
encoder_lstm_b_1
}
return variable_mapping
def dis_decoder_seq2seq(hparams):
assert FLAGS.discriminator_model == 'seq2seq_vd'
assert hparams.dis_num_layers == 2
if not FLAGS.dis_share_embedding:
decoder_embedding = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/decoder/rnn/embedding'
][0]
decoder_lstm_w_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
][0]
decoder_lstm_b_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
][0]
decoder_lstm_w_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
][0]
decoder_lstm_b_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
][0]
if FLAGS.data_set == 'ptb':
model_str = 'Model'
else:
model_str = 'model'
if not FLAGS.dis_share_embedding:
variable_mapping = {
str(model_str) + '/embedding':
decoder_embedding,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
decoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
decoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
decoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
decoder_lstm_b_1
}
else:
variable_mapping = {
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
decoder_lstm_w_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
decoder_lstm_b_0,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
decoder_lstm_w_1,
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
decoder_lstm_b_1,
}
return variable_mapping
def dis_seq2seq_vd(hparams):
assert FLAGS.discriminator_model == 'seq2seq_vd'
assert hparams.dis_num_layers == 2
if not FLAGS.dis_share_embedding:
decoder_embedding = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/decoder/rnn/embedding'
][0]
## Encoder variables.
encoder_lstm_w_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
][0]
encoder_lstm_b_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
][0]
encoder_lstm_w_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
][0]
encoder_lstm_b_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
][0]
## Attention.
if FLAGS.attention_option is not None:
decoder_attention_keys = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/decoder/attention_keys/weights'
][0]
decoder_attention_construct_weights = [
v for v in tf.trainable_variables()
if v.op.name == 'dis/decoder/rnn/attention_construct/weights'
][0]
## Decoder.
decoder_lstm_w_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'
][0]
decoder_lstm_b_0 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'
][0]
decoder_lstm_w_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'
][0]
decoder_lstm_b_1 = [
v for v in tf.trainable_variables() if v.op.name ==
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'
][0]
# Standard variable mappings.
variable_mapping = {
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
encoder_lstm_w_0,
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
encoder_lstm_b_0,
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
encoder_lstm_w_1,
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
encoder_lstm_b_1,
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel':
decoder_lstm_w_0,
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias':
decoder_lstm_b_0,
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel':
decoder_lstm_w_1,
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias':
decoder_lstm_b_1
}
# Optional variable mappings.
if not FLAGS.dis_share_embedding:
variable_mapping['gen/decoder/rnn/embedding'] = decoder_embedding
if FLAGS.attention_option is not None:
variable_mapping[
'gen/decoder/attention_keys/weights'] = decoder_attention_keys
variable_mapping[
'gen/decoder/rnn/attention_construct/weights'] = decoder_attention_construct_weights
return variable_mapping