|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Input pipe for feeding examples to a Seq2Label model graph.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
from google.protobuf import text_format |
|
|
|
from protos import seq2label_pb2 |
|
import seq2label_utils |
|
|
|
DNA_BASES = tuple('ACGT') |
|
NUM_DNA_BASES = len(DNA_BASES) |
|
|
|
|
|
AMBIGUITY_CODES = { |
|
'K': 'GT', |
|
'M': 'AC', |
|
'R': 'AG', |
|
'Y': 'CT', |
|
'S': 'CG', |
|
'W': 'AT', |
|
'B': 'CGT', |
|
'V': 'ACG', |
|
'H': 'ACT', |
|
'D': 'AGT', |
|
'X': 'ACGT', |
|
'N': 'ACGT' |
|
} |
|
|
|
|
|
def load_dataset_info(dataset_info_path): |
|
"""Load a `Seq2LabelDatasetInfo` from a serialized text proto file.""" |
|
dataset_info = seq2label_pb2.Seq2LabelDatasetInfo() |
|
with tf.gfile.Open(dataset_info_path, 'r') as f: |
|
text_format.Parse(f.read(), dataset_info) |
|
return dataset_info |
|
|
|
|
|
class _InputEncoding(object): |
|
"""A helper class providing the graph operations needed to encode input. |
|
|
|
Instantiation of an _InputEncoding will write on the default TF graph, so it |
|
should only be instantiated inside the `input_fn`. |
|
|
|
Attributes: |
|
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}. |
|
targets: list of strings; the names of the labels of interest (e.g. |
|
"species"). |
|
dna_bases: a tuple of the recognized DNA alphabet. |
|
n_bases: the size of the DNA alphabet. |
|
all_characters: list of recognized alphabet, including ambiguity codes. |
|
label_values: a tuple of strings, the possible label values of the |
|
prediction target. |
|
n_labels: the size of label_values |
|
fixed_read_length: an integer value of the statically-known read length, or |
|
None if the read length is to be determined dynamically. |
|
""" |
|
|
|
def __init__(self, |
|
dataset_info, |
|
mode, |
|
targets, |
|
noise_rate=0.0, |
|
fixed_read_length=None): |
|
self.mode = mode |
|
self.targets = targets |
|
self.dna_bases = DNA_BASES |
|
self.n_bases = NUM_DNA_BASES |
|
self.all_characters = list(DNA_BASES) + sorted(AMBIGUITY_CODES.keys()) |
|
self.character_encodings = np.concatenate( |
|
[[self._character_to_base_distribution(char)] |
|
for char in self.all_characters], |
|
axis=0) |
|
all_legal_label_values = seq2label_utils.get_all_label_values(dataset_info) |
|
|
|
self.characters_table = tf.contrib.lookup.index_table_from_tensor( |
|
mapping=self.all_characters) |
|
self.label_tables = { |
|
target: tf.contrib.lookup.index_table_from_tensor( |
|
all_legal_label_values[target]) |
|
for target in targets |
|
} |
|
self.fixed_read_length = fixed_read_length |
|
self.noise_rate = noise_rate |
|
|
|
def _character_to_base_distribution(self, char): |
|
"""Maps the given character to a probability distribution over DNA bases. |
|
|
|
Args: |
|
char: character to be encoded as a probability distribution over bases. |
|
|
|
Returns: |
|
Array of size (self.n_bases,) representing the identity of the given |
|
character as a distribution over the possible DNA bases, self.dna_bases. |
|
|
|
Raises: |
|
ValueError: if the given character is not contained in the recognized |
|
alphabet, self.all_characters. |
|
""" |
|
if char not in self.all_characters: |
|
raise ValueError( |
|
'Base distribution requested for unrecognized character %s.' % char) |
|
possible_bases = AMBIGUITY_CODES[char] if char in AMBIGUITY_CODES else char |
|
base_indices = [self.dna_bases.index(base) for base in possible_bases] |
|
probability_weight = 1.0 / len(possible_bases) |
|
distribution = np.zeros((self.n_bases)) |
|
distribution[base_indices] = probability_weight |
|
return distribution |
|
|
|
def encode_read(self, string_seq): |
|
"""Converts the input read sequence to one-hot encoding. |
|
|
|
Args: |
|
string_seq: tf.String; input read sequence. |
|
|
|
Returns: |
|
Input read sequence as a one-hot encoded Tensor, with depth and ordering |
|
of one-hot encoding determined by the given bases. Ambiguous characters |
|
such as "N" and "S" are encoded as a probability distribution over the |
|
possible bases they represent. |
|
""" |
|
with tf.variable_scope('encode_read'): |
|
read = tf.string_split([string_seq], delimiter='').values |
|
read = self.characters_table.lookup(read) |
|
read = tf.cast(tf.gather(self.character_encodings, read), tf.float32) |
|
if self.fixed_read_length: |
|
read = tf.reshape(read, (self.fixed_read_length, self.n_bases)) |
|
return read |
|
|
|
def encode_label(self, target, string_label): |
|
"""Converts the label value to an integer encoding. |
|
|
|
Args: |
|
target: str; the target name. |
|
string_label: tf.String; value of the label for the current input read. |
|
|
|
Returns: |
|
Given label value as an index into the possible_target_values. |
|
""" |
|
with tf.variable_scope('encode_label/{}'.format(target)): |
|
return tf.cast(self.label_tables[target].lookup(string_label), tf.int32) |
|
|
|
def _empty_label(self): |
|
return tf.constant((), dtype=tf.int32, shape=()) |
|
|
|
def parse_single_tfexample(self, serialized_example): |
|
"""Parses a tf.train.Example proto to a one-hot encoded read, label pair. |
|
|
|
Injects noise into the incoming tf.train.Example's read sequence |
|
when noise_rate is non-zero. |
|
|
|
Args: |
|
serialized_example: string; the serialized tf.train.Example proto |
|
containing the read sequence and label value of interest as |
|
tf.FixedLenFeatures. |
|
|
|
Returns: |
|
Tuple (features, labels) of dicts for the input features and prediction |
|
targets. |
|
""" |
|
with tf.variable_scope('parse_single_tfexample'): |
|
features_spec = {'sequence': tf.FixedLenFeature([], tf.string)} |
|
for target in self.targets: |
|
features_spec[target] = tf.FixedLenFeature([], tf.string) |
|
features = tf.parse_single_example( |
|
serialized_example, features=features_spec) |
|
if self.noise_rate > 0.0: |
|
read_sequence = tf.py_func(seq2label_utils.add_read_noise, |
|
[features['sequence'], self.noise_rate], |
|
(tf.string)) |
|
else: |
|
read_sequence = features['sequence'] |
|
read_sequence = self.encode_read(read_sequence) |
|
read_features = {'sequence': read_sequence} |
|
if self.mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): |
|
label = { |
|
target: self.encode_label(target, features[target]) |
|
for target in self.targets |
|
} |
|
else: |
|
label = {target: self._empty_label() for target in self.targets} |
|
return read_features, label |
|
|
|
|
|
class InputDataset(object): |
|
"""A class providing access to input data for the Seq2Label model. |
|
|
|
Attributes: |
|
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}. |
|
targets: list of strings; the names of the labels of interest (e.g. |
|
"species"). |
|
dataset_info: a `Seq2LabelDatasetInfo` message reflecting the dataset |
|
metadata. |
|
initializer: the TF initializer op for the underlying iterator, which |
|
will rewind the iterator. |
|
is_train: Boolean indicating whether or not the execution mode is TRAIN. |
|
""" |
|
|
|
def __init__(self, |
|
mode, |
|
targets, |
|
dataset_info, |
|
train_epochs=None, |
|
noise_rate=0.0, |
|
random_seed=None, |
|
input_tfrecord_files=None, |
|
fixed_read_length=None, |
|
ensure_constant_batch_size=False, |
|
num_parallel_calls=32): |
|
"""Constructor for InputDataset. |
|
|
|
Args: |
|
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}. |
|
targets: list of strings; the names of the labels of interest (e.g. |
|
"species"). |
|
dataset_info: a `Seq2LabelDatasetInfo` message reflecting the dataset |
|
metadata. |
|
train_epochs: the number of training epochs to perform, if mode==TRAIN. |
|
noise_rate: float [0.0, 1.0] specifying rate at which to inject |
|
base-flipping noise into the read sequences. |
|
random_seed: seed to be used for shuffling, if mode==TRAIN. |
|
input_tfrecord_files: a list of filenames for TFRecords of TF examples. |
|
fixed_read_length: an integer value of the statically-known read length, |
|
or None if the read length is to be determined dynamically. The read |
|
length must be known statically for TPU execution. |
|
ensure_constant_batch_size: ensure a constant batch size at the expense of |
|
discarding the last "short" batch. This also gives us a statically |
|
constant batch size, which is essential for e.g. the TPU platform. |
|
num_parallel_calls: the number of dataset elements to process in parallel. |
|
If None, elements will be processed sequentially. |
|
""" |
|
self.input_tfrecord_files = input_tfrecord_files |
|
self.mode = mode |
|
self.targets = targets |
|
self.dataset_info = dataset_info |
|
self._train_epochs = train_epochs |
|
self._noise_rate = noise_rate |
|
self._random_seed = random_seed |
|
if random_seed is not None: |
|
np.random.seed(random_seed) |
|
self._fixed_read_length = fixed_read_length |
|
self._ensure_constant_batch_size = ensure_constant_batch_size |
|
self._num_parallel_calls = num_parallel_calls |
|
|
|
@staticmethod |
|
def from_tfrecord_files(input_tfrecord_files, *args, **kwargs): |
|
return InputDataset( |
|
*args, input_tfrecord_files=input_tfrecord_files, **kwargs) |
|
|
|
@property |
|
def is_train(self): |
|
return self.mode == tf.estimator.ModeKeys.TRAIN |
|
|
|
def input_fn(self, params): |
|
"""Supplies input for the model. |
|
|
|
This function supplies input to our model as a function of the mode. |
|
|
|
Args: |
|
params: a dictionary, containing: |
|
- params['batch_size']: the integer batch size. |
|
|
|
Returns: |
|
A tuple of two values as follows: |
|
1) the *features* dict, containing a tensor value for keys as follows: |
|
- "sequence" - the encoded read input sequence. |
|
2) the *labels* dict. containing a key for `target`, whose value is: |
|
- a string Tensor value (in TRAIN/EVAL mode), or |
|
- a blank Tensor (PREDICT mode). |
|
""" |
|
randomize_input = self.is_train |
|
batch_size = params['batch_size'] |
|
|
|
encoding = _InputEncoding( |
|
self.dataset_info, |
|
self.mode, |
|
self.targets, |
|
noise_rate=self._noise_rate, |
|
fixed_read_length=self._fixed_read_length) |
|
|
|
dataset = tf.data.TFRecordDataset(self.input_tfrecord_files) |
|
dataset = dataset.map( |
|
encoding.parse_single_tfexample, |
|
num_parallel_calls=self._num_parallel_calls) |
|
|
|
dataset = dataset.repeat(self._train_epochs if self.is_train else 1) |
|
if randomize_input: |
|
dataset = dataset.shuffle( |
|
buffer_size=max(1000, batch_size), seed=self._random_seed) |
|
|
|
if self._ensure_constant_batch_size: |
|
|
|
|
|
dataset = dataset.batch(batch_size, drop_remainder=True) |
|
else: |
|
dataset = dataset.batch(batch_size) |
|
|
|
|
|
dataset = dataset.prefetch(2) |
|
|
|
|
|
iterator = dataset.make_initializable_iterator() |
|
self.initializer = iterator.initializer |
|
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) |
|
|
|
features, labels = iterator.get_next() |
|
return (features, labels) |
|
|