NCTC / models /research /cvt_text /model /multitask_model.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
5.44 kB
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A multi-task and semi-supervised NLP model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from model import encoder
from model import shared_inputs
class Inference(object):
def __init__(self, config, inputs, pretrained_embeddings, tasks):
with tf.variable_scope('encoder'):
self.encoder = encoder.Encoder(config, inputs, pretrained_embeddings)
self.modules = {}
for task in tasks:
with tf.variable_scope(task.name):
self.modules[task.name] = task.get_module(inputs, self.encoder)
class Model(object):
def __init__(self, config, pretrained_embeddings, tasks):
self._config = config
self._tasks = tasks
self._global_step, self._optimizer = self._get_optimizer()
self._inputs = shared_inputs.Inputs(config)
with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
inference = Inference(config, self._inputs, pretrained_embeddings,
tasks)
self._trainer = inference
self._tester = inference
self._teacher = inference
if config.ema_test or config.ema_teacher:
ema = tf.train.ExponentialMovingAverage(config.ema_decay)
model_vars = tf.get_collection("trainable_variables", "model")
ema_op = ema.apply(model_vars)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)
def ema_getter(getter, name, *args, **kwargs):
var = getter(name, *args, **kwargs)
return ema.average(var)
scope.set_custom_getter(ema_getter)
inference_ema = Inference(
config, self._inputs, pretrained_embeddings, tasks)
if config.ema_teacher:
self._teacher = inference_ema
if config.ema_test:
self._tester = inference_ema
self._unlabeled_loss = self._get_consistency_loss(tasks)
self._unlabeled_train_op = self._get_train_op(self._unlabeled_loss)
self._labeled_train_ops = {}
for task in self._tasks:
task_loss = self._trainer.modules[task.name].supervised_loss
self._labeled_train_ops[task.name] = self._get_train_op(task_loss)
def _get_consistency_loss(self, tasks):
return sum([self._trainer.modules[task.name].unsupervised_loss
for task in tasks])
def _get_optimizer(self):
global_step = tf.get_variable('global_step', initializer=0, trainable=False)
warm_up_multiplier = (tf.minimum(tf.to_float(global_step),
self._config.warm_up_steps)
/ self._config.warm_up_steps)
decay_multiplier = 1.0 / (1 + self._config.lr_decay *
tf.sqrt(tf.to_float(global_step)))
lr = self._config.lr * warm_up_multiplier * decay_multiplier
optimizer = tf.train.MomentumOptimizer(lr, self._config.momentum)
return global_step, optimizer
def _get_train_op(self, loss):
grads, vs = zip(*self._optimizer.compute_gradients(loss))
grads, _ = tf.clip_by_global_norm(grads, self._config.grad_clip)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
return self._optimizer.apply_gradients(
zip(grads, vs), global_step=self._global_step)
def _create_feed_dict(self, mb, model, is_training=True):
feed = self._inputs.create_feed_dict(mb, is_training)
if mb.task_name in model.modules:
model.modules[mb.task_name].update_feed_dict(feed, mb)
else:
for module in model.modules.values():
module.update_feed_dict(feed, mb)
return feed
def train_unlabeled(self, sess, mb):
return sess.run([self._unlabeled_train_op, self._unlabeled_loss],
feed_dict=self._create_feed_dict(mb, self._trainer))[1]
def train_labeled(self, sess, mb):
return sess.run([self._labeled_train_ops[mb.task_name],
self._trainer.modules[mb.task_name].supervised_loss,],
feed_dict=self._create_feed_dict(mb, self._trainer))[1]
def run_teacher(self, sess, mb):
result = sess.run({task.name: self._teacher.modules[task.name].probs
for task in self._tasks},
feed_dict=self._create_feed_dict(mb, self._teacher,
False))
for task_name, probs in result.iteritems():
mb.teacher_predictions[task_name] = probs.astype('float16')
def test(self, sess, mb):
return sess.run(
[self._tester.modules[mb.task_name].supervised_loss,
self._tester.modules[mb.task_name].preds],
feed_dict=self._create_feed_dict(mb, self._tester, False))
def get_global_step(self, sess):
return sess.run(self._global_step)