{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "model_train_upload_workflow.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.4" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "ad1f7a8dbb624d8988ade18adab9256a": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_49da911827a943e0aadb2efd0329d147", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_a7ec07a01a694048a05ed29df99984f6", "IPY_MODEL_1fa89f5112884cf6b43288bae7e710cb" ] } }, "49da911827a943e0aadb2efd0329d147": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "a7ec07a01a694048a05ed29df99984f6": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_bd094c08455d4f0a84f18afbb3b69ef8", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 26, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 26, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_c73a9873c97e4a41be77376199e1fb32" } }, "1fa89f5112884cf6b43288bae7e710cb": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_e7afc24c895d48ca951144df72c30278", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 26.0/26.0 [00:00<00:00, 227B/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_52cb01ad3daf42c98699983c46eaa4c5" } }, "bd094c08455d4f0a84f18afbb3b69ef8": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "c73a9873c97e4a41be77376199e1fb32": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "e7afc24c895d48ca951144df72c30278": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "52cb01ad3daf42c98699983c46eaa4c5": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "b5ced843807b41d8b3620e841785e0e0": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_4fa3061c511d4afb847c89b3fe1cf72d", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_f9c5ce7bb2394b44b6915d0e6d3385fe", "IPY_MODEL_ce14e3b5f648478f84bebce1c451b1db" ] } }, "4fa3061c511d4afb847c89b3fe1cf72d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "f9c5ce7bb2394b44b6915d0e6d3385fe": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_f3f5b0d3ffec4b23b149ccbe9f9cb106", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 641, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 641, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_b5b8b2092a3a4ee18dcbcc0651a7163f" } }, "ce14e3b5f648478f84bebce1c451b1db": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_e943740e8a2a4add827227dfd16dbe00", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 641/641 [00:00<00:00, 1.38kB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_5f95112a99ca4609aa57d0faabfc47f0" } }, "f3f5b0d3ffec4b23b149ccbe9f9cb106": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "b5b8b2092a3a4ee18dcbcc0651a7163f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "e943740e8a2a4add827227dfd16dbe00": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "5f95112a99ca4609aa57d0faabfc47f0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "fdeb9c5964b1450394ea02d57387f87a": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_e9cb0a8b41c844d3933e1b6dffef3455", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_eefe6217955d49c4a488c970ed496337", "IPY_MODEL_149232314ef24e5cb54e5264217b01ae" ] } }, "e9cb0a8b41c844d3933e1b6dffef3455": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "eefe6217955d49c4a488c970ed496337": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_0a26c722f94e42bc9c7162b4b3aa9b8d", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 1042301, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 1042301, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_5f30e1dd6ebc487b8e156e29cc8cfb4a" } }, "149232314ef24e5cb54e5264217b01ae": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_a103897a377146558564797867a8c0e2", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 1.04M/1.04M [00:00<00:00, 2.93MB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_e304e8e3ac814a89aebefa619b1b74db" } }, "0a26c722f94e42bc9c7162b4b3aa9b8d": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "5f30e1dd6ebc487b8e156e29cc8cfb4a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "a103897a377146558564797867a8c0e2": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "e304e8e3ac814a89aebefa619b1b74db": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "b579736c4c224e3e99399322e96a2a6b": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_981e5ff3c1224acb93f775660c8ba65a", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_d7cfadc533fa42c6885ba97ab4ca7a49", "IPY_MODEL_72f7b08bf77146fa81a3980f8785bdb1" ] } }, "981e5ff3c1224acb93f775660c8ba65a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "d7cfadc533fa42c6885ba97ab4ca7a49": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_1341f329f46c4ee29f63beee956b05fa", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 456318, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 456318, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_82b0e5e7c26e4de586fad746efc3f7e0" } }, "72f7b08bf77146fa81a3980f8785bdb1": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_9c7e761ce078459585eda1041e178b32", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 456k/456k [00:00<00:00, 2.14MB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_33bf48e6f4464faf8b530bfec3942539" } }, "1341f329f46c4ee29f63beee956b05fa": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "82b0e5e7c26e4de586fad746efc3f7e0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "9c7e761ce078459585eda1041e178b32": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "33bf48e6f4464faf8b530bfec3942539": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "b59e043e9f1e4ab8ac2e400cf77948b8": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_44d179e899ae4fb9b646af7c1d54e588", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_c3634536ccc446abb32c90f13629d9fb", "IPY_MODEL_d63b5bcbee684bc7a967585798406780" ] } }, "44d179e899ae4fb9b646af7c1d54e588": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "c3634536ccc446abb32c90f13629d9fb": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_18004e8f29c84e6ea834dfd4629c3ab4", "_dom_classes": [], "description": "Downloading: 100%", "_model_name": "FloatProgressModel", "bar_style": "success", "max": 351265583, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 351265583, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_2c728d7d5cd2468ebf6695d7df7de2aa" } }, "d63b5bcbee684bc7a967585798406780": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_92ff15cfcb0049828b2a1ae93be73567", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 351M/351M [00:20<00:00, 16.9MB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_121fdc89bcef4443ad83d34f24891d79" } }, "18004e8f29c84e6ea834dfd4629c3ab4": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "initial", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "2c728d7d5cd2468ebf6695d7df7de2aa": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "92ff15cfcb0049828b2a1ae93be73567": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "121fdc89bcef4443ad83d34f24891d79": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } } } } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "VTze-VbeU1c0" }, "source": [ "# Fine-tune a DialoGPT model\n", "\n", "Adapted from the notebook in [this Medium post](https://towardsdatascience.com/make-your-own-rick-sanchez-bot-with-transformers-and-dialogpt-fine-tuning-f85e6d1f4e30?gi=e4a72d1510f0)." ] }, { "cell_type": "markdown", "metadata": { "id": "Y17kuzFNUSrZ" }, "source": [ "## Setup" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GBfltjGHT6KG", "outputId": "1c6cbc4d-7451-4f4c-cf36-97fea0808e92" }, "source": [ "from google.colab import drive\n", "drive.mount('/content/drive/')" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "Mounted at /content/drive/\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "T8fgmjaqUErq", "outputId": "ecf98ec1-44ed-4619-d290-acf5ac422a8d", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "!pip -q install transformers" ], "execution_count": 2, "outputs": [ { "output_type": "stream", "text": [ "\u001b[K |████████████████████████████████| 2.5MB 33.5MB/s \n", "\u001b[K |████████████████████████████████| 901kB 42.1MB/s \n", "\u001b[K |████████████████████████████████| 3.3MB 34.4MB/s \n", "\u001b[?25h" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "EtCreyG8UG1s" }, "source": [ "import os\n", "os.chdir(\"/content/drive/My Drive/Colab Notebooks\")" ], "execution_count": 3, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "dnv5kT-mLsB-" }, "source": [ "# all the imports\n", "\n", "import glob\n", "import logging\n", "import os\n", "import pickle\n", "import random\n", "import re\n", "import shutil\n", "from typing import Dict, List, Tuple\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from sklearn.model_selection import train_test_split\n", "\n", "from torch.nn.utils.rnn import pad_sequence\n", "from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\n", "from torch.utils.data.distributed import DistributedSampler\n", "from tqdm.notebook import tqdm, trange\n", "\n", "from pathlib import Path\n", "\n", "from transformers import (\n", " MODEL_WITH_LM_HEAD_MAPPING,\n", " WEIGHTS_NAME,\n", " AdamW,\n", " AutoConfig,\n", " PreTrainedModel,\n", " PreTrainedTokenizer,\n", " get_linear_schedule_with_warmup,\n", ")\n", "\n", "\n", "try:\n", " from torch.utils.tensorboard import SummaryWriter\n", "except ImportError:\n", " from tensorboardX import SummaryWriter" ], "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "BmrbGB8aUmBm" }, "source": [ "## Get Data from Kaggle" ] }, { "cell_type": "code", "metadata": { "id": "ftBYBoOoV_Er" }, "source": [ "!mkdir ~/.kaggle\n", "!cp kaggle.json ~/.kaggle/kaggle.json" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "fbITTMcLVbI_" }, "source": [ "!kaggle datasets download ruolinzheng/twewy-game-script -f twewy-name-line-full.csv" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "RXdJTSVwWGHj" }, "source": [ "data = pd.read_csv('mcu.csv')" ], "execution_count": 5, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 792 }, "id": "h6kGx-9eG7qA", "outputId": "0a92c7ac-34a9-4852-fdc7-1a46c259f5de" }, "source": [ "data.sample(6)" ], "execution_count": 6, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0characterlinemovieyearwordsAdam McKayAnna BodenArt MarcumAshley Edward MillerChris McKennaChristopher FordChristopher MarkusChristopher YostCraig KyleDon PayneDrew PearceEdgar WrightEric PearsonErik SommersGeneva Robertson-DworetHawk OstbyJames GunnJoe CornishJoe Robert ColeJohn Francis DaleyJon WattsJonathan GoldsteinJoss WhedonJustin TherouxMark FergusMatt HollowayPaul RuddRyan CooglerRyan FleckShane BlackStephen McFeelyZack Stentz
46114611THE MANDARINWell then. What are we waiting for?Iron Man 320137FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseTrueFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseTrueFalseFalse
69996999PIETRO MAXIMOFFI know what they are.Avengers: Age of Ultron20155FalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseTrueFalseFalseFalseFalseFalseFalseFalseFalseFalse
1224912249THORShe’s too strong. Without my hammer I cannot--Thor: Ragnarok20179FalseFalseFalseFalseFalseFalseFalseTrueTrueFalseFalseFalseTrueFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
52475247JANE FOSTEROh, they're kids.Thor: The Dark World20133FalseFalseFalseFalseFalseFalseTrueTrueFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseTrueFalse
51865186THORI thank you for your sword and for your counse...Thor: The Dark World201313FalseFalseFalseFalseFalseFalseTrueTrueFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseTrueFalse
1348113481PETER PARKERBut it was such a long way down and I just tho...Avengers: Infinity War201817FalseFalseFalseFalseFalseFalseTrueFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseFalseTrueFalse
\n", "
" ], "text/plain": [ " Unnamed: 0 character ... Stephen McFeely Zack Stentz\n", "4611 4611 THE MANDARIN ... False False\n", "6999 6999 PIETRO MAXIMOFF ... False False\n", "12249 12249 THOR ... False False\n", "5247 5247 JANE FOSTER ... True False\n", "5186 5186 THOR ... True False\n", "13481 13481 PETER PARKER ... True False\n", "\n", "[6 rows x 38 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 6 } ] }, { "cell_type": "code", "metadata": { "id": "PG8v6--qWUwj" }, "source": [ "CHARACTER_NAME = 'TONY STARK'" ], "execution_count": 7, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "GZUcEMd2WLDT" }, "source": [ "contexted = []\n", "\n", "# context window of size 7\n", "n = 7\n", "\n", "for i in data[data.character == CHARACTER_NAME].index:\n", " if i < n:\n", " continue\n", " row = []\n", " prev = i - 1 - n # we additionally substract 1, so row will contain current responce and 7 previous responces \n", " for j in range(i, prev, -1):\n", " row.append(data.line[j])\n", " contexted.append(row)\n", "\n", "columns = ['response', 'context'] \n", "columns = columns + ['context/' + str(i) for i in range(n - 1)]\n", "\n", "df = pd.DataFrame.from_records(contexted, columns=columns)" ], "execution_count": 8, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 703 }, "id": "4T5OlNZHUxij", "outputId": "514686d3-5edc-4f47-9fee-aa221ad3ab5a" }, "source": [ "df.sample(6)" ], "execution_count": 9, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
responsecontextcontext/0context/1context/2context/3context/4context/5
1450Ah, please hold.Tony, we have a problem.Yeah, put him through.Priority call from Secretary Ross. There's bee...Tony, I'm glad you're back at the compound. I ...Yes, this is--this is Tony \"Stank\". You're in ...Are you Tony \"Stank\"?Oh yeah.
1065Six, high, back. Alright. You see that? Naile...Stay on my six, cover high and don't shoot me ...Yep. What do I do?ls your gun up?Okay. That's good. Now give me cameras A throu...Broadcast will commence shortly. Take final po...Yeah, death by oil.Viking funeral. Public execution.
1080I got you covered.What does that mean?Oh, I'm sorry, they're only coded to me.Oh, yeah. That's awesome. Give me a suit, okay?Everybody needs a hobby. Heartbreaker, help Re...This is how you've been managing your down tim...Incoming! Jarvis, get Igor to steady this thing.Gentlemen.
89What happened over there? I had my eyes opene...The official report was sketchy. What happened...Yes. That’s right.Something besides weapons?No, I don’t want to retire. I want to do some...You mean you’re retiring?I...can’t do this anymore.I’m sure he will. Now if you could just take ...
67Or something very big for fifteen minutes. Let...That could run your heart for fifty lifetimes.Three gigajoules -- per second.It’s pretty small, what can it generate?Yeah, but this one is going to last a bit long...What an original invention.That’s because it’s a miniature ARK reactor. ...Look like an animal, and soon you’ll start beh...
1264Romanoff? You and Banner better not be playing...Yeah.You good?Avengers, time to work for a living.On it.Rhodey, get the rest of the people on board th...We're out of time. They're coming for the core.Thor, I got a plan!
\n", "
" ], "text/plain": [ " response ... context/5\n", "1450 Ah, please hold. ... Oh yeah.\n", "1065 Six, high, back. Alright. You see that? Naile... ... Viking funeral. Public execution.\n", "1080 I got you covered. ... Gentlemen.\n", "89 What happened over there? I had my eyes opene... ... I’m sure he will. Now if you could just take ...\n", "67 Or something very big for fifteen minutes. Let... ... Look like an animal, and soon you’ll start beh...\n", "1264 Romanoff? You and Banner better not be playing... ... Thor, I got a plan!\n", "\n", "[6 rows x 8 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 9 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 603 }, "id": "NGy0MxMQVIAP", "outputId": "6e38cdfe-eb03-495b-bd5b-5cbb270f16e5" }, "source": [ "trn_df, val_df = train_test_split(df, test_size=0.2)\n", "trn_df.head()" ], "execution_count": 10, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
responsecontextcontext/0context/1context/2context/3context/4context/5
218Where’s Pepper?TONY--Clear.I’ll ride with her.--I don’t know, he’s not answering his phone. ...What do you mean, he paid to have Tony killed?...Tony. Please. I’ll be out of here in a minut...Pepper. I should take that.
1039What? No. But I do know it's happening off the...Do you know what they did to my suit?Spill.Doing what? Ow, I get it! Ow! That hurt. I get...You tell him where Pepper is and he'll stop do...Here's how it works, Meryl Streep.Hi, Trevor. Trevor Slattery. I know I'm shorte...Yeah, I know, it's... It's embarrassing.
1195I'm calling in VERONICA. Alright everybody, s...Well, that's not gonna happen. Not for a while...News or footage, keyword: Hulk. Natasha, I co...Of course not, I'm already there. You'll catch...And you're not going anywhere.Ah, the Vibranium's getting away.No. I'm over it. I want...I want to finish the...I'm gonna kill him. I'll be right back.
1459Three, two, one. Hey, May. My gosh, uh, I want...Mm-hmm.Hey, May. How you doing? What are you wearing?...Okay.Get in the frame.An alibi? Sure.We rolling?Yeah, hold on.
1718Okay, Cap, I make ten of them, just passing th...Hulk take stairs.All right. Flick me.All right, you’re up, Little Buddy. There’s ou...One sec, just packing my lunch.Yes, that’s much better. You coming, Stark?“I’m on my way down to coordinate search and r...I’m on my way down to coordinate search and re...
\n", "
" ], "text/plain": [ " response ... context/5\n", "218 Where’s Pepper? ... Pepper. I should take that.\n", "1039 What? No. But I do know it's happening off the... ... Yeah, I know, it's... It's embarrassing.\n", "1195 I'm calling in VERONICA. Alright everybody, s... ... I'm gonna kill him. I'll be right back.\n", "1459 Three, two, one. Hey, May. My gosh, uh, I want... ... Yeah, hold on.\n", "1718 Okay, Cap, I make ten of them, just passing th... ... I’m on my way down to coordinate search and re...\n", "\n", "[5 rows x 8 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 10 } ] }, { "cell_type": "code", "metadata": { "id": "aEeJQlAKWtiJ" }, "source": [ "# create dataset suitable for our model\n", "def construct_conv(row, tokenizer, eos = True):\n", " flatten = lambda l: [item for sublist in l for item in sublist]\n", " conv = list(reversed([tokenizer.encode(x) + [tokenizer.eos_token_id] for x in row]))\n", " conv = flatten(conv)\n", " return conv\n", "\n", "class ConversationDataset(Dataset):\n", " def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):\n", "\n", " block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence)\n", "\n", " directory = args.cache_dir\n", " cached_features_file = os.path.join(\n", " directory, args.model_type + \"_cached_lm_\" + str(block_size)\n", " )\n", "\n", " if os.path.exists(cached_features_file) and not args.overwrite_cache:\n", " logger.info(\"Loading features from cached file %s\", cached_features_file)\n", " with open(cached_features_file, \"rb\") as handle:\n", " self.examples = pickle.load(handle)\n", " else:\n", " logger.info(\"Creating features from dataset file at %s\", directory)\n", "\n", " self.examples = []\n", " for _, row in df.iterrows():\n", " conv = construct_conv(row, tokenizer)\n", " self.examples.append(conv)\n", "\n", " logger.info(\"Saving features into cached file %s\", cached_features_file)\n", " with open(cached_features_file, \"wb\") as handle:\n", " pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", "\n", " def __len__(self):\n", " return len(self.examples)\n", "\n", " def __getitem__(self, item):\n", " return torch.tensor(self.examples[item], dtype=torch.long)" ], "execution_count": 11, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "-3iHwoKlWyrs" }, "source": [ "# Cacheing and storing of data/checkpoints\n", "\n", "def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):\n", " return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)\n", "\n", "\n", "def set_seed(args):\n", " random.seed(args.seed)\n", " np.random.seed(args.seed)\n", " torch.manual_seed(args.seed)\n", " if args.n_gpu > 0:\n", " torch.cuda.manual_seed_all(args.seed)\n", "\n", "\n", "def _sorted_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> List[str]:\n", " ordering_and_checkpoint_path = []\n", "\n", " glob_checkpoints = glob.glob(os.path.join(args.output_dir, \"{}-*\".format(checkpoint_prefix)))\n", "\n", " for path in glob_checkpoints:\n", " if use_mtime:\n", " ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n", " else:\n", " regex_match = re.match(\".*{}-([0-9]+)\".format(checkpoint_prefix), path)\n", " if regex_match and regex_match.groups():\n", " ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n", "\n", " checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n", " checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n", " return checkpoints_sorted\n", "\n", "\n", "def _rotate_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> None:\n", " if not args.save_total_limit:\n", " return\n", " if args.save_total_limit <= 0:\n", " return\n", "\n", " # Check if we should delete older checkpoint(s)\n", " checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)\n", " if len(checkpoints_sorted) <= args.save_total_limit:\n", " return\n", "\n", " number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)\n", " checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n", " for checkpoint in checkpoints_to_be_deleted:\n", " logger.info(\"Deleting older checkpoint [{}] due to args.save_total_limit\".format(checkpoint))\n", " shutil.rmtree(checkpoint)" ], "execution_count": 12, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "EEDdTJTqUwZJ" }, "source": [ "## Build Model" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 325, "referenced_widgets": [ "ad1f7a8dbb624d8988ade18adab9256a", "49da911827a943e0aadb2efd0329d147", "a7ec07a01a694048a05ed29df99984f6", "1fa89f5112884cf6b43288bae7e710cb", "bd094c08455d4f0a84f18afbb3b69ef8", "c73a9873c97e4a41be77376199e1fb32", "e7afc24c895d48ca951144df72c30278", "52cb01ad3daf42c98699983c46eaa4c5", "b5ced843807b41d8b3620e841785e0e0", "4fa3061c511d4afb847c89b3fe1cf72d", "f9c5ce7bb2394b44b6915d0e6d3385fe", "ce14e3b5f648478f84bebce1c451b1db", "f3f5b0d3ffec4b23b149ccbe9f9cb106", "b5b8b2092a3a4ee18dcbcc0651a7163f", "e943740e8a2a4add827227dfd16dbe00", "5f95112a99ca4609aa57d0faabfc47f0", "fdeb9c5964b1450394ea02d57387f87a", "e9cb0a8b41c844d3933e1b6dffef3455", "eefe6217955d49c4a488c970ed496337", "149232314ef24e5cb54e5264217b01ae", "0a26c722f94e42bc9c7162b4b3aa9b8d", "5f30e1dd6ebc487b8e156e29cc8cfb4a", "a103897a377146558564797867a8c0e2", "e304e8e3ac814a89aebefa619b1b74db", "b579736c4c224e3e99399322e96a2a6b", "981e5ff3c1224acb93f775660c8ba65a", "d7cfadc533fa42c6885ba97ab4ca7a49", "72f7b08bf77146fa81a3980f8785bdb1", "1341f329f46c4ee29f63beee956b05fa", "82b0e5e7c26e4de586fad746efc3f7e0", "9c7e761ce078459585eda1041e178b32", "33bf48e6f4464faf8b530bfec3942539", "b59e043e9f1e4ab8ac2e400cf77948b8", "44d179e899ae4fb9b646af7c1d54e588", "c3634536ccc446abb32c90f13629d9fb", "d63b5bcbee684bc7a967585798406780", "18004e8f29c84e6ea834dfd4629c3ab4", "2c728d7d5cd2468ebf6695d7df7de2aa", "92ff15cfcb0049828b2a1ae93be73567", "121fdc89bcef4443ad83d34f24891d79" ] }, "id": "r2cE0fY5UHpz", "outputId": "d9e32046-31d3-4be3-eb69-1d466748d762" }, "source": [ "from transformers import AutoModelWithLMHead, AutoModelForCausalLM, AutoTokenizer\n", "import torch\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-small\")\n", "model = AutoModelWithLMHead.from_pretrained(\"microsoft/DialoGPT-small\")" ], "execution_count": 13, "outputs": [ { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ad1f7a8dbb624d8988ade18adab9256a", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=26.0, style=ProgressStyle(description_w…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b5ced843807b41d8b3620e841785e0e0", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=641.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fdeb9c5964b1450394ea02d57387f87a", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b579736c4c224e3e99399322e96a2a6b", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/transformers/models/auto/modeling_auto.py:847: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n", " FutureWarning,\n" ], "name": "stderr" }, { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b59e043e9f1e4ab8ac2e400cf77948b8", "version_minor": 0, "version_major": 2 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=351265583.0, style=ProgressStyle(descri…" ] }, "metadata": { "tags": [] } }, { "output_type": "stream", "text": [ "\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "ra2vsRp-UMXo" }, "source": [ "\"\"\"\n", "Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).\n", "GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned\n", "using a masked language modeling (MLM) loss.\n", "\"\"\"\n", "\n", "# Configs\n", "logger = logging.getLogger(__name__)\n", "\n", "MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())\n", "MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)" ], "execution_count": 14, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "2OnASqJjUNJa" }, "source": [ "# Args to allow for easy convertion of python script to notebook\n", "class Args():\n", " def __init__(self):\n", " self.output_dir = 'output-medium'\n", " self.model_type = 'gpt2'\n", " self.model_name_or_path = 'microsoft/DialoGPT-small'\n", " self.config_name = 'microsoft/DialoGPT-small'\n", " self.tokenizer_name = 'microsoft/DialoGPT-small'\n", " self.cache_dir = 'cached'\n", " self.block_size = 512\n", " self.do_train = True\n", " self.do_eval = True\n", " self.evaluate_during_training = False\n", " self.per_gpu_train_batch_size = 4\n", " self.per_gpu_eval_batch_size = 4\n", " self.gradient_accumulation_steps = 1\n", " self.learning_rate = 5e-5\n", " self.weight_decay = 0.0\n", " self.adam_epsilon = 1e-8\n", " self.max_grad_norm = 1.0\n", " self.num_train_epochs = 4\n", " self.max_steps = -1\n", " self.warmup_steps = 0\n", " self.logging_steps = 1000\n", " self.save_steps = 3500\n", " self.save_total_limit = None\n", " self.eval_all_checkpoints = False\n", " self.no_cuda = False\n", " self.overwrite_output_dir = True\n", " self.overwrite_cache = True\n", " self.should_continue = False\n", " self.seed = 42\n", " self.local_rank = -1\n", " self.fp16 = False\n", " self.fp16_opt_level = 'O1'\n", "\n", "args = Args()" ], "execution_count": 15, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "9Q1dTFXxW9NE" }, "source": [ "## Train and Evaluate" ] }, { "cell_type": "code", "metadata": { "id": "PaarIDZrW81h" }, "source": [ "def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:\n", " \"\"\" Train the model \"\"\"\n", " if args.local_rank in [-1, 0]:\n", " tb_writer = SummaryWriter()\n", "\n", " args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)\n", "\n", " def collate(examples: List[torch.Tensor]):\n", " if tokenizer._pad_token is None:\n", " return pad_sequence(examples, batch_first=True)\n", " return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n", "\n", " train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)\n", " train_dataloader = DataLoader(\n", " train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate, drop_last = True\n", " )\n", "\n", " if args.max_steps > 0:\n", " t_total = args.max_steps\n", " args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1\n", " else:\n", " t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs\n", "\n", " model = model.module if hasattr(model, \"module\") else model # Take care of distributed/parallel training\n", " model.resize_token_embeddings(len(tokenizer))\n", " # add_special_tokens_(model, tokenizer)\n", "\n", "\n", " # Prepare optimizer and schedule (linear warmup and decay)\n", " no_decay = [\"bias\", \"LayerNorm.weight\"]\n", " optimizer_grouped_parameters = [\n", " {\n", " \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n", " \"weight_decay\": args.weight_decay,\n", " },\n", " {\"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], \"weight_decay\": 0.0},\n", " ]\n", " optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)\n", " scheduler = get_linear_schedule_with_warmup(\n", " optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total\n", " )\n", "\n", " # Check if saved optimizer or scheduler states exist\n", " if (\n", " args.model_name_or_path\n", " and os.path.isfile(os.path.join(args.model_name_or_path, \"optimizer.pt\"))\n", " and os.path.isfile(os.path.join(args.model_name_or_path, \"scheduler.pt\"))\n", " ):\n", " # Load in optimizer and scheduler states\n", " optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"optimizer.pt\")))\n", " scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"scheduler.pt\")))\n", "\n", " if args.fp16:\n", " try:\n", " from apex import amp\n", " except ImportError:\n", " raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n", " model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)\n", "\n", " # multi-gpu training (should be after apex fp16 initialization)\n", " if args.n_gpu > 1:\n", " model = torch.nn.DataParallel(model)\n", "\n", " # Distributed training (should be after apex fp16 initialization)\n", " if args.local_rank != -1:\n", " model = torch.nn.parallel.DistributedDataParallel(\n", " model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True\n", " )\n", "\n", " # Train!\n", " logger.info(\"***** Running training *****\")\n", " logger.info(\" Num examples = %d\", len(train_dataset))\n", " logger.info(\" Num Epochs = %d\", args.num_train_epochs)\n", " logger.info(\" Instantaneous batch size per GPU = %d\", args.per_gpu_train_batch_size)\n", " logger.info(\n", " \" Total train batch size (w. parallel, distributed & accumulation) = %d\",\n", " args.train_batch_size\n", " * args.gradient_accumulation_steps\n", " * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),\n", " )\n", " logger.info(\" Gradient Accumulation steps = %d\", args.gradient_accumulation_steps)\n", " logger.info(\" Total optimization steps = %d\", t_total)\n", "\n", " global_step = 0\n", " epochs_trained = 0\n", " steps_trained_in_current_epoch = 0\n", " # Check if continuing training from a checkpoint\n", " if args.model_name_or_path and os.path.exists(args.model_name_or_path):\n", " try:\n", " # set global_step to gobal_step of last saved checkpoint from model path\n", " checkpoint_suffix = args.model_name_or_path.split(\"-\")[-1].split(\"/\")[0]\n", " global_step = int(checkpoint_suffix)\n", " epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)\n", " steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)\n", "\n", " logger.info(\" Continuing training from checkpoint, will skip to saved global_step\")\n", " logger.info(\" Continuing training from epoch %d\", epochs_trained)\n", " logger.info(\" Continuing training from global step %d\", global_step)\n", " logger.info(\" Will skip the first %d steps in the first epoch\", steps_trained_in_current_epoch)\n", " except ValueError:\n", " logger.info(\" Starting fine-tuning.\")\n", "\n", " tr_loss, logging_loss = 0.0, 0.0\n", "\n", " model.zero_grad()\n", " train_iterator = trange(\n", " epochs_trained, int(args.num_train_epochs), desc=\"Epoch\", disable=args.local_rank not in [-1, 0]\n", " )\n", " set_seed(args) # Added here for reproducibility\n", " for _ in train_iterator:\n", " epoch_iterator = tqdm(train_dataloader, desc=\"Iteration\", disable=args.local_rank not in [-1, 0])\n", " for step, batch in enumerate(epoch_iterator):\n", "\n", " # Skip past any already trained steps if resuming training\n", " if steps_trained_in_current_epoch > 0:\n", " steps_trained_in_current_epoch -= 1\n", " continue\n", "\n", " inputs, labels = (batch, batch)\n", " if inputs.shape[1] > 1024: continue\n", " inputs = inputs.to(args.device)\n", " labels = labels.to(args.device)\n", " model.train()\n", " outputs = model(inputs, labels=labels)\n", " loss = outputs[0] # model outputs are always tuple in transformers (see doc)\n", "\n", " if args.n_gpu > 1:\n", " loss = loss.mean() # mean() to average on multi-gpu parallel training\n", " if args.gradient_accumulation_steps > 1:\n", " loss = loss / args.gradient_accumulation_steps\n", "\n", " if args.fp16:\n", " with amp.scale_loss(loss, optimizer) as scaled_loss:\n", " scaled_loss.backward()\n", " else:\n", " loss.backward()\n", "\n", " tr_loss += loss.item()\n", " if (step + 1) % args.gradient_accumulation_steps == 0:\n", " if args.fp16:\n", " torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)\n", " else:\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n", " optimizer.step()\n", " scheduler.step() # Update learning rate schedule\n", " model.zero_grad()\n", " global_step += 1\n", "\n", " if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:\n", " # Log metrics\n", " if (\n", " args.local_rank == -1 and args.evaluate_during_training\n", " ): # Only evaluate when single GPU otherwise metrics may not average well\n", " results = evaluate(args, model, tokenizer)\n", " for key, value in results.items():\n", " tb_writer.add_scalar(\"eval_{}\".format(key), value, global_step)\n", " tb_writer.add_scalar(\"lr\", scheduler.get_lr()[0], global_step)\n", " tb_writer.add_scalar(\"loss\", (tr_loss - logging_loss) / args.logging_steps, global_step)\n", " logging_loss = tr_loss\n", "\n", " if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:\n", " checkpoint_prefix = \"checkpoint\"\n", " # Save model checkpoint\n", " output_dir = os.path.join(args.output_dir, \"{}-{}\".format(checkpoint_prefix, global_step))\n", " os.makedirs(output_dir, exist_ok=True)\n", " model_to_save = (\n", " model.module if hasattr(model, \"module\") else model\n", " ) # Take care of distributed/parallel training\n", " model_to_save.save_pretrained(output_dir)\n", " tokenizer.save_pretrained(output_dir)\n", "\n", " torch.save(args, os.path.join(output_dir, \"training_args.bin\"))\n", " logger.info(\"Saving model checkpoint to %s\", output_dir)\n", "\n", " _rotate_checkpoints(args, checkpoint_prefix)\n", "\n", " torch.save(optimizer.state_dict(), os.path.join(output_dir, \"optimizer.pt\"))\n", " torch.save(scheduler.state_dict(), os.path.join(output_dir, \"scheduler.pt\"))\n", " logger.info(\"Saving optimizer and scheduler states to %s\", output_dir)\n", "\n", " if args.max_steps > 0 and global_step > args.max_steps:\n", " epoch_iterator.close()\n", " break\n", " if args.max_steps > 0 and global_step > args.max_steps:\n", " train_iterator.close()\n", " break\n", "\n", " if args.local_rank in [-1, 0]:\n", " tb_writer.close()\n", "\n", " return global_step, tr_loss / global_step\n", "\n", "# Evaluation of some model\n", "\n", "def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, df_trn, df_val, prefix=\"\") -> Dict:\n", " # Loop to handle MNLI double evaluation (matched, mis-matched)\n", " eval_output_dir = args.output_dir\n", "\n", " eval_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=True)\n", " os.makedirs(eval_output_dir, exist_ok=True)\n", " args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)\n", " # Note that DistributedSampler samples randomly\n", "\n", " def collate(examples: List[torch.Tensor]):\n", " if tokenizer._pad_token is None:\n", " return pad_sequence(examples, batch_first=True)\n", " return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n", "\n", " eval_sampler = SequentialSampler(eval_dataset)\n", " eval_dataloader = DataLoader(\n", " eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate, drop_last = True\n", " )\n", "\n", " # multi-gpu evaluate\n", " if args.n_gpu > 1:\n", " model = torch.nn.DataParallel(model)\n", "\n", " # Eval!\n", " logger.info(\"***** Running evaluation {} *****\".format(prefix))\n", " logger.info(\" Num examples = %d\", len(eval_dataset))\n", " logger.info(\" Batch size = %d\", args.eval_batch_size)\n", " eval_loss = 0.0\n", " nb_eval_steps = 0\n", " model.eval()\n", "\n", " for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n", " inputs, labels = (batch, batch)\n", " inputs = inputs.to(args.device)\n", " labels = labels.to(args.device)\n", "\n", " with torch.no_grad():\n", " outputs = model(inputs, labels=labels)\n", " lm_loss = outputs[0]\n", " eval_loss += lm_loss.mean().item()\n", " nb_eval_steps += 1\n", "\n", " eval_loss = eval_loss / nb_eval_steps\n", " perplexity = torch.exp(torch.tensor(eval_loss))\n", "\n", " result = {\"perplexity\": perplexity}\n", "\n", " output_eval_file = os.path.join(eval_output_dir, prefix, \"eval_results.txt\")\n", " with open(output_eval_file, \"w\") as writer:\n", " logger.info(\"***** Eval results {} *****\".format(prefix))\n", " for key in sorted(result.keys()):\n", " logger.info(\" %s = %s\", key, str(result[key]))\n", " writer.write(\"%s = %s\\n\" % (key, str(result[key])))\n", "\n", " return result" ], "execution_count": 16, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "SCnGAJWbXD9C" }, "source": [ "# Main runner\n", "\n", "def main(df_trn, df_val):\n", " args = Args()\n", " \n", " if args.should_continue:\n", " sorted_checkpoints = _sorted_checkpoints(args)\n", " if len(sorted_checkpoints) == 0:\n", " raise ValueError(\"Used --should_continue but no checkpoint was found in --output_dir.\")\n", " else:\n", " args.model_name_or_path = sorted_checkpoints[-1]\n", "\n", " if (\n", " os.path.exists(args.output_dir)\n", " and os.listdir(args.output_dir)\n", " and args.do_train\n", " and not args.overwrite_output_dir\n", " and not args.should_continue\n", " ):\n", " raise ValueError(\n", " \"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.\".format(\n", " args.output_dir\n", " )\n", " )\n", "\n", " # Setup CUDA, GPU & distributed training\n", " device = torch.device(\"cuda\")\n", " args.n_gpu = torch.cuda.device_count()\n", " args.device = device\n", "\n", " # Setup logging\n", " logging.basicConfig(\n", " format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n", " datefmt=\"%m/%d/%Y %H:%M:%S\",\n", " level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,\n", " )\n", " logger.warning(\n", " \"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s\",\n", " args.local_rank,\n", " device,\n", " args.n_gpu,\n", " bool(args.local_rank != -1),\n", " args.fp16,\n", " )\n", "\n", " # Set seed\n", " set_seed(args)\n", "\n", " config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)\n", " tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)\n", " model = AutoModelWithLMHead.from_pretrained(\n", " args.model_name_or_path,\n", " from_tf=False,\n", " config=config,\n", " cache_dir=args.cache_dir,\n", " )\n", " model.to(args.device)\n", " \n", " logger.info(\"Training/evaluation parameters %s\", args)\n", "\n", " # Training\n", " if args.do_train:\n", " train_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False)\n", "\n", " global_step, tr_loss = train(args, train_dataset, model, tokenizer)\n", " logger.info(\" global_step = %s, average loss = %s\", global_step, tr_loss)\n", "\n", " # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()\n", " if args.do_train:\n", " # Create output directory if needed\n", " os.makedirs(args.output_dir, exist_ok=True)\n", "\n", " logger.info(\"Saving model checkpoint to %s\", args.output_dir)\n", " # Save a trained model, configuration and tokenizer using `save_pretrained()`.\n", " # They can then be reloaded using `from_pretrained()`\n", " model_to_save = (\n", " model.module if hasattr(model, \"module\") else model\n", " ) # Take care of distributed/parallel training\n", " model_to_save.save_pretrained(args.output_dir)\n", " tokenizer.save_pretrained(args.output_dir)\n", "\n", " # Good practice: save your training arguments together with the trained model\n", " torch.save(args, os.path.join(args.output_dir, \"training_args.bin\"))\n", "\n", " # Load a trained model and vocabulary that you have fine-tuned\n", " model = AutoModelWithLMHead.from_pretrained(args.output_dir)\n", " tokenizer = AutoTokenizer.from_pretrained(args.output_dir)\n", " model.to(args.device)\n", "\n", " # Evaluation\n", " results = {}\n", " if args.do_eval and args.local_rank in [-1, 0]:\n", " checkpoints = [args.output_dir]\n", " if args.eval_all_checkpoints:\n", " checkpoints = list(\n", " os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + \"/**/\" + WEIGHTS_NAME, recursive=True))\n", " )\n", " logging.getLogger(\"transformers.modeling_utils\").setLevel(logging.WARN) # Reduce logging\n", " logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n", " for checkpoint in checkpoints:\n", " global_step = checkpoint.split(\"-\")[-1] if len(checkpoints) > 1 else \"\"\n", " prefix = checkpoint.split(\"/\")[-1] if checkpoint.find(\"checkpoint\") != -1 else \"\"\n", "\n", " model = AutoModelWithLMHead.from_pretrained(checkpoint)\n", " model.to(args.device)\n", " result = evaluate(args, model, tokenizer, df_trn, df_val, prefix=prefix)\n", " result = dict((k + \"_{}\".format(global_step), v) for k, v in result.items())\n", " results.update(result)\n", "\n", " return results" ], "execution_count": 17, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "7NWvkdR-XHeB" }, "source": [ "## Run the Main Function" ] }, { "cell_type": "code", "metadata": { "id": "e61zo2JtXGNX" }, "source": [ "main(trn_df, val_df)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "YRpQ_n2zXQj-" }, "source": [ "## Load the Trained Model" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HGw3qgfaXQHX", "outputId": "ec6cb27f-7d33-42c2-e10e-0e6993472fcf" }, "source": [ "tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')\n", "model = AutoModelWithLMHead.from_pretrained('output-medium')" ], "execution_count": 21, "outputs": [ { "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/transformers/models/auto/modeling_auto.py:847: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n", " FutureWarning,\n" ], "name": "stderr" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lAWsiAvNXbxd", "outputId": "40a7c370-cc54-45b6-e5f6-f6861ce7c1d1" }, "source": [ "# Let's chat for 4 lines\n", "for step in range(4):\n", " # encode the new user input, add the eos_token and return a tensor in Pytorch\n", " new_user_input_ids = tokenizer.encode(input(\">> User:\") + tokenizer.eos_token, return_tensors='pt')\n", " # print(new_user_input_ids)\n", "\n", " # append the new user input tokens to the chat history\n", " bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids\n", "\n", " # generated a response while limiting the total chat history to 1000 tokens, \n", " chat_history_ids = model.generate(\n", " bot_input_ids, max_length=200,\n", " pad_token_id=tokenizer.eos_token_id, \n", " no_repeat_ngram_size=3, \n", " do_sample=True, \n", " top_k=100, \n", " top_p=0.7,\n", " temperature=0.8\n", " )\n", " \n", " # pretty print last ouput tokens from bot\n", " print(\"JoshuaBot: {}\".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))" ], "execution_count": 22, "outputs": [ { "output_type": "stream", "text": [ ">> User:hi\n", "JoshuaBot: Hi.\n", ">> User:Who are you\n", "JoshuaBot: I’m Tony Stark.\n", ">> User:ok\n", "JoshuaBot: You are.\n", ">> User:Your mom\n", "JoshuaBot: !!!\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "ANSQlQezXqwn" }, "source": [ "## Push Model to Hugging Face" ] }, { "cell_type": "code", "metadata": { "id": "VgnHRgHKXwDd", "outputId": "18ca84a0-9b74-4719-befa-c84016b7bf50", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "!sudo apt-get install git-lfs" ], "execution_count": 23, "outputs": [ { "output_type": "stream", "text": [ "Reading package lists... Done\n", "Building dependency tree \n", "Reading state information... Done\n", "The following NEW packages will be installed:\n", " git-lfs\n", "0 upgraded, 1 newly installed, 0 to remove and 39 not upgraded.\n", "Need to get 2,129 kB of archives.\n", "After this operation, 7,662 kB of additional disk space will be used.\n", "Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 git-lfs amd64 2.3.4-1 [2,129 kB]\n", "Fetched 2,129 kB in 1s (2,904 kB/s)\n", "debconf: unable to initialize frontend: Dialog\n", "debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 76, <> line 1.)\n", "debconf: falling back to frontend: Readline\n", "debconf: unable to initialize frontend: Readline\n", "debconf: (This frontend requires a controlling tty.)\n", "debconf: falling back to frontend: Teletype\n", "dpkg-preconfigure: unable to re-open stdin: \n", "Selecting previously unselected package git-lfs.\n", "(Reading database ... 160837 files and directories currently installed.)\n", "Preparing to unpack .../git-lfs_2.3.4-1_amd64.deb ...\n", "Unpacking git-lfs (2.3.4-1) ...\n", "Setting up git-lfs (2.3.4-1) ...\n", "Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "tfUsrKR7YLT1" }, "source": [ "MY_MODEL_NAME = 'DialoGPT-small-Tony'\n", "\n", "HUGGINGFACE_API_KEY = \"\"" ], "execution_count": 26, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "L4nrBX8D6Rm6", "outputId": "94929339-1dc3-4a73-e62e-f86764af1093", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "!huggingface-cli login\n" ], "execution_count": 25, "outputs": [ { "output_type": "stream", "text": [ "\n", " _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|\n", " _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n", " _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|\n", " _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n", " _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|\n", "\n", " \n", "Username: imrit450\n", "Password: \n", "Login successful\n", "Your token: bXZrHHMXQKGFoEbLhCJmFurqTlidmrkJGIxiffJrMaASvfGRdJudDQGWAUZwfpwSVaTQdAVkRkFfJHdXmPhKHRCJLCSGFKnPBQSfqMuomzLWnEOBWvnYLCOIiJuHRPfg \n", "\n", "Your token has been saved to /root/.huggingface/token\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "uhqMtvfmXei8" }, "source": [ "!git config --global user.email \"imrit450@gmail.com\"\n", "# Tip: using the same email as your huggingface.co account will link your commits to your profile\n", "!git config --global user.name \"Ismail\"" ], "execution_count": 28, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_65nsiLcYNXI", "outputId": "b6d3a728-836a-49bc-acdd-b273941249a9" }, "source": [ "model.push_to_hub(MY_MODEL_NAME)\n", "tokenizer.push_to_hub(MY_MODEL_NAME)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "07/17/2021 00:45:53 - INFO - huggingface_hub.repository - git version 2.17.1\n", "Sorry, no usage text found for \"git-lfs\"\n" ], "name": "stderr" } ] }, { "cell_type": "code", "metadata": { "id": "a9Ikuswq7LQ3" }, "source": [ "" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "D_XfXTCrZKmO" }, "source": [ "## All Done!" ] }, { "cell_type": "code", "metadata": { "id": "_tIwK7G8ZLrd" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }