farrell236 commited on
Commit
77c8482
1 Parent(s): 6692ae2

Upload 17 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ assets/attn_plot.png filter=lfs diff=lfs merge=lfs -text
36
+ assets/examples.png filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/cxr_validator_model.tf/variables/variables.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
38
+ checkpoints/RATCHET.tf/variables/variables.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import datetime
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import streamlit as st
7
+ import tensorflow as tf
8
+
9
+ from skimage import io
10
+ from transformer import Transformer
11
+ from tokenizers import ByteLevelBPETokenizer
12
+
13
+
14
+ @st.cache_resource
15
+ def load_validator():
16
+ validator_model = tf.keras.models.load_model('checkpoints/cxr_validator_model.tf')
17
+ print('Validator Model Loaded!')
18
+ return validator_model
19
+
20
+
21
+ @st.cache_resource
22
+ def load_model():
23
+
24
+ # Load Tokenizer
25
+ tokenizer = ByteLevelBPETokenizer(
26
+ 'mimic/mimic-vocab.json',
27
+ 'mimic/mimic-merges.txt',
28
+ )
29
+
30
+ # Load Model
31
+ hparams = default_hparams()
32
+ transformer = Transformer(
33
+ num_layers=hparams['num_layers'],
34
+ d_model=hparams['d_model'],
35
+ num_heads=hparams['num_heads'],
36
+ dff=hparams['dff'],
37
+ target_vocab_size=tokenizer.get_vocab_size(),
38
+ dropout_rate=hparams['dropout_rate'])
39
+ transformer.load_weights('checkpoints/RATCHET.tf')
40
+ print(f'Model Loaded! Checkpoint file: checkpoints/RATCHET.tf')
41
+
42
+ return transformer, tokenizer
43
+
44
+
45
+ def top_k_logits(logits, k):
46
+ if k == 0:
47
+ # no truncation
48
+ return logits
49
+
50
+ def _top_k():
51
+ values, _ = tf.nn.top_k(logits, k=k)
52
+ min_values = values[:, -1, tf.newaxis]
53
+ return tf.where(
54
+ logits < min_values,
55
+ tf.ones_like(logits, dtype=logits.dtype) * -1e10,
56
+ logits,
57
+ )
58
+ return tf.cond(
59
+ tf.equal(k, 0),
60
+ lambda: logits,
61
+ lambda: _top_k(),
62
+ )
63
+
64
+
65
+ def top_p_logits(logits, p):
66
+ """Nucleus sampling"""
67
+ batch, _ = logits.shape.as_list()
68
+ sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
69
+ cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
70
+ indices = tf.stack([
71
+ tf.range(0, batch),
72
+ # number of indices to include
73
+ tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
74
+ ], axis=-1)
75
+ min_values = tf.gather_nd(sorted_logits, indices)
76
+ return tf.where(
77
+ logits < min_values,
78
+ tf.ones_like(logits) * -1e10,
79
+ logits,
80
+ )
81
+
82
+
83
+ def evaluate(inp_img, tokenizer, transformer, temperature, top_k, top_p, options, seed, MAX_LENGTH=128):
84
+
85
+ # The first token to the transformer should be the start token
86
+ output = tf.convert_to_tensor([[tokenizer.token_to_id('<s>')]])
87
+
88
+ my_bar = st.progress(0)
89
+ for i in tqdm.tqdm(range(MAX_LENGTH)):
90
+ my_bar.progress(i/MAX_LENGTH)
91
+
92
+ # predictions.shape == (batch_size, seq_len, vocab_size)
93
+ predictions = transformer([inp_img, output], training=False)
94
+
95
+ # select the last word from the seq_len dimension
96
+ predictions = predictions[:, -1, :] / temperature # (batch_size, vocab_size)
97
+ predictions = top_k_logits(predictions, k=top_k)
98
+ predictions = top_p_logits(predictions, p=top_p)
99
+
100
+ if options == 'Greedy':
101
+ predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)[:, tf.newaxis]
102
+ elif options == 'Sampling':
103
+ predicted_id = tf.random.categorical(predictions, num_samples=1, dtype=tf.int32, seed=seed)
104
+ else:
105
+ st.write('SHOULD NOT HAPPEN')
106
+
107
+ # return the result if the predicted_id is equal to the end token
108
+ if predicted_id == 2: # stop token #tokenizer_en.vocab_size + 1:
109
+ my_bar.empty()
110
+ break
111
+
112
+ # concatentate the predicted_id to the output which is given to the decoder
113
+ # as its input.
114
+ output = tf.concat([output, predicted_id], axis=-1)
115
+
116
+ my_bar.empty()
117
+
118
+ # transformer([inp_img, output[:, :-1]], training=False)
119
+ return tf.squeeze(output, axis=0)[1:], transformer.decoder.last_attn_scores
120
+
121
+
122
+ def main():
123
+
124
+ st.title('Chest X-ray AI Diagnosis Demo')
125
+ st.text('Made with Streamlit and Attention RNN')
126
+
127
+ transformer, tokenizer = load_model()
128
+ cxr_validator_model = load_validator()
129
+
130
+ st.sidebar.title('Configuration')
131
+ options = st.sidebar.selectbox('Generation Method', ('Greedy', 'Sampling'))
132
+ seed = st.sidebar.number_input('Sampling Seed:', value=42)
133
+ temperature = st.sidebar.number_input('Temperature', value=1.)
134
+ top_k = st.sidebar.slider('top_k', min_value=0, max_value=tokenizer.get_vocab_size(), value=6, step=1)
135
+ top_p = st.sidebar.slider('top_p', min_value=0., max_value=1., value=1., step=0.01)
136
+ attention_head = st.sidebar.slider('attention_head', min_value=-1, max_value=7, value=-1, step=1)
137
+
138
+ st.sidebar.info('PRIVACY POLICY: Uploaded images are never stored on disk.')
139
+
140
+ st.set_option('deprecation.showfileUploaderEncoding', False)
141
+ uploaded_file = st.file_uploader('Choose an image...', type=('png', 'jpg', 'jpeg'))
142
+
143
+ if uploaded_file:
144
+
145
+ # Read input image with size [1, H, W, 1] and range (0, 255)
146
+ img_array = io.imread(uploaded_file, as_gray=True)[None, ..., None]
147
+
148
+ # Convert image to float values in (0, 1)
149
+ img_array = tf.image.convert_image_dtype(img_array, tf.float32)
150
+
151
+ # Resize image with padding to [1, 224, 224, 1]
152
+ img_array = tf.image.resize_with_pad(img_array, 224, 224, method=tf.image.ResizeMethod.BILINEAR)
153
+
154
+ # Display input image
155
+ st.image(np.squeeze(img_array.numpy()), caption='Uploaded Image')
156
+
157
+ # Check image
158
+ valid = tf.nn.sigmoid(cxr_validator_model(img_array))
159
+ if valid < 0.1:
160
+ st.info('Image is not a Chest X-ray')
161
+ return
162
+
163
+ # Log datetime
164
+ print('[{}] Running Analysis...'
165
+ .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
166
+
167
+ # Generate radiology report
168
+ with st.spinner('Generating report... Do not refresh or close window.'):
169
+ result, attention_weights = evaluate(img_array, tokenizer, transformer,
170
+ temperature, top_k, top_p,
171
+ options, seed)
172
+ predicted_sentence = tokenizer.decode(result)
173
+
174
+ # Display generated text
175
+ st.subheader('Generated Report:')
176
+ st.write(predicted_sentence)
177
+ # st.info(predicted_sentence)
178
+
179
+ st.subheader('Attention Plot:')
180
+
181
+ attn_map = attention_weights[0] # squeeze
182
+ if attention_head == -1: # average attention heads
183
+ attn_map = tf.reduce_mean(attn_map, axis=0)
184
+ else: # select attention heads
185
+ attn_map = attn_map[attention_head]
186
+ attn_map = attn_map / attn_map.numpy().max() * 255
187
+
188
+ fig = plt.figure(figsize=(40, 80))
189
+
190
+ for i in range(attn_map.shape[0] - 1):
191
+ attn_token = attn_map[i, ...]
192
+ attn_token = tf.reshape(attn_token, [7, 7])
193
+
194
+ ax = fig.add_subplot(16, 8, i + 1)
195
+ ax.set_title(tokenizer.decode([result.numpy()[i]]))
196
+ img = ax.imshow(np.squeeze(img_array))
197
+ ax.imshow(attn_token, cmap='gray', alpha=0.6, extent=img.get_extent())
198
+
199
+ st.pyplot(plt)
200
+
201
+ # Run again?
202
+ st.button('Regenerate Report')
203
+
204
+
205
+ if __name__ == '__main__':
206
+
207
+ tf.config.set_visible_devices([], 'GPU')
208
+
209
+ main()
assets/attn_plot.png ADDED

Git LFS Details

  • SHA256: 660b2fe611515e076e8d7e154c073e0b914a9af203272781a0e413651e5ca8d9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.86 MB
assets/examples.png ADDED

Git LFS Details

  • SHA256: ed6618d777b28aacbac881686c50f6ec756c48cc9416dc4259e8255ea5387bd2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
assets/model_transformer.png ADDED
checkpoints/RATCHET.tf/keras_metadata.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8fa018ac83d10617e20e3f03de3718d9d3d6e1b89673707cb510318fd3198b3
3
+ size 1065144
checkpoints/RATCHET.tf/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84e9d837b881c58edee113c7bbdc793159e6e57c2ddcf9d2a3e4da7c5104a7db
3
+ size 26013311
checkpoints/RATCHET.tf/variables/variables.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae18face6fa821f8c6c62923ef5533fca681e01b6bb8ae511a9c94844f618c8e
3
+ size 1669994429
checkpoints/RATCHET.tf/variables/variables.index ADDED
Binary file (121 kB). View file
 
checkpoints/cxr_validator_model.tf/fingerprint.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21a31ac72a46d124de283ecbd75c35efc8ac0c5f597efd3040ed8dd00d071ef2
3
+ size 53
checkpoints/cxr_validator_model.tf/keras_metadata.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19106ee698a03e8b9ec11b0092fd65c32654380171a3c55a7976d56313e4438a
3
+ size 2538679
checkpoints/cxr_validator_model.tf/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16e7434007981626733e6f925cd0b226e1f4130cfaec7e79ba81ffd16d7ab1cb
3
+ size 14320368
checkpoints/cxr_validator_model.tf/variables/variables.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2edd5cef46c1624f31464e13f3b5fb8c0ceb4ce8a1d834a6cde9c2e71dd509e
3
+ size 224256098
checkpoints/cxr_validator_model.tf/variables/variables.index ADDED
Binary file (51.9 kB). View file
 
mimic/mimic-merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
mimic/mimic-vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ scikit-image
4
+ tensorflow
5
+ tokenizers
6
+ tqdm
transformer.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+ from __future__ import unicode_literals
5
+
6
+ import datetime
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+
11
+
12
+ def default_hparams():
13
+ return {
14
+ 'img_x': 224,
15
+ 'img_y': 224,
16
+ 'img_ch': 1,
17
+ 'd_model': 512,
18
+ 'dff': 2048,
19
+ 'num_heads': 8,
20
+ 'num_layers': 6,
21
+ 'dropout_rate': 0.1
22
+ }
23
+
24
+
25
+ def positional_encoding(length, depth):
26
+ depth = depth / 2
27
+
28
+ positions = np.arange(length)[:, np.newaxis] # (seq, 1)
29
+ depths = np.arange(depth)[np.newaxis, :] / depth # (1, depth)
30
+
31
+ angle_rates = 1 / (10000 ** depths) # (1, depth)
32
+ angle_rads = positions * angle_rates # (pos, depth)
33
+
34
+ pos_encoding = np.concatenate(
35
+ [np.sin(angle_rads), np.cos(angle_rads)],
36
+ axis=-1)
37
+
38
+ return tf.cast(pos_encoding, dtype=tf.float32)
39
+
40
+
41
+ class PositionalEmbedding(tf.keras.layers.Layer):
42
+ def __init__(self, vocab_size, d_model):
43
+ super().__init__()
44
+ self.d_model = d_model
45
+ self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
46
+ self.pos_encoding = positional_encoding(length=2048, depth=d_model)
47
+
48
+ def compute_mask(self, *args, **kwargs):
49
+ return self.embedding.compute_mask(*args, **kwargs)
50
+
51
+ def call(self, x):
52
+ length = tf.shape(x)[1]
53
+ x = self.embedding(x)
54
+ # This factor sets the relative scale of the embedding and positonal_encoding.
55
+ x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
56
+ x = x + self.pos_encoding[tf.newaxis, :length, :]
57
+ return x
58
+
59
+
60
+ class BaseAttention(tf.keras.layers.Layer):
61
+ def __init__(self, **kwargs):
62
+ super().__init__()
63
+ self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
64
+ self.layernorm = tf.keras.layers.LayerNormalization()
65
+ self.add = tf.keras.layers.Add()
66
+
67
+
68
+ class CrossAttention(BaseAttention):
69
+ def call(self, x, context):
70
+ attn_output, attn_scores = self.mha(
71
+ query=x,
72
+ key=context,
73
+ value=context,
74
+ return_attention_scores=True)
75
+
76
+ # Cache the attention scores for plotting later.
77
+ self.last_attn_scores = attn_scores
78
+
79
+ x = self.add([x, attn_output])
80
+ x = self.layernorm(x)
81
+
82
+ return x
83
+
84
+
85
+ class CausalSelfAttention(BaseAttention):
86
+ def call(self, x):
87
+ attn_output = self.mha(
88
+ query=x,
89
+ value=x,
90
+ key=x,
91
+ use_causal_mask=True)
92
+ x = self.add([x, attn_output])
93
+ x = self.layernorm(x)
94
+ return x
95
+
96
+
97
+ class FeedForward(tf.keras.layers.Layer):
98
+ def __init__(self, d_model, dff, dropout_rate=0.1):
99
+ super().__init__()
100
+ self.seq = tf.keras.Sequential([
101
+ tf.keras.layers.Dense(dff, activation='relu'),
102
+ tf.keras.layers.Dense(d_model),
103
+ tf.keras.layers.Dropout(dropout_rate)
104
+ ])
105
+ self.add = tf.keras.layers.Add()
106
+ self.layer_norm = tf.keras.layers.LayerNormalization()
107
+
108
+ def call(self, x):
109
+ x = self.add([x, self.seq(x)])
110
+ x = self.layer_norm(x)
111
+ return x
112
+
113
+
114
+ class DecoderLayer(tf.keras.layers.Layer):
115
+ def __init__(self,
116
+ *,
117
+ d_model,
118
+ num_heads,
119
+ dff,
120
+ dropout_rate=0.1):
121
+ super(DecoderLayer, self).__init__()
122
+
123
+ self.causal_self_attention = CausalSelfAttention(
124
+ num_heads=num_heads,
125
+ key_dim=d_model,
126
+ dropout=dropout_rate)
127
+
128
+ self.cross_attention = CrossAttention(
129
+ num_heads=num_heads,
130
+ key_dim=d_model,
131
+ dropout=dropout_rate)
132
+
133
+ self.ffn = FeedForward(d_model, dff)
134
+
135
+ def call(self, x, context):
136
+ x = self.causal_self_attention(x=x)
137
+ x = self.cross_attention(x=x, context=context)
138
+
139
+ # Cache the last attention scores for plotting later
140
+ self.last_attn_scores = self.cross_attention.last_attn_scores
141
+
142
+ x = self.ffn(x) # Shape `(batch_size, seq_len, d_model)`.
143
+ return x
144
+
145
+
146
+ class Encoder(tf.keras.layers.Layer):
147
+ def __init__(self, embedding_dim, input_shape, pretrain_weights=None):
148
+ super(Encoder, self).__init__()
149
+
150
+ # shape after fc == (batch_size, nf * nf, embedding_dim)
151
+ self.fc = tf.keras.layers.Dense(embedding_dim, activation='relu')
152
+
153
+ # Use DenseNet-121 as feature extraction model
154
+ self.base_model = tf.keras.applications.DenseNet121(
155
+ include_top=False, weights=None, input_shape=input_shape)
156
+
157
+ # Load pre-trained weights if present
158
+ if pretrain_weights:
159
+ print(f'{datetime.datetime.now()}: I Loading Pretrained DenseNet-121 weights: {pretrain_weights}')
160
+ self.base_model.load_weights(pretrain_weights)
161
+ else:
162
+ print(f'{datetime.datetime.now()}: I No Pretrained DenseNet-121 weights specified')
163
+
164
+ def call(self, x, **kwargs):
165
+ x = self.base_model(x)
166
+ # DenseNet-121 output is (batch_size, ?, ?, 1024)
167
+ s = tf.shape(x)
168
+ x = tf.reshape(x, (s[0], s[1] * s[2], x.shape[3]))
169
+ x = self.fc(x)
170
+ return x
171
+
172
+
173
+ class Decoder(tf.keras.layers.Layer):
174
+ def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,
175
+ dropout_rate=0.1):
176
+ super(Decoder, self).__init__()
177
+
178
+ self.d_model = d_model
179
+ self.num_layers = num_layers
180
+
181
+ self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
182
+ d_model=d_model)
183
+ self.dropout = tf.keras.layers.Dropout(dropout_rate)
184
+ self.dec_layers = [
185
+ DecoderLayer(d_model=d_model, num_heads=num_heads,
186
+ dff=dff, dropout_rate=dropout_rate)
187
+ for _ in range(num_layers)]
188
+
189
+ self.last_attn_scores = None
190
+
191
+ def call(self, x, context):
192
+ # `x` is token-IDs shape (batch, target_seq_len)
193
+ x = self.pos_embedding(x) # (batch_size, target_seq_len, d_model)
194
+
195
+ x = self.dropout(x)
196
+
197
+ for i in range(self.num_layers):
198
+ x = self.dec_layers[i](x, context)
199
+
200
+ self.last_attn_scores = self.dec_layers[-1].last_attn_scores
201
+
202
+ # The shape of x is (batch_size, target_seq_len, d_model).
203
+ return x
204
+
205
+
206
+ class Transformer(tf.keras.Model):
207
+ def __init__(self, num_layers, d_model, num_heads, dff,
208
+ target_vocab_size, dropout_rate=0.1, input_shape=(224, 224, 1),
209
+ classifier_weights=None):
210
+ super(Transformer, self).__init__()
211
+
212
+ self.encoder = Encoder(d_model, input_shape,
213
+ pretrain_weights=classifier_weights)
214
+
215
+ self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
216
+ num_heads=num_heads, dff=dff,
217
+ vocab_size=target_vocab_size,
218
+ dropout_rate=dropout_rate)
219
+
220
+ self.final_layer = tf.keras.layers.Dense(target_vocab_size)
221
+
222
+ def call(self, inputs):
223
+ # To use a Keras model with `.fit` you must pass all your inputs in the
224
+ # first argument.
225
+ context, x = inputs
226
+
227
+ context = self.encoder(context) # (batch_size, context_len, d_model)
228
+
229
+ x = self.decoder(x, context) # (batch_size, target_len, d_model)
230
+
231
+ # Final linear layer output.
232
+ logits = self.final_layer(x) # (batch_size, target_len, target_vocab_size)
233
+
234
+ try:
235
+ # Drop the keras mask, so it doesn't scale the losses/metrics.
236
+ # b/250038731
237
+ del logits._keras_mask
238
+ except AttributeError:
239
+ pass
240
+
241
+ # Return the final output and the attention weights.
242
+ return logits
243
+
244
+
245
+ if __name__ == "__main__":
246
+
247
+ hparams = default_hparams()
248
+
249
+ transformer = Transformer(
250
+ num_layers=hparams['num_layers'],
251
+ d_model=hparams['d_model'],
252
+ num_heads=hparams['num_heads'],
253
+ dff=hparams['dff'],
254
+ target_vocab_size=2048,
255
+ dropout_rate=hparams['dropout_rate'])
256
+
257
+ a=1
258
+
259
+
260
+ image = np.random.rand(1,224,224,1).astype('float32')
261
+ text = np.random.randint(0, 2048, size=(1, 27))
262
+
263
+ output = transformer((image, text))