Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
# Dependency imports | |
import tensorflow as tf | |
FLAGS = tf.app.flags.FLAGS | |
def rnn_nas(hparams, model): | |
assert model == 'gen' or model == 'dis' | |
# This logic is only valid for rnn_zaremba | |
if model == 'gen': | |
assert FLAGS.generator_model == 'rnn_nas' | |
assert hparams.gen_num_layers == 2 | |
if model == 'dis': | |
assert FLAGS.discriminator_model == 'rnn_nas' | |
assert hparams.dis_num_layers == 2 | |
# Output variables only for the Generator. Discriminator output biases | |
# will begin randomly initialized. | |
if model == 'gen': | |
softmax_b = [ | |
v for v in tf.trainable_variables() if v.op.name == 'gen/rnn/softmax_b' | |
][0] | |
# Common elements to Generator and Discriminator. | |
embedding = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == str(model) + '/rnn/embedding' | |
][0] | |
lstm_w_0 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == | |
str(model) + '/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat' | |
][0] | |
lstm_b_0 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == str(model) + | |
'/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat' | |
][0] | |
lstm_w_1 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == | |
str(model) + '/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat' | |
][0] | |
lstm_b_1 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == str(model) + | |
'/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat' | |
][0] | |
# Dictionary mapping. | |
if model == 'gen': | |
variable_mapping = { | |
'Model/embeddings/input_embedding': | |
embedding, | |
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat': | |
lstm_w_0, | |
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat': | |
lstm_b_0, | |
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat': | |
lstm_w_1, | |
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat': | |
lstm_b_1, | |
'Model/softmax_b': | |
softmax_b | |
} | |
else: | |
variable_mapping = { | |
'Model/embeddings/input_embedding': | |
embedding, | |
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat': | |
lstm_w_0, | |
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat': | |
lstm_b_0, | |
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat': | |
lstm_w_1, | |
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat': | |
lstm_b_1 | |
} | |
return variable_mapping | |
def cnn(): | |
"""Variable mapping for the CNN embedding. | |
Returns: | |
variable_mapping: Dictionary with Key: ckpt_name, Value: model_var. | |
""" | |
# This logic is only valid for cnn | |
assert FLAGS.discriminator_model == 'cnn' | |
# Retrieve CNN embedding. | |
embedding = [ | |
v for v in tf.trainable_variables() if v.op.name == 'dis/embedding' | |
][0] | |
# Variable mapping. | |
variable_mapping = {'Model/embedding': embedding} | |
return variable_mapping | |
def rnn_zaremba(hparams, model): | |
"""Returns the PTB Variable name to MaskGAN Variable dictionary mapping. This | |
is a highly restrictive function just for testing. This will need to be | |
generalized. | |
Args: | |
hparams: Hyperparameters for the MaskGAN. | |
model: Model type, one of ['gen', 'dis']. | |
Returns: | |
variable_mapping: Dictionary with Key: ckpt_name, Value: model_var. | |
""" | |
assert model == 'gen' or model == 'dis' | |
# This logic is only valid for rnn_zaremba | |
if model == 'gen': | |
assert FLAGS.generator_model == 'rnn_zaremba' | |
assert hparams.gen_num_layers == 2 | |
if model == 'dis': | |
assert (FLAGS.discriminator_model == 'rnn_zaremba' or | |
FLAGS.discriminator_model == 'rnn_vd') | |
assert hparams.dis_num_layers == 2 | |
# Output variables only for the Generator. Discriminator output weights | |
# and biases will begin randomly initialized. | |
if model == 'gen': | |
softmax_w = [ | |
v for v in tf.trainable_variables() if v.op.name == 'gen/rnn/softmax_w' | |
][0] | |
softmax_b = [ | |
v for v in tf.trainable_variables() if v.op.name == 'gen/rnn/softmax_b' | |
][0] | |
# Common elements to Generator and Discriminator. | |
if not FLAGS.dis_share_embedding or model != 'dis': | |
embedding = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == str(model) + '/rnn/embedding' | |
][0] | |
lstm_w_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == str(model) + | |
'/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
][0] | |
lstm_b_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == str(model) + | |
'/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
][0] | |
lstm_w_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == str(model) + | |
'/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
][0] | |
lstm_b_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == str(model) + | |
'/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
][0] | |
# Dictionary mapping. | |
if model == 'gen': | |
variable_mapping = { | |
'Model/embedding': embedding, | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0, | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1, | |
'Model/softmax_w': softmax_w, | |
'Model/softmax_b': softmax_b | |
} | |
else: | |
if FLAGS.dis_share_embedding: | |
variable_mapping = { | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0, | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1 | |
} | |
else: | |
variable_mapping = { | |
'Model/embedding': embedding, | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0, | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1 | |
} | |
return variable_mapping | |
def gen_encoder_seq2seq_nas(hparams): | |
"""Returns the NAS Variable name to MaskGAN Variable | |
dictionary mapping. This is a highly restrictive function just for testing. | |
This is for the *unidirecitional* seq2seq_nas encoder. | |
Args: | |
hparams: Hyperparameters for the MaskGAN. | |
Returns: | |
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself. | |
""" | |
assert FLAGS.generator_model == 'seq2seq_nas' | |
assert hparams.gen_num_layers == 2 | |
## Encoder forward variables. | |
if not FLAGS.seq2seq_share_embedding: | |
encoder_embedding = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'gen/encoder/rnn/embedding' | |
][0] | |
encoder_lstm_w_0 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == | |
'gen/encoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat' | |
][0] | |
encoder_lstm_b_0 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == | |
'gen/encoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat' | |
][0] | |
encoder_lstm_w_1 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == | |
'gen/encoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat' | |
][0] | |
encoder_lstm_b_1 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == | |
'gen/encoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat' | |
][0] | |
if not FLAGS.seq2seq_share_embedding: | |
variable_mapping = { | |
'Model/embeddings/input_embedding': | |
encoder_embedding, | |
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat': | |
encoder_lstm_w_0, | |
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat': | |
encoder_lstm_b_0, | |
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat': | |
encoder_lstm_w_1, | |
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat': | |
encoder_lstm_b_1 | |
} | |
else: | |
variable_mapping = { | |
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat': | |
encoder_lstm_w_0, | |
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat': | |
encoder_lstm_b_0, | |
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat': | |
encoder_lstm_w_1, | |
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat': | |
encoder_lstm_b_1 | |
} | |
return variable_mapping | |
def gen_decoder_seq2seq_nas(hparams): | |
assert FLAGS.generator_model == 'seq2seq_nas' | |
assert hparams.gen_num_layers == 2 | |
decoder_embedding = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'gen/decoder/rnn/embedding' | |
][0] | |
decoder_lstm_w_0 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == | |
'gen/decoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat' | |
][0] | |
decoder_lstm_b_0 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == | |
'gen/decoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat' | |
][0] | |
decoder_lstm_w_1 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == | |
'gen/decoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat' | |
][0] | |
decoder_lstm_b_1 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == | |
'gen/decoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat' | |
][0] | |
decoder_softmax_b = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'gen/decoder/rnn/softmax_b' | |
][0] | |
variable_mapping = { | |
'Model/embeddings/input_embedding': | |
decoder_embedding, | |
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat': | |
decoder_lstm_w_0, | |
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat': | |
decoder_lstm_b_0, | |
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat': | |
decoder_lstm_w_1, | |
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat': | |
decoder_lstm_b_1, | |
'Model/softmax_b': | |
decoder_softmax_b | |
} | |
return variable_mapping | |
def gen_encoder_seq2seq(hparams): | |
"""Returns the PTB Variable name to MaskGAN Variable | |
dictionary mapping. This is a highly restrictive function just for testing. | |
This is foe the *unidirecitional* seq2seq_zaremba encoder. | |
Args: | |
hparams: Hyperparameters for the MaskGAN. | |
Returns: | |
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself. | |
""" | |
assert (FLAGS.generator_model == 'seq2seq_zaremba' or | |
FLAGS.generator_model == 'seq2seq_vd') | |
assert hparams.gen_num_layers == 2 | |
## Encoder forward variables. | |
if not FLAGS.seq2seq_share_embedding: | |
encoder_embedding = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'gen/encoder/rnn/embedding' | |
][0] | |
encoder_lstm_w_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
][0] | |
encoder_lstm_b_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
][0] | |
encoder_lstm_w_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
][0] | |
encoder_lstm_b_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
][0] | |
if FLAGS.data_set == 'ptb': | |
model_str = 'Model' | |
else: | |
model_str = 'model' | |
if not FLAGS.seq2seq_share_embedding: | |
variable_mapping = { | |
str(model_str) + '/embedding': | |
encoder_embedding, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
encoder_lstm_w_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
encoder_lstm_b_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
encoder_lstm_w_1, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
encoder_lstm_b_1 | |
} | |
else: | |
variable_mapping = { | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
encoder_lstm_w_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
encoder_lstm_b_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
encoder_lstm_w_1, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
encoder_lstm_b_1 | |
} | |
return variable_mapping | |
def gen_decoder_seq2seq(hparams): | |
assert (FLAGS.generator_model == 'seq2seq_zaremba' or | |
FLAGS.generator_model == 'seq2seq_vd') | |
assert hparams.gen_num_layers == 2 | |
decoder_embedding = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'gen/decoder/rnn/embedding' | |
][0] | |
decoder_lstm_w_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
][0] | |
decoder_lstm_b_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
][0] | |
decoder_lstm_w_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
][0] | |
decoder_lstm_b_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
][0] | |
decoder_softmax_b = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'gen/decoder/rnn/softmax_b' | |
][0] | |
if FLAGS.data_set == 'ptb': | |
model_str = 'Model' | |
else: | |
model_str = 'model' | |
variable_mapping = { | |
str(model_str) + '/embedding': | |
decoder_embedding, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
decoder_lstm_w_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
decoder_lstm_b_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
decoder_lstm_w_1, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
decoder_lstm_b_1, | |
str(model_str) + '/softmax_b': | |
decoder_softmax_b | |
} | |
return variable_mapping | |
def dis_fwd_bidirectional(hparams): | |
"""Returns the *forward* PTB Variable name to MaskGAN Variable dictionary | |
mapping. This is a highly restrictive function just for testing. This is for | |
the bidirectional_zaremba discriminator. | |
Args: | |
FLAGS: Flags for the model. | |
hparams: Hyperparameters for the MaskGAN. | |
Returns: | |
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself. | |
""" | |
assert (FLAGS.discriminator_model == 'bidirectional_zaremba' or | |
FLAGS.discriminator_model == 'bidirectional_vd') | |
assert hparams.dis_num_layers == 2 | |
# Forward Discriminator Elements. | |
if not FLAGS.dis_share_embedding: | |
embedding = [ | |
v for v in tf.trainable_variables() if v.op.name == 'dis/embedding' | |
][0] | |
fw_lstm_w_0 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
][0] | |
fw_lstm_b_0 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
][0] | |
fw_lstm_w_1 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
][0] | |
fw_lstm_b_1 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
][0] | |
if FLAGS.dis_share_embedding: | |
variable_mapping = { | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': fw_lstm_w_0, | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': fw_lstm_b_0, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': fw_lstm_w_1, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': fw_lstm_b_1 | |
} | |
else: | |
variable_mapping = { | |
'Model/embedding': embedding, | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': fw_lstm_w_0, | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': fw_lstm_b_0, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': fw_lstm_w_1, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': fw_lstm_b_1 | |
} | |
return variable_mapping | |
def dis_bwd_bidirectional(hparams): | |
"""Returns the *backward* PTB Variable name to MaskGAN Variable dictionary | |
mapping. This is a highly restrictive function just for testing. This is for | |
the bidirectional_zaremba discriminator. | |
Args: | |
hparams: Hyperparameters for the MaskGAN. | |
Returns: | |
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself. | |
""" | |
assert (FLAGS.discriminator_model == 'bidirectional_zaremba' or | |
FLAGS.discriminator_model == 'bidirectional_vd') | |
assert hparams.dis_num_layers == 2 | |
# Backward Discriminator Elements. | |
bw_lstm_w_0 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
][0] | |
bw_lstm_b_0 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
][0] | |
bw_lstm_w_1 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
][0] | |
bw_lstm_b_1 = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
][0] | |
variable_mapping = { | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': bw_lstm_w_0, | |
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': bw_lstm_b_0, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': bw_lstm_w_1, | |
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': bw_lstm_b_1 | |
} | |
return variable_mapping | |
def dis_encoder_seq2seq(hparams): | |
"""Returns the PTB Variable name to MaskGAN Variable | |
dictionary mapping. | |
Args: | |
hparams: Hyperparameters for the MaskGAN. | |
Returns: | |
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself. | |
""" | |
assert FLAGS.discriminator_model == 'seq2seq_vd' | |
assert hparams.dis_num_layers == 2 | |
## Encoder forward variables. | |
encoder_lstm_w_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
][0] | |
encoder_lstm_b_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
][0] | |
encoder_lstm_w_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
][0] | |
encoder_lstm_b_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
][0] | |
if FLAGS.data_set == 'ptb': | |
model_str = 'Model' | |
else: | |
model_str = 'model' | |
variable_mapping = { | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
encoder_lstm_w_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
encoder_lstm_b_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
encoder_lstm_w_1, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
encoder_lstm_b_1 | |
} | |
return variable_mapping | |
def dis_decoder_seq2seq(hparams): | |
assert FLAGS.discriminator_model == 'seq2seq_vd' | |
assert hparams.dis_num_layers == 2 | |
if not FLAGS.dis_share_embedding: | |
decoder_embedding = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/decoder/rnn/embedding' | |
][0] | |
decoder_lstm_w_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
][0] | |
decoder_lstm_b_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
][0] | |
decoder_lstm_w_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
][0] | |
decoder_lstm_b_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
][0] | |
if FLAGS.data_set == 'ptb': | |
model_str = 'Model' | |
else: | |
model_str = 'model' | |
if not FLAGS.dis_share_embedding: | |
variable_mapping = { | |
str(model_str) + '/embedding': | |
decoder_embedding, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
decoder_lstm_w_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
decoder_lstm_b_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
decoder_lstm_w_1, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
decoder_lstm_b_1 | |
} | |
else: | |
variable_mapping = { | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
decoder_lstm_w_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
decoder_lstm_b_0, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
decoder_lstm_w_1, | |
str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
decoder_lstm_b_1, | |
} | |
return variable_mapping | |
def dis_seq2seq_vd(hparams): | |
assert FLAGS.discriminator_model == 'seq2seq_vd' | |
assert hparams.dis_num_layers == 2 | |
if not FLAGS.dis_share_embedding: | |
decoder_embedding = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/decoder/rnn/embedding' | |
][0] | |
## Encoder variables. | |
encoder_lstm_w_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
][0] | |
encoder_lstm_b_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
][0] | |
encoder_lstm_w_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
][0] | |
encoder_lstm_b_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
][0] | |
## Attention. | |
if FLAGS.attention_option is not None: | |
decoder_attention_keys = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/decoder/attention_keys/weights' | |
][0] | |
decoder_attention_construct_weights = [ | |
v for v in tf.trainable_variables() | |
if v.op.name == 'dis/decoder/rnn/attention_construct/weights' | |
][0] | |
## Decoder. | |
decoder_lstm_w_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
][0] | |
decoder_lstm_b_0 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
][0] | |
decoder_lstm_w_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
][0] | |
decoder_lstm_b_1 = [ | |
v for v in tf.trainable_variables() if v.op.name == | |
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
][0] | |
# Standard variable mappings. | |
variable_mapping = { | |
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
encoder_lstm_w_0, | |
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
encoder_lstm_b_0, | |
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
encoder_lstm_w_1, | |
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
encoder_lstm_b_1, | |
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
decoder_lstm_w_0, | |
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
decoder_lstm_b_0, | |
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
decoder_lstm_w_1, | |
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
decoder_lstm_b_1 | |
} | |
# Optional variable mappings. | |
if not FLAGS.dis_share_embedding: | |
variable_mapping['gen/decoder/rnn/embedding'] = decoder_embedding | |
if FLAGS.attention_option is not None: | |
variable_mapping[ | |
'gen/decoder/attention_keys/weights'] = decoder_attention_keys | |
variable_mapping[ | |
'gen/decoder/rnn/attention_construct/weights'] = decoder_attention_construct_weights | |
return variable_mapping | |