{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "GPT2(error).ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "1b266a2c1cf646a392a46e39586282b3": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_8ecfcf14981c4d82b5d9d3839a496f0b", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_16b07572ac0d46798b2c2a292c3f9143", "IPY_MODEL_cf412ff73fc647908154abc9b2847f38" ] } }, "8ecfcf14981c4d82b5d9d3839a496f0b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "16b07572ac0d46798b2c2a292c3f9143": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_6ebc21286ae843e5b9ba4df8f4cebfe0", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 1042301, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 1042301, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_48246be80e82429da2d48f9d4a1aaf0a" } }, "cf412ff73fc647908154abc9b2847f38": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_5754900e885d4f509ede058b186fcab6", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 1.04M/1.04M [00:06<00:00, 154kB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_d0434381119c46489e17fcbccd9755ea" } }, "6ebc21286ae843e5b9ba4df8f4cebfe0": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "48246be80e82429da2d48f9d4a1aaf0a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "5754900e885d4f509ede058b186fcab6": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "d0434381119c46489e17fcbccd9755ea": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "73c4b8bc05f64477aa03d767f4483795": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_6123827ad5964b4b8a17aaca618b4768", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_5327d425e74d4a599214282b9b70d58b", "IPY_MODEL_974490d04f18407f9f5a5785b2802c0a" ] } }, "6123827ad5964b4b8a17aaca618b4768": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "5327d425e74d4a599214282b9b70d58b": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_c3cc1723c39a4d74b2ab83bd23b5fcce", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 456318, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 456318, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_391d59bf8d2845f88a83dc25c7cf89f3" } }, "974490d04f18407f9f5a5785b2802c0a": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_d60fa9fe71444784b78bdfba6ed6a9e1", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 456k/456k [00:04<00:00, 96.1kB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_41a3b55e5e264b85ada9558e5777790f" } }, "c3cc1723c39a4d74b2ab83bd23b5fcce": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "391d59bf8d2845f88a83dc25c7cf89f3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "d60fa9fe71444784b78bdfba6ed6a9e1": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "41a3b55e5e264b85ada9558e5777790f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "aa4d6e2e9ac44e9bb40b7daccc91ee83": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_c3b054972a6145d1ad03ca938a7ade9c", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_644f4a69db534dd4a11172e5d010e8fe", "IPY_MODEL_0e19c2d5b060490399efbfcda773e9ba" ] } }, "c3b054972a6145d1ad03ca938a7ade9c": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "644f4a69db534dd4a11172e5d010e8fe": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_109479db406d4085acff84904cdac4ef", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 1355256, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 1355256, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_1fb7fd7b44bf4b3a9e949f025487ff47" } }, "0e19c2d5b060490399efbfcda773e9ba": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_01c30370e3ff4b5caf9a3369841ad597", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 1.36M/1.36M [00:00<00:00, 1.73MB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_d19eedc2acce48d4be7189b422b5fcb9" } }, "109479db406d4085acff84904cdac4ef": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "1fb7fd7b44bf4b3a9e949f025487ff47": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "01c30370e3ff4b5caf9a3369841ad597": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "d19eedc2acce48d4be7189b422b5fcb9": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } } } } }, "cells": [ { "cell_type": "code", "metadata": { "id": "hYCVkKKAwSjV" }, "source": [ "%%capture\n", "!pip install transformers\n", "!pip install datasets\n", "!pip install --upgrade git+https://github.com/google/flax.git" ], "execution_count": 1, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "2gcm5rxByOXO", "colab": { "base_uri": "https://localhost:8080/", "height": 164, "referenced_widgets": [ "1b266a2c1cf646a392a46e39586282b3", "8ecfcf14981c4d82b5d9d3839a496f0b", "16b07572ac0d46798b2c2a292c3f9143", "cf412ff73fc647908154abc9b2847f38", "6ebc21286ae843e5b9ba4df8f4cebfe0", "48246be80e82429da2d48f9d4a1aaf0a", "5754900e885d4f509ede058b186fcab6", "d0434381119c46489e17fcbccd9755ea", "73c4b8bc05f64477aa03d767f4483795", "6123827ad5964b4b8a17aaca618b4768", "5327d425e74d4a599214282b9b70d58b", "974490d04f18407f9f5a5785b2802c0a", "c3cc1723c39a4d74b2ab83bd23b5fcce", "391d59bf8d2845f88a83dc25c7cf89f3", "d60fa9fe71444784b78bdfba6ed6a9e1", "41a3b55e5e264b85ada9558e5777790f", "aa4d6e2e9ac44e9bb40b7daccc91ee83", "c3b054972a6145d1ad03ca938a7ade9c", "644f4a69db534dd4a11172e5d010e8fe", "0e19c2d5b060490399efbfcda773e9ba", "109479db406d4085acff84904cdac4ef", "1fb7fd7b44bf4b3a9e949f025487ff47", "01c30370e3ff4b5caf9a3369841ad597", "d19eedc2acce48d4be7189b422b5fcb9" ] }, "outputId": "5814323f-d04d-408c-e833-8522806ea73b" }, "source": [ "import jax\n", "from transformers.modeling_flax_utils import FlaxPreTrainedModel\n", "import flax.linen as nn\n", "import jax.numpy as jnp\n", "from transformers import GPT2Config\n", "from transformers import FlaxGPT2PreTrainedModel\n", "from transformers import FlaxGPT2Model\n", "from transformers import GPT2Tokenizer\n", "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\",pad_token='<|endoftext|>')" ], "execution_count": 2, "outputs": [ { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1b266a2c1cf646a392a46e39586282b3", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "73c4b8bc05f64477aa03d767f4483795", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aa4d6e2e9ac44e9bb40b7daccc91ee83", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "GDokS6VEJI6C" }, "source": [ "#inputs = tokenizer([\"JAX/Flax is amazing \",\"tensorflow is also good\"],[\"pytorch is better\",\"keras is the best\"],return_tensors='jax',padding='max_length',max_length=30)" ], "execution_count": 3, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "hWiMk1TzyYim" }, "source": [ "class FlaxGPT2ForMultipleChoiceModule(nn.Module):\n", " config:GPT2Config\n", " dtype: jnp.dtype = jnp.float32\n", " def setup(self):\n", " self.gpt2 = FlaxGPT2Model(config=self.config, dtype=self.dtype)\n", " self.dropout = nn.Dropout(rate=0.2)\n", " self.classifier = nn.Dense(4, dtype=self.dtype)\n", "\n", " def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):\n", " batch_size = input_ids.shape[0]\n", "\n", " rng=jax.random.PRNGKey(0)\n", " _, dropout_rng = jax.random.split(rng)\n", "\n", " outputs=self.gpt2(input_ids, attention_mask,position_ids,return_dict=return_dict)\n", " \n", "\n", " hidden_states = outputs[0]\n", "\n", " \n", " hidden_states= jnp.mean(hidden_states, axis=1)\n", "\n", " print(hidden_states.shape)\n", " \n", " hidden_states=hidden_states.reshape(batch_size,-1) #(32,8,768)->(32,8*768)\n", "\n", " dropout_output = self.dropout(hidden_states,deterministic=deterministic,rng=dropout_rng)\n", "\n", " print(dropout_output.shape)\n", "\n", " logits = self.classifier(dropout_output)\n", " reshaped_logits = logits.reshape(-1, 4) #(32,4)\n", " if not return_dict:\n", " return (reshaped_logits,) + outputs[2:]\n", " return reshaped_logits" ], "execution_count": 7, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "u1j00Ck255BC" }, "source": [ "class FlaxGPT2ForMultipleChoice(FlaxGPT2PreTrainedModel):\n", " module_class = FlaxGPT2ForMultipleChoiceModule" ], "execution_count": 8, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "h2MrRgKTRxZO", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "5a0fcc68-ca39-4df0-c854-734125d65f53" }, "source": [ "model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2') # getting warning" ], "execution_count": 9, "outputs": [ { "output_type": "stream", "text": [ "(1, 768)\n", "(1, 768)\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "Some weights of the model checkpoint at gpt2 were not used when initializing FlaxGPT2ForMultipleChoice: {('h', '1', 'ln_1', 'bias'), ('h', '6', 'ln_1', 'scale'), ('h', '1', 'attn', 'c_proj', 'kernel'), ('h', '11', 'mlp', 'c_fc', 'bias'), ('h', '7', 'ln_1', 'bias'), ('h', '5', 'ln_2', 'bias'), ('h', '10', 'ln_2', 'scale'), ('h', '4', 'mlp', 'c_proj', 'kernel'), ('h', '0', 'mlp', 'c_proj', 'bias'), ('h', '0', 'ln_1', 'bias'), ('h', '0', 'mlp', 'c_fc', 'kernel'), ('wpe', 'embedding'), ('h', '3', 'ln_1', 'scale'), ('h', '2', 'ln_1', 'scale'), ('h', '3', 'mlp', 'c_fc', 'kernel'), ('h', '7', 'ln_1', 'scale'), ('h', '8', 'mlp', 'c_proj', 'kernel'), ('h', '7', 'mlp', 'c_proj', 'kernel'), ('h', '3', 'ln_2', 'bias'), ('h', '9', 'attn', 'c_attn', 'kernel'), ('h', '0', 'mlp', 'c_fc', 'bias'), ('h', '3', 'attn', 'c_proj', 'bias'), ('h', '0', 'ln_1', 'scale'), ('h', '3', 'attn', 'c_attn', 'kernel'), ('h', '0', 'mlp', 'c_proj', 'kernel'), ('h', '5', 'ln_1', 'bias'), ('h', '7', 'attn', 'c_attn', 'bias'), ('h', '1', 'ln_2', 'bias'), ('h', '11', 'ln_2', 'scale'), ('h', '7', 'ln_2', 'bias'), ('h', '9', 'attn', 'c_proj', 'kernel'), ('h', '0', 'ln_2', 'bias'), ('h', '2', 'ln_2', 'scale'), ('h', '11', 'attn', 'c_attn', 'kernel'), ('h', '8', 'attn', 'c_proj', 'kernel'), ('h', '4', 'attn', 'c_attn', 'kernel'), ('h', '5', 'ln_1', 'scale'), ('h', '4', 'ln_1', 'bias'), ('h', '8', 'ln_2', 'bias'), ('h', '1', 'mlp', 'c_fc', 'kernel'), ('h', '9', 'ln_2', 'scale'), ('h', '1', 'mlp', 'c_proj', 'bias'), ('h', '2', 'mlp', 'c_proj', 'kernel'), ('h', '9', 'attn', 'c_proj', 'bias'), ('h', '11', 'ln_2', 'bias'), ('h', '6', 'mlp', 'c_proj', 'bias'), ('h', '3', 'ln_1', 'bias'), ('h', '1', 'attn', 'c_attn', 'kernel'), ('h', '9', 'ln_1', 'scale'), ('h', '10', 'attn', 'c_attn', 'bias'), ('h', '10', 'mlp', 'c_proj', 'kernel'), ('h', '2', 'attn', 'c_proj', 'kernel'), ('h', '0', 'attn', 'c_proj', 'kernel'), ('h', '6', 'attn', 'c_attn', 'kernel'), ('h', '4', 'mlp', 'c_fc', 'bias'), ('h', '3', 'attn', 'c_attn', 'bias'), ('h', '3', 'attn', 'c_proj', 'kernel'), ('h', '11', 'mlp', 'c_proj', 'bias'), ('h', '9', 'attn', 'c_attn', 'bias'), ('h', '7', 'mlp', 'c_proj', 'bias'), ('h', '7', 'mlp', 'c_fc', 'bias'), ('h', '6', 'attn', 'c_attn', 'bias'), ('h', '5', 'mlp', 'c_fc', 'kernel'), ('h', '0', 'attn', 'c_proj', 'bias'), ('h', '2', 'attn', 'c_proj', 'bias'), ('h', '10', 'attn', 'c_attn', 'kernel'), ('h', '10', 'mlp', 'c_proj', 'bias'), ('h', '1', 'attn', 'c_attn', 'bias'), ('h', '11', 'ln_1', 'bias'), ('h', '4', 'ln_2', 'bias'), ('h', '8', 'ln_1', 'bias'), ('h', '11', 'attn', 'c_proj', 'kernel'), ('h', '9', 'mlp', 'c_fc', 'kernel'), ('h', '7', 'ln_2', 'scale'), ('h', '9', 'mlp', 'c_proj', 'kernel'), ('h', '11', 'attn', 'c_attn', 'bias'), ('h', '10', 'mlp', 'c_fc', 'bias'), ('h', '6', 'attn', 'c_proj', 'kernel'), ('h', '0', 'ln_2', 'scale'), ('h', '2', 'ln_2', 'bias'), ('h', '3', 'mlp', 'c_proj', 'bias'), ('h', '5', 'mlp', 'c_proj', 'kernel'), ('h', '8', 'mlp', 'c_fc', 'bias'), ('h', '9', 'mlp', 'c_proj', 'bias'), ('h', '9', 'mlp', 'c_fc', 'bias'), ('h', '8', 'mlp', 'c_fc', 'kernel'), ('h', '9', 'ln_1', 'bias'), ('h', '10', 'ln_1', 'scale'), ('h', '6', 'ln_2', 'bias'), ('h', '2', 'mlp', 'c_fc', 'kernel'), ('h', '4', 'attn', 'c_proj', 'bias'), ('h', '1', 'ln_2', 'scale'), ('h', '5', 'mlp', 'c_fc', 'bias'), ('h', '7', 'mlp', 'c_fc', 'kernel'), ('h', '7', 'attn', 'c_proj', 'bias'), ('h', '5', 'attn', 'c_proj', 'kernel'), ('h', '2', 'mlp', 'c_fc', 'bias'), ('h', '6', 'ln_2', 'scale'), ('h', '11', 'ln_1', 'scale'), ('h', '4', 'mlp', 'c_fc', 'kernel'), ('h', '2', 'ln_1', 'bias'), ('h', '9', 'ln_2', 'bias'), ('h', '11', 'mlp', 'c_fc', 'kernel'), ('h', '1', 'attn', 'c_proj', 'bias'), ('h', '4', 'ln_2', 'scale'), ('h', '8', 'ln_1', 'scale'), ('h', '6', 'attn', 'c_proj', 'bias'), ('h', '5', 'attn', 'c_attn', 'kernel'), ('h', '3', 'ln_2', 'scale'), ('h', '8', 'attn', 'c_attn', 'bias'), ('h', '10', 'mlp', 'c_fc', 'kernel'), ('h', '1', 'ln_1', 'scale'), ('h', '10', 'attn', 'c_proj', 'bias'), ('h', '6', 'ln_1', 'bias'), ('h', '0', 'attn', 'c_attn', 'kernel'), ('wte', 'embedding'), ('h', '6', 'mlp', 'c_fc', 'kernel'), ('h', '4', 'attn', 'c_attn', 'bias'), ('h', '10', 'ln_2', 'bias'), ('h', '8', 'attn', 'c_proj', 'bias'), ('h', '11', 'attn', 'c_proj', 'bias'), ('h', '8', 'attn', 'c_attn', 'kernel'), ('h', '5', 'attn', 'c_attn', 'bias'), ('h', '5', 'ln_2', 'scale'), ('h', '2', 'attn', 'c_attn', 'bias'), ('ln_f', 'scale'), ('h', '7', 'attn', 'c_attn', 'kernel'), ('h', '4', 'ln_1', 'scale'), ('h', '8', 'ln_2', 'scale'), ('h', '11', 'mlp', 'c_proj', 'kernel'), ('h', '5', 'attn', 'c_proj', 'bias'), ('h', '7', 'attn', 'c_proj', 'kernel'), ('h', '8', 'mlp', 'c_proj', 'bias'), ('h', '3', 'mlp', 'c_fc', 'bias'), ('h', '10', 'ln_1', 'bias'), ('h', '2', 'attn', 'c_attn', 'kernel'), ('h', '6', 'mlp', 'c_proj', 'kernel'), ('h', '4', 'attn', 'c_proj', 'kernel'), ('h', '1', 'mlp', 'c_proj', 'kernel'), ('h', '2', 'mlp', 'c_proj', 'bias'), ('h', '1', 'mlp', 'c_fc', 'bias'), ('h', '4', 'mlp', 'c_proj', 'bias'), ('ln_f', 'bias'), ('h', '6', 'mlp', 'c_fc', 'bias'), ('h', '0', 'attn', 'c_attn', 'bias'), ('h', '10', 'attn', 'c_proj', 'kernel'), ('h', '5', 'mlp', 'c_proj', 'bias'), ('h', '3', 'mlp', 'c_proj', 'kernel')}\n", "- This IS expected if you are initializing FlaxGPT2ForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing FlaxGPT2ForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of FlaxGPT2ForMultipleChoice were not initialized from the model checkpoint at gpt2 and are newly initialized: {('classifier', 'bias'), ('classifier', 'kernel')}\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ], "name": "stderr" } ] }, { "cell_type": "code", "metadata": { "id": "CdSuQK9pRmw-" }, "source": [ "input_ids=jnp.ones((4,5,6))\n", "attention_mask=jnp.ones((4,5,6))" ], "execution_count": 10, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "d3Bu38KTkwWs", "colab": { "base_uri": "https://localhost:8080/", "height": 300 }, "outputId": "5470ccc9-6d49-427c-ad8e-5162343acfde" }, "source": [ "out1 = model(input_ids, attention_mask) #GPT2 will not take (batch_size,num_choice,sequence_length)" ], "execution_count": 11, "outputs": [ { "output_type": "error", "ename": "ValueError", "evalue": "ignored", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mout1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#GPT2 will not take (batch_size,num_choice,sequence_length)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_flax_gpt2.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, input_ids, attention_mask, position_ids, params, past_key_values, dropout_rng, train, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 370\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 371\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 372\u001b[0;31m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msequence_length\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 373\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 374\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mposition_ids\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)" ] } ] }, { "cell_type": "code", "metadata": { "id": "VZPlQfkhgLJd", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "7948688b-be77-4e9a-fc06-8644a2614d42" }, "source": [ "print(out1)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "[[ 1.1391759 -0.01598702 0.55463445 0.36025363]\n", " [ 0.32208228 0.37667227 0.87823874 0.19541818]\n", " [ 0.76971424 0.7187787 0.68642044 -0.31461257]\n", " [ 1.2375658 0.03325981 0.00153449 0.12019679]]\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "fgkIcD-mZWP7" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }