boris commited on
Commit
894a546
2 Parent(s): 2f69241 67221fc

Merge pull request #13 from borisdayma/model-generate-notebook

Browse files
seq2seq/CustomBARTv4b_model-generate.ipynb ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "name": "CustomBARTv4b-model-generate.ipynb",
7
+ "provenance": [],
8
+ "collapsed_sections": [],
9
+ "machine_shape": "hm"
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "accelerator": "TPU"
19
+ },
20
+ "cells": [
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {
24
+ "id": "ewer-Q-0w2xA"
25
+ },
26
+ "source": [
27
+ "# Installation"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "metadata": {
33
+ "colab": {
34
+ "base_uri": "https://localhost:8080/"
35
+ },
36
+ "id": "NpsF9ipLLl2s",
37
+ "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
38
+ },
39
+ "source": [
40
+ "!pip install git+https://github.com/huggingface/transformers/\n",
41
+ "!pip install git+https://github.com/google/flax"
42
+ ],
43
+ "execution_count": 1,
44
+ "outputs": [
45
+ {
46
+ "output_type": "stream",
47
+ "text": [
48
+ "Collecting git+https://github.com/huggingface/transformers/\n",
49
+ " Cloning https://github.com/huggingface/transformers/ to /tmp/pip-req-build-oxejx1op\n",
50
+ " Running command git clone -q https://github.com/huggingface/transformers/ /tmp/pip-req-build-oxejx1op\n",
51
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
52
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
53
+ " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
54
+ "Requirement already satisfied (use --upgrade to upgrade): transformers==4.9.0.dev0 from git+https://github.com/huggingface/transformers/ in /usr/local/lib/python3.7/dist-packages\n",
55
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (1.19.5)\n",
56
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (20.9)\n",
57
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (5.4.1)\n",
58
+ "Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.0.45)\n",
59
+ "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (4.6.0)\n",
60
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (4.41.1)\n",
61
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (3.0.12)\n",
62
+ "Requirement already satisfied: huggingface-hub==0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.0.12)\n",
63
+ "Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.10.3)\n",
64
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (2019.12.20)\n",
65
+ "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (2.23.0)\n",
66
+ "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers==4.9.0.dev0) (2.4.7)\n",
67
+ "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (1.15.0)\n",
68
+ "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (1.0.1)\n",
69
+ "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (7.1.2)\n",
70
+ "Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.9.0.dev0) (3.7.4.3)\n",
71
+ "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.9.0.dev0) (3.4.1)\n",
72
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (2021.5.30)\n",
73
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (3.0.4)\n",
74
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (1.24.3)\n",
75
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (2.10)\n",
76
+ "Building wheels for collected packages: transformers\n",
77
+ " Building wheel for transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
78
+ " Created wheel for transformers: filename=transformers-4.9.0.dev0-cp37-none-any.whl size=2582229 sha256=249c593273ccca3027c6427d2c6fd749a89f21d722d628d97eb438a2cf3185a8\n",
79
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-l2rqt1b7/wheels/61/69/33/974fccec4d0ab5feee9fe83bd93e680d269a805be9ede5ec60\n",
80
+ "Successfully built transformers\n",
81
+ "Collecting git+https://github.com/google/flax\n",
82
+ " Cloning https://github.com/google/flax to /tmp/pip-req-build-rt9g1_wx\n",
83
+ " Running command git clone -q https://github.com/google/flax /tmp/pip-req-build-rt9g1_wx\n",
84
+ "Requirement already satisfied (use --upgrade to upgrade): flax==0.3.4 from git+https://github.com/google/flax in /usr/local/lib/python3.7/dist-packages\n",
85
+ "Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (1.19.5)\n",
86
+ "Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (0.2.13)\n",
87
+ "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (3.2.2)\n",
88
+ "Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (1.0.2)\n",
89
+ "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (0.0.9)\n",
90
+ "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->flax==0.3.4) (3.3.0)\n",
91
+ "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->flax==0.3.4) (0.12.0)\n",
92
+ "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (2.8.1)\n",
93
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (0.10.0)\n",
94
+ "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (2.4.7)\n",
95
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (1.3.1)\n",
96
+ "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax->flax==0.3.4) (0.0.8)\n",
97
+ "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax->flax==0.3.4) (0.1.66+cuda110)\n",
98
+ "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.2.13->flax==0.3.4) (1.15.0)\n",
99
+ "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.3.4) (0.1.6)\n",
100
+ "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.3.4) (0.11.1)\n",
101
+ "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax==0.3.4) (1.12)\n",
102
+ "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax==0.3.4) (1.4.1)\n",
103
+ "Building wheels for collected packages: flax\n",
104
+ " Building wheel for flax (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
105
+ " Created wheel for flax: filename=flax-0.3.4-cp37-none-any.whl size=184692 sha256=503b27995f372afe33631e71572d5edc1fffd4d2e0a4cd206d291ad6b0e4c299\n",
106
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-g1pzxnv6/wheels/3d/26/f4/0ea6051d7352289d9e4f8178348452b35a9a97bde6035405a5\n",
107
+ "Successfully built flax\n"
108
+ ],
109
+ "name": "stdout"
110
+ }
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "metadata": {
116
+ "id": "M1wVkrpjU6zO"
117
+ },
118
+ "source": [
119
+ "%load_ext autoreload\n",
120
+ "%autoreload 2"
121
+ ],
122
+ "execution_count": 2,
123
+ "outputs": []
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "metadata": {
128
+ "id": "t47CH1H_IOT8"
129
+ },
130
+ "source": [
131
+ "# Custom BART Model"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "metadata": {
137
+ "id": "9jQnM6S2vCpn"
138
+ },
139
+ "source": [
140
+ "# TODO: set those args in a config file\n",
141
+ "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
142
+ "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
143
+ "BOS_TOKEN_ID = 16384\n",
144
+ "BASE_MODEL = 'facebook/bart-large-cnn'"
145
+ ],
146
+ "execution_count": 3,
147
+ "outputs": []
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "metadata": {
152
+ "id": "_eEaJVxAKpV5"
153
+ },
154
+ "source": [
155
+ "import jax\n",
156
+ "import flax.linen as nn\n",
157
+ "\n",
158
+ "from transformers.models.bart.modeling_flax_bart import *\n",
159
+ "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
160
+ "\n",
161
+ "class CustomFlaxBartModule(FlaxBartModule):\n",
162
+ " def setup(self):\n",
163
+ " # we keep shared to easily load pre-trained weights\n",
164
+ " self.shared = nn.Embed(\n",
165
+ " self.config.vocab_size,\n",
166
+ " self.config.d_model,\n",
167
+ " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
168
+ " dtype=self.dtype,\n",
169
+ " )\n",
170
+ " # a separate embedding is used for the decoder\n",
171
+ " self.decoder_embed = nn.Embed(\n",
172
+ " OUTPUT_VOCAB_SIZE,\n",
173
+ " self.config.d_model,\n",
174
+ " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
175
+ " dtype=self.dtype,\n",
176
+ " )\n",
177
+ " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
178
+ "\n",
179
+ " # the decoder has a different config\n",
180
+ " decoder_config = BartConfig(self.config.to_dict())\n",
181
+ " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
182
+ " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
183
+ " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
184
+ "\n",
185
+ "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
186
+ " def setup(self):\n",
187
+ " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
188
+ " self.lm_head = nn.Dense(\n",
189
+ " OUTPUT_VOCAB_SIZE,\n",
190
+ " use_bias=False,\n",
191
+ " dtype=self.dtype,\n",
192
+ " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
193
+ " )\n",
194
+ " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
195
+ "\n",
196
+ "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
197
+ " module_class = CustomFlaxBartForConditionalGenerationModule"
198
+ ],
199
+ "execution_count": 4,
200
+ "outputs": []
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "metadata": {
205
+ "id": "S7CP9Td9m2ge",
206
+ "colab": {
207
+ "base_uri": "https://localhost:8080/"
208
+ },
209
+ "outputId": "5638ef68-9c40-46f7-90ba-a4d05b61360d"
210
+ },
211
+ "source": [
212
+ "# load pre-trained model for encoder weights\n",
213
+ "base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)"
214
+ ],
215
+ "execution_count": 5,
216
+ "outputs": [
217
+ {
218
+ "output_type": "stream",
219
+ "text": [
220
+ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
221
+ ],
222
+ "name": "stderr"
223
+ }
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "metadata": {
229
+ "id": "6lmynR-poceH"
230
+ },
231
+ "source": [
232
+ "# set up our new model config\n",
233
+ "config = BartConfig.from_pretrained(BASE_MODEL)\n",
234
+ "config.tie_word_embeddings = False\n",
235
+ "config.decoder_start_token_id = BOS_TOKEN_ID\n",
236
+ "config.bos_token_id = BOS_TOKEN_ID # should not be used\n",
237
+ "config.pos_token_id = BOS_TOKEN_ID # should not be used\n",
238
+ "#config.eos_token_id = None # prevents generation from stopping until we reach max_length"
239
+ ],
240
+ "execution_count": 6,
241
+ "outputs": []
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "metadata": {
246
+ "id": "_6-XKK40oEfP"
247
+ },
248
+ "source": [
249
+ "# create our model and initialize it randomly\n",
250
+ "model = CustomFlaxBartForConditionalGeneration(config)"
251
+ ],
252
+ "execution_count": 7,
253
+ "outputs": []
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "metadata": {
258
+ "id": "-r_hZestr-NR"
259
+ },
260
+ "source": [
261
+ "# use pretrained weights\n",
262
+ "model.params['model']['encoder'] = base_model.params['model']['encoder']\n",
263
+ "model.params['model']['shared'] = base_model.params['model']['shared']"
264
+ ],
265
+ "execution_count": 8,
266
+ "outputs": []
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "metadata": {
271
+ "id": "5NEX8f62sVjx"
272
+ },
273
+ "source": [
274
+ "# no need for base_model anymore\n",
275
+ "del base_model"
276
+ ],
277
+ "execution_count": 9,
278
+ "outputs": []
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "metadata": {
283
+ "colab": {
284
+ "base_uri": "https://localhost:8080/"
285
+ },
286
+ "id": "Jz032w73nHEf",
287
+ "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
288
+ },
289
+ "source": [
290
+ "# we verify that the shape has not been modified\n",
291
+ "model.params['final_logits_bias'].shape"
292
+ ],
293
+ "execution_count": 10,
294
+ "outputs": [
295
+ {
296
+ "output_type": "execute_result",
297
+ "data": {
298
+ "text/plain": [
299
+ "(1, 16385)"
300
+ ]
301
+ },
302
+ "metadata": {
303
+ "tags": []
304
+ },
305
+ "execution_count": 10
306
+ }
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "markdown",
311
+ "metadata": {
312
+ "id": "zLl24Ez5t7x1"
313
+ },
314
+ "source": [
315
+ "## Inference"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "metadata": {
321
+ "id": "XLLA2NK3uDQr"
322
+ },
323
+ "source": [
324
+ "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
325
+ ],
326
+ "execution_count": 11,
327
+ "outputs": []
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "metadata": {
332
+ "colab": {
333
+ "base_uri": "https://localhost:8080/"
334
+ },
335
+ "id": "Ntow53I_t81D",
336
+ "outputId": "59289cdd-1429-4720-cc87-88810c4b99ac"
337
+ },
338
+ "source": [
339
+ "text = \"My friends are cool but they eat too many carbs.\"\n",
340
+ "inputs = tokenizer(text, max_length=1024, return_tensors='jax')\n",
341
+ "encoder_outputs = model.encode(**inputs)"
342
+ ],
343
+ "execution_count": 12,
344
+ "outputs": [
345
+ {
346
+ "output_type": "stream",
347
+ "text": [
348
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n"
349
+ ],
350
+ "name": "stderr"
351
+ }
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "metadata": {
357
+ "colab": {
358
+ "base_uri": "https://localhost:8080/"
359
+ },
360
+ "id": "vcRNJnJ_uJOJ",
361
+ "outputId": "025afd54-7908-4a9c-fb59-e40bd3458711"
362
+ },
363
+ "source": [
364
+ "decoder_start_token_id = model.config.decoder_start_token_id\n",
365
+ "decoder_start_token_id"
366
+ ],
367
+ "execution_count": 13,
368
+ "outputs": [
369
+ {
370
+ "output_type": "execute_result",
371
+ "data": {
372
+ "text/plain": [
373
+ "16384"
374
+ ]
375
+ },
376
+ "metadata": {
377
+ "tags": []
378
+ },
379
+ "execution_count": 13
380
+ }
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "metadata": {
386
+ "id": "6QWmEwL_uMld"
387
+ },
388
+ "source": [
389
+ "decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype=\"i4\") * decoder_start_token_id\n",
390
+ "outputs = model.decode(decoder_input_ids, encoder_outputs)"
391
+ ],
392
+ "execution_count": 14,
393
+ "outputs": []
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "metadata": {
398
+ "colab": {
399
+ "base_uri": "https://localhost:8080/"
400
+ },
401
+ "id": "c_ys3yWBothF",
402
+ "outputId": "40d4d584-e0a8-44cb-bbea-0ffa38d50a53"
403
+ },
404
+ "source": [
405
+ "outputs"
406
+ ],
407
+ "execution_count": 15,
408
+ "outputs": [
409
+ {
410
+ "output_type": "execute_result",
411
+ "data": {
412
+ "text/plain": [
413
+ "FlaxCausalLMOutputWithCrossAttentions([('logits',\n",
414
+ " DeviceArray([[[ 0.5263986 , -2.0947676 , -0.18830685, ..., 0.7599884 ,\n",
415
+ " 0.6746795 , -1.0411576 ]]], dtype=float32))])"
416
+ ]
417
+ },
418
+ "metadata": {
419
+ "tags": []
420
+ },
421
+ "execution_count": 15
422
+ }
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "metadata": {
428
+ "colab": {
429
+ "base_uri": "https://localhost:8080/"
430
+ },
431
+ "id": "O6s0wtB_uTC_",
432
+ "outputId": "bc0e9e80-e346-4e99-d28e-3f658eda1f66"
433
+ },
434
+ "source": [
435
+ "outputs.logits.shape"
436
+ ],
437
+ "execution_count": 16,
438
+ "outputs": [
439
+ {
440
+ "output_type": "execute_result",
441
+ "data": {
442
+ "text/plain": [
443
+ "(1, 1, 16385)"
444
+ ]
445
+ },
446
+ "metadata": {
447
+ "tags": []
448
+ },
449
+ "execution_count": 16
450
+ }
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "metadata": {
456
+ "colab": {
457
+ "base_uri": "https://localhost:8080/"
458
+ },
459
+ "id": "ELzemGP3uBzy",
460
+ "outputId": "dc12f98a-1ccf-450d-ba2a-9c29d7d14885"
461
+ },
462
+ "source": [
463
+ "outputs.logits.argmax(axis=-1)"
464
+ ],
465
+ "execution_count": 17,
466
+ "outputs": [
467
+ {
468
+ "output_type": "execute_result",
469
+ "data": {
470
+ "text/plain": [
471
+ "DeviceArray([[12459]], dtype=int32)"
472
+ ]
473
+ },
474
+ "metadata": {
475
+ "tags": []
476
+ },
477
+ "execution_count": 17
478
+ }
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "metadata": {
484
+ "colab": {
485
+ "base_uri": "https://localhost:8080/"
486
+ },
487
+ "id": "fQjikkGEunpx",
488
+ "outputId": "3dba0209-ad4e-4069-be38-6c599c677ef1"
489
+ },
490
+ "source": [
491
+ "model.config.bos_token_id, model.config.eos_token_id, model.config.pad_token_id"
492
+ ],
493
+ "execution_count": 18,
494
+ "outputs": [
495
+ {
496
+ "output_type": "execute_result",
497
+ "data": {
498
+ "text/plain": [
499
+ "(16384, 2, 1)"
500
+ ]
501
+ },
502
+ "metadata": {
503
+ "tags": []
504
+ },
505
+ "execution_count": 18
506
+ }
507
+ ]
508
+ },
509
+ {
510
+ "cell_type": "code",
511
+ "metadata": {
512
+ "id": "P32mJJSbrU1F"
513
+ },
514
+ "source": [
515
+ "input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')"
516
+ ],
517
+ "execution_count": 19,
518
+ "outputs": []
519
+ },
520
+ {
521
+ "cell_type": "code",
522
+ "metadata": {
523
+ "id": "C7cHbIHruELT"
524
+ },
525
+ "source": [
526
+ "greedy_output = model.generate(input_ids_test, max_length=50)"
527
+ ],
528
+ "execution_count": 20,
529
+ "outputs": []
530
+ },
531
+ {
532
+ "cell_type": "code",
533
+ "metadata": {
534
+ "colab": {
535
+ "base_uri": "https://localhost:8080/"
536
+ },
537
+ "id": "jYugh9cOuwc9",
538
+ "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
539
+ },
540
+ "source": [
541
+ "greedy_output[0]"
542
+ ],
543
+ "execution_count": 21,
544
+ "outputs": [
545
+ {
546
+ "output_type": "execute_result",
547
+ "data": {
548
+ "text/plain": [
549
+ "DeviceArray([[16384, 0, 3570, 13405, 10186, 2392, 16362, 1869,\n",
550
+ " 15772, 13546, 15772, 13546, 9348, 14791, 15772, 15772,\n",
551
+ " 15772, 11272, 15772, 13546, 15772, 15772, 13546, 15772,\n",
552
+ " 13546, 15772, 6642, 15772, 10776, 6431, 15772, 14567,\n",
553
+ " 13406, 15772, 14567, 6235, 15772, 4909, 16160, 568,\n",
554
+ " 4664, 6650, 8952, 9089, 15772, 5952, 7375, 10843,\n",
555
+ " 8952, 2]], dtype=int32)"
556
+ ]
557
+ },
558
+ "metadata": {
559
+ "tags": []
560
+ },
561
+ "execution_count": 21
562
+ }
563
+ ]
564
+ }
565
+ ]
566
+ }