File size: 4,881 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""CNN-BiLSTM sentence encoder."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from base import embeddings
from model import model_helpers


class Encoder(object):
  def __init__(self, config, inputs, pretrained_embeddings):
    self._config = config
    self._inputs = inputs

    self.word_reprs = self._get_word_reprs(pretrained_embeddings)
    self.uni_fw, self.uni_bw = self._get_unidirectional_reprs(self.word_reprs)
    self.uni_reprs = tf.concat([self.uni_fw, self.uni_bw], axis=-1)
    self.bi_fw, self.bi_bw, self.bi_reprs = self._get_bidirectional_reprs(
        self.uni_reprs)

  def _get_word_reprs(self, pretrained_embeddings):
    with tf.variable_scope('word_embeddings'):
      word_embedding_matrix = tf.get_variable(
          'word_embedding_matrix', initializer=pretrained_embeddings)
      word_embeddings = tf.nn.embedding_lookup(
          word_embedding_matrix, self._inputs.words)
      word_embeddings = tf.nn.dropout(word_embeddings, self._inputs.keep_prob)
      word_embeddings *= tf.get_variable('emb_scale', initializer=1.0)

    if not self._config.use_chars:
      return word_embeddings

    with tf.variable_scope('char_embeddings'):
      char_embedding_matrix = tf.get_variable(
          'char_embeddings',
          shape=[embeddings.NUM_CHARS, self._config.char_embedding_size])
      char_embeddings = tf.nn.embedding_lookup(char_embedding_matrix,
                                               self._inputs.chars)
      shape = tf.shape(char_embeddings)
      char_embeddings = tf.reshape(
          char_embeddings,
          shape=[-1, shape[-2], self._config.char_embedding_size])
      char_reprs = []
      for filter_width in self._config.char_cnn_filter_widths:
        conv = tf.layers.conv1d(
            char_embeddings, self._config.char_cnn_n_filters, filter_width)
        conv = tf.nn.relu(conv)
        conv = tf.nn.dropout(tf.reduce_max(conv, axis=1),
                             self._inputs.keep_prob)
        conv = tf.reshape(conv, shape=[-1, shape[1],
                                       self._config.char_cnn_n_filters])
        char_reprs.append(conv)
      return tf.concat([word_embeddings] + char_reprs, axis=-1)

  def _get_unidirectional_reprs(self, word_reprs):
    with tf.variable_scope('unidirectional_reprs'):
      word_lstm_input_size = (
          self._config.word_embedding_size if not self._config.use_chars else
          (self._config.word_embedding_size +
           len(self._config.char_cnn_filter_widths)
           * self._config.char_cnn_n_filters))
      word_reprs.set_shape([None, None, word_lstm_input_size])
      (outputs_fw, outputs_bw), _ = tf.nn.bidirectional_dynamic_rnn(
          model_helpers.multi_lstm_cell(self._config.unidirectional_sizes,
                                        self._inputs.keep_prob,
                                        self._config.projection_size),
          model_helpers.multi_lstm_cell(self._config.unidirectional_sizes,
                                        self._inputs.keep_prob,
                                        self._config.projection_size),
          word_reprs,
          dtype=tf.float32,
          sequence_length=self._inputs.lengths,
          scope='unilstm'
      )
      return outputs_fw, outputs_bw

  def _get_bidirectional_reprs(self, uni_reprs):
    with tf.variable_scope('bidirectional_reprs'):
      current_outputs = uni_reprs
      outputs_fw, outputs_bw = None, None
      for size in self._config.bidirectional_sizes:
        (outputs_fw, outputs_bw), _ = tf.nn.bidirectional_dynamic_rnn(
            model_helpers.lstm_cell(size, self._inputs.keep_prob,
                                    self._config.projection_size),
            model_helpers.lstm_cell(size, self._inputs.keep_prob,
                                    self._config.projection_size),
            current_outputs,
            dtype=tf.float32,
            sequence_length=self._inputs.lengths,
            scope='bilstm'
        )
        current_outputs = tf.concat([outputs_fw, outputs_bw], axis=-1)
      return outputs_fw, outputs_bw, current_outputs