File size: 5,442 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A multi-task and semi-supervised NLP model."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from model import encoder
from model import shared_inputs


class Inference(object):
  def __init__(self, config, inputs, pretrained_embeddings, tasks):
    with tf.variable_scope('encoder'):
      self.encoder = encoder.Encoder(config, inputs, pretrained_embeddings)
    self.modules = {}
    for task in tasks:
      with tf.variable_scope(task.name):
        self.modules[task.name] = task.get_module(inputs, self.encoder)


class Model(object):
  def __init__(self, config, pretrained_embeddings, tasks):
    self._config = config
    self._tasks = tasks

    self._global_step, self._optimizer = self._get_optimizer()
    self._inputs = shared_inputs.Inputs(config)
    with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
      inference = Inference(config, self._inputs, pretrained_embeddings,
                            tasks)
      self._trainer = inference
      self._tester = inference
      self._teacher = inference
      if config.ema_test or config.ema_teacher:
        ema = tf.train.ExponentialMovingAverage(config.ema_decay)
        model_vars = tf.get_collection("trainable_variables", "model")
        ema_op = ema.apply(model_vars)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)

        def ema_getter(getter, name, *args, **kwargs):
          var = getter(name, *args, **kwargs)
          return ema.average(var)

        scope.set_custom_getter(ema_getter)
        inference_ema = Inference(
            config, self._inputs, pretrained_embeddings, tasks)
        if config.ema_teacher:
          self._teacher = inference_ema
        if config.ema_test:
          self._tester = inference_ema

    self._unlabeled_loss = self._get_consistency_loss(tasks)
    self._unlabeled_train_op = self._get_train_op(self._unlabeled_loss)
    self._labeled_train_ops = {}
    for task in self._tasks:
      task_loss = self._trainer.modules[task.name].supervised_loss
      self._labeled_train_ops[task.name] = self._get_train_op(task_loss)

  def _get_consistency_loss(self, tasks):
    return sum([self._trainer.modules[task.name].unsupervised_loss
                for task in tasks])

  def _get_optimizer(self):
    global_step = tf.get_variable('global_step', initializer=0, trainable=False)
    warm_up_multiplier = (tf.minimum(tf.to_float(global_step),
                                     self._config.warm_up_steps)
                          / self._config.warm_up_steps)
    decay_multiplier = 1.0 / (1 + self._config.lr_decay *
                              tf.sqrt(tf.to_float(global_step)))
    lr = self._config.lr * warm_up_multiplier * decay_multiplier
    optimizer = tf.train.MomentumOptimizer(lr, self._config.momentum)
    return global_step, optimizer

  def _get_train_op(self, loss):
    grads, vs = zip(*self._optimizer.compute_gradients(loss))
    grads, _ = tf.clip_by_global_norm(grads, self._config.grad_clip)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      return self._optimizer.apply_gradients(
          zip(grads, vs), global_step=self._global_step)

  def _create_feed_dict(self, mb, model, is_training=True):
    feed = self._inputs.create_feed_dict(mb, is_training)
    if mb.task_name in model.modules:
      model.modules[mb.task_name].update_feed_dict(feed, mb)
    else:
      for module in model.modules.values():
        module.update_feed_dict(feed, mb)
    return feed

  def train_unlabeled(self, sess, mb):
    return sess.run([self._unlabeled_train_op, self._unlabeled_loss],
                    feed_dict=self._create_feed_dict(mb, self._trainer))[1]

  def train_labeled(self, sess, mb):
    return sess.run([self._labeled_train_ops[mb.task_name],
                     self._trainer.modules[mb.task_name].supervised_loss,],
                    feed_dict=self._create_feed_dict(mb, self._trainer))[1]

  def run_teacher(self, sess, mb):
    result = sess.run({task.name: self._teacher.modules[task.name].probs
                       for task in self._tasks},
                      feed_dict=self._create_feed_dict(mb, self._teacher,
                                                       False))
    for task_name, probs in result.iteritems():
      mb.teacher_predictions[task_name] = probs.astype('float16')

  def test(self, sess, mb):
    return sess.run(
        [self._tester.modules[mb.task_name].supervised_loss,
         self._tester.modules[mb.task_name].preds],
        feed_dict=self._create_feed_dict(mb, self._tester, False))

  def get_global_step(self, sess):
    return sess.run(self._global_step)