Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Input utils for virtual adversarial text classification.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
# Dependency imports | |
import tensorflow as tf | |
from data import data_utils | |
class VatxtInput(object): | |
"""Wrapper around NextQueuedSequenceBatch.""" | |
def __init__(self, | |
batch, | |
state_name=None, | |
tokens=None, | |
num_states=0, | |
eos_id=None): | |
"""Construct VatxtInput. | |
Args: | |
batch: NextQueuedSequenceBatch. | |
state_name: str, name of state to fetch and save. | |
tokens: int Tensor, tokens. Defaults to batch's F_TOKEN_ID sequence. | |
num_states: int The number of states to store. | |
eos_id: int Id of end of Sequence. | |
""" | |
self._batch = batch | |
self._state_name = state_name | |
self._tokens = (tokens if tokens is not None else | |
batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID]) | |
self._num_states = num_states | |
w = batch.sequences[data_utils.SequenceWrapper.F_WEIGHT] | |
self._weights = w | |
l = batch.sequences[data_utils.SequenceWrapper.F_LABEL] | |
self._labels = l | |
# eos weights | |
self._eos_weights = None | |
if eos_id: | |
ew = tf.cast(tf.equal(self._tokens, eos_id), tf.float32) | |
self._eos_weights = ew | |
def tokens(self): | |
return self._tokens | |
def weights(self): | |
return self._weights | |
def eos_weights(self): | |
return self._eos_weights | |
def labels(self): | |
return self._labels | |
def length(self): | |
return self._batch.length | |
def state_name(self): | |
return self._state_name | |
def state(self): | |
# LSTM tuple states | |
state_names = _get_tuple_state_names(self._num_states, self._state_name) | |
return tuple([ | |
tf.contrib.rnn.LSTMStateTuple( | |
self._batch.state(c_name), self._batch.state(h_name)) | |
for c_name, h_name in state_names | |
]) | |
def save_state(self, value): | |
# LSTM tuple states | |
state_names = _get_tuple_state_names(self._num_states, self._state_name) | |
save_ops = [] | |
for (c_state, h_state), (c_name, h_name) in zip(value, state_names): | |
save_ops.append(self._batch.save_state(c_name, c_state)) | |
save_ops.append(self._batch.save_state(h_name, h_state)) | |
return tf.group(*save_ops) | |
def _get_tuple_state_names(num_states, base_name): | |
"""Returns state names for use with LSTM tuple state.""" | |
state_names = [('{}_{}_c'.format(i, base_name), '{}_{}_h'.format( | |
i, base_name)) for i in range(num_states)] | |
return state_names | |
def _split_bidir_tokens(batch): | |
tokens = batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID] | |
# Tokens have shape [batch, time, 2] | |
# forward and reverse have shape [batch, time]. | |
forward, reverse = [ | |
tf.squeeze(t, axis=[2]) for t in tf.split(tokens, 2, axis=2) | |
] | |
return forward, reverse | |
def _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq): | |
"""Returns input filenames for configuration. | |
Args: | |
phase: str, 'train', 'test', or 'valid'. | |
bidir: bool, bidirectional model. | |
pretrain: bool, pretraining or classification. | |
use_seq2seq: bool, seq2seq data, only valid if pretrain=True. | |
Returns: | |
Tuple of filenames. | |
Raises: | |
ValueError: if an invalid combination of arguments is provided that does not | |
map to any data files (e.g. pretrain=False, use_seq2seq=True). | |
""" | |
data_spec = (phase, bidir, pretrain, use_seq2seq) | |
data_specs = { | |
('train', True, True, False): (data_utils.TRAIN_LM, | |
data_utils.TRAIN_REV_LM), | |
('train', True, False, False): (data_utils.TRAIN_BD_CLASS,), | |
('train', False, True, False): (data_utils.TRAIN_LM,), | |
('train', False, True, True): (data_utils.TRAIN_SA,), | |
('train', False, False, False): (data_utils.TRAIN_CLASS,), | |
('test', True, True, False): (data_utils.TEST_LM, | |
data_utils.TRAIN_REV_LM), | |
('test', True, False, False): (data_utils.TEST_BD_CLASS,), | |
('test', False, True, False): (data_utils.TEST_LM,), | |
('test', False, True, True): (data_utils.TEST_SA,), | |
('test', False, False, False): (data_utils.TEST_CLASS,), | |
('valid', True, False, False): (data_utils.VALID_BD_CLASS,), | |
('valid', False, False, False): (data_utils.VALID_CLASS,), | |
} | |
if data_spec not in data_specs: | |
raise ValueError( | |
'Data specification (phase, bidir, pretrain, use_seq2seq) %s not ' | |
'supported' % str(data_spec)) | |
return data_specs[data_spec] | |
def _read_single_sequence_example(file_list, tokens_shape=None): | |
"""Reads and parses SequenceExamples from TFRecord-encoded file_list.""" | |
tf.logging.info('Constructing TFRecordReader from files: %s', file_list) | |
file_queue = tf.train.string_input_producer(file_list) | |
reader = tf.TFRecordReader() | |
seq_key, serialized_record = reader.read(file_queue) | |
ctx, sequence = tf.parse_single_sequence_example( | |
serialized_record, | |
sequence_features={ | |
data_utils.SequenceWrapper.F_TOKEN_ID: | |
tf.FixedLenSequenceFeature(tokens_shape or [], dtype=tf.int64), | |
data_utils.SequenceWrapper.F_LABEL: | |
tf.FixedLenSequenceFeature([], dtype=tf.int64), | |
data_utils.SequenceWrapper.F_WEIGHT: | |
tf.FixedLenSequenceFeature([], dtype=tf.float32), | |
}) | |
return seq_key, ctx, sequence | |
def _read_and_batch(data_dir, | |
fname, | |
state_name, | |
state_size, | |
num_layers, | |
unroll_steps, | |
batch_size, | |
bidir_input=False): | |
"""Inputs for text model. | |
Args: | |
data_dir: str, directory containing TFRecord files of SequenceExample. | |
fname: str, input file name. | |
state_name: string, key for saved state of LSTM. | |
state_size: int, size of LSTM state. | |
num_layers: int, the number of layers in the LSTM. | |
unroll_steps: int, number of timesteps to unroll for TBTT. | |
batch_size: int, batch size. | |
bidir_input: bool, whether the input is bidirectional. If True, creates 2 | |
states, state_name and state_name + '_reverse'. | |
Returns: | |
Instance of NextQueuedSequenceBatch | |
Raises: | |
ValueError: if file for input specification is not found. | |
""" | |
data_path = os.path.join(data_dir, fname) | |
if not tf.gfile.Exists(data_path): | |
raise ValueError('Failed to find file: %s' % data_path) | |
tokens_shape = [2] if bidir_input else [] | |
seq_key, ctx, sequence = _read_single_sequence_example( | |
[data_path], tokens_shape=tokens_shape) | |
# Set up stateful queue reader. | |
state_names = _get_tuple_state_names(num_layers, state_name) | |
initial_states = {} | |
for c_state, h_state in state_names: | |
initial_states[c_state] = tf.zeros(state_size) | |
initial_states[h_state] = tf.zeros(state_size) | |
if bidir_input: | |
rev_state_names = _get_tuple_state_names(num_layers, | |
'{}_reverse'.format(state_name)) | |
for rev_c_state, rev_h_state in rev_state_names: | |
initial_states[rev_c_state] = tf.zeros(state_size) | |
initial_states[rev_h_state] = tf.zeros(state_size) | |
batch = tf.contrib.training.batch_sequences_with_states( | |
input_key=seq_key, | |
input_sequences=sequence, | |
input_context=ctx, | |
input_length=tf.shape(sequence['token_id'])[0], | |
initial_states=initial_states, | |
num_unroll=unroll_steps, | |
batch_size=batch_size, | |
allow_small_batch=False, | |
num_threads=4, | |
capacity=batch_size * 10, | |
make_keys_unique=True, | |
make_keys_unique_seed=29392) | |
return batch | |
def inputs(data_dir=None, | |
phase='train', | |
bidir=False, | |
pretrain=False, | |
use_seq2seq=False, | |
state_name='lstm', | |
state_size=None, | |
num_layers=0, | |
batch_size=32, | |
unroll_steps=100, | |
eos_id=None): | |
"""Inputs for text model. | |
Args: | |
data_dir: str, directory containing TFRecord files of SequenceExample. | |
phase: str, dataset for evaluation {'train', 'valid', 'test'}. | |
bidir: bool, bidirectional LSTM. | |
pretrain: bool, whether to read pretraining data or classification data. | |
use_seq2seq: bool, whether to read seq2seq data or the language model data. | |
state_name: string, key for saved state of LSTM. | |
state_size: int, size of LSTM state. | |
num_layers: int, the number of LSTM layers. | |
batch_size: int, batch size. | |
unroll_steps: int, number of timesteps to unroll for TBTT. | |
eos_id: int, id of end of sequence. used for the kl weights on vat | |
Returns: | |
Instance of VatxtInput (x2 if bidir=True and pretrain=True, i.e. forward and | |
reverse). | |
""" | |
with tf.name_scope('inputs'): | |
filenames = _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq) | |
if bidir and pretrain: | |
# Bidirectional pretraining | |
# Requires separate forward and reverse language model data. | |
forward_fname, reverse_fname = filenames | |
forward_batch = _read_and_batch(data_dir, forward_fname, state_name, | |
state_size, num_layers, unroll_steps, | |
batch_size) | |
state_name_rev = state_name + '_reverse' | |
reverse_batch = _read_and_batch(data_dir, reverse_fname, state_name_rev, | |
state_size, num_layers, unroll_steps, | |
batch_size) | |
forward_input = VatxtInput( | |
forward_batch, | |
state_name=state_name, | |
num_states=num_layers, | |
eos_id=eos_id) | |
reverse_input = VatxtInput( | |
reverse_batch, | |
state_name=state_name_rev, | |
num_states=num_layers, | |
eos_id=eos_id) | |
return forward_input, reverse_input | |
elif bidir: | |
# Classifier bidirectional LSTM | |
# Shared data source, but separate token/state streams | |
fname, = filenames | |
batch = _read_and_batch( | |
data_dir, | |
fname, | |
state_name, | |
state_size, | |
num_layers, | |
unroll_steps, | |
batch_size, | |
bidir_input=True) | |
forward_tokens, reverse_tokens = _split_bidir_tokens(batch) | |
forward_input = VatxtInput( | |
batch, | |
state_name=state_name, | |
tokens=forward_tokens, | |
num_states=num_layers) | |
reverse_input = VatxtInput( | |
batch, | |
state_name=state_name + '_reverse', | |
tokens=reverse_tokens, | |
num_states=num_layers) | |
return forward_input, reverse_input | |
else: | |
# Unidirectional LM or classifier | |
fname, = filenames | |
batch = _read_and_batch( | |
data_dir, | |
fname, | |
state_name, | |
state_size, | |
num_layers, | |
unroll_steps, | |
batch_size, | |
bidir_input=False) | |
return VatxtInput( | |
batch, state_name=state_name, num_states=num_layers, eos_id=eos_id) | |