File size: 5,007 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BERT Question Answering model."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf, tf_keras

from official.nlp.modeling import networks


@tf_keras.utils.register_keras_serializable(package='Text')
class BertSpanLabeler(tf_keras.Model):
  """Span labeler model based on a BERT-style transformer-based encoder.

  This is an implementation of the network structure surrounding a transformer
  encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
  for Language Understanding" (https://arxiv.org/abs/1810.04805).

  The BertSpanLabeler allows a user to pass in a transformer encoder, and
  instantiates a span labeling network based on a single dense layer.

  *Note* that the model is constructed by
  [Keras Functional API](https://keras.io/guides/functional_api/).

  Args:
    network: A transformer network. This network should output a sequence output
      and a classification output. Furthermore, it should expose its embedding
      table via a `get_embedding_table` method.
    initializer: The initializer (if any) to use in the span labeling network.
      Defaults to a Glorot uniform initializer.
    output: The output style for this network. Can be either `logit`' or
      `predictions`.
  """

  def __init__(self,
               network,
               initializer='glorot_uniform',
               output='logits',
               **kwargs):

    # We want to use the inputs of the passed network as the inputs to this
    # Model. To do this, we need to keep a handle to the network inputs for use
    # when we construct the Model object at the end of init.
    inputs = network.inputs

    # Because we have a copy of inputs to create this Model object, we can
    # invoke the Network object with its own input tensors to start the Model.
    outputs = network(inputs)
    if isinstance(outputs, list):
      sequence_output = outputs[0]
    else:
      sequence_output = outputs['sequence_output']

    # The input network (typically a transformer model) may get outputs from all
    # layers. When this case happens, we retrieve the last layer output.
    if isinstance(sequence_output, list):
      sequence_output = sequence_output[-1]

    # This is an instance variable for ease of access to the underlying task
    # network.
    span_labeling = networks.SpanLabeling(
        input_width=sequence_output.shape[-1],
        initializer=initializer,
        output=output,
        name='span_labeling')
    start_logits, end_logits = span_labeling(sequence_output)

    # Use identity layers wrapped in lambdas to explicitly name the output
    # tensors. This allows us to use string-keyed dicts in Keras fit/predict/
    # evaluate calls.
    start_logits = tf_keras.layers.Lambda(
        tf.identity, name='start_positions')(
            start_logits)
    end_logits = tf_keras.layers.Lambda(
        tf.identity, name='end_positions')(
            end_logits)

    logits = [start_logits, end_logits]

    # b/164516224
    # Once we've created the network using the Functional API, we call
    # super().__init__ as though we were invoking the Functional API Model
    # constructor, resulting in this object having all the properties of a model
    # created using the Functional API. Once super().__init__ is called, we
    # can assign attributes to `self` - note that all `self` assignments are
    # below this line.
    super(BertSpanLabeler, self).__init__(
        inputs=inputs, outputs=logits, **kwargs)
    self._network = network
    config_dict = {
        'network': network,
        'initializer': initializer,
        'output': output,
    }
    # We are storing the config dict as a namedtuple here to ensure checkpoint
    # compatibility with an earlier version of this model which did not track
    # the config dict attribute. TF does not track immutable attrs which
    # do not contain Trackables, so by creating a config namedtuple instead of
    # a dict we avoid tracking it.
    config_cls = collections.namedtuple('Config', config_dict.keys())
    self._config = config_cls(**config_dict)
    self.span_labeling = span_labeling

  @property
  def checkpoint_items(self):
    return dict(encoder=self._network)

  def get_config(self):
    return dict(self._config._asdict())

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)