diff --git "a/experiment/rwkv-x-exp/multi-size-train/v5-L12-D2048-part1.ipynb" "b/experiment/rwkv-x-exp/multi-size-train/v5-L12-D2048-part1.ipynb" new file mode 100644--- /dev/null +++ "b/experiment/rwkv-x-exp/multi-size-train/v5-L12-D2048-part1.ipynb" @@ -0,0 +1,15954 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "403604b7", + "metadata": { + "papermill": { + "duration": 0.002912, + "end_time": "2023-09-29T05:56:34.178431", + "exception": false, + "start_time": "2023-09-29T05:56:34.175519", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# RWKV v5 multi-size training experiment\n", + "\n", + "**Note:** This project assumes you have the rwkv-infctx conda env setup" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "3a120d89", + "metadata": { + "papermill": { + "duration": 0.002318, + "end_time": "2023-09-29T05:56:34.184361", + "exception": false, + "start_time": "2023-09-29T05:56:34.182043", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Basic Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8949fa01", + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-29T05:56:34.188493Z", + "iopub.status.busy": "2023-09-29T05:56:34.188124Z", + "iopub.status.idle": "2023-09-29T05:56:34.854422Z", + "shell.execute_reply": "2023-09-29T05:56:34.853652Z" + }, + "papermill": { + "duration": 0.670486, + "end_time": "2023-09-29T05:56:34.856364", + "exception": false, + "start_time": "2023-09-29T05:56:34.185878", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# First lets setup the various directories, and init the model\n", + "!mkdir -p ../../../../model/\n", + "!mkdir -p ../../../../datapath/\n", + "!mkdir -p ../../../../checkpoint/" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dc84d15b", + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-29T05:56:34.862583Z", + "iopub.status.busy": "2023-09-29T05:56:34.862335Z", + "iopub.status.idle": "2023-09-29T05:56:34.870376Z", + "shell.execute_reply": "2023-09-29T05:56:34.869641Z" + }, + "papermill": { + "duration": 0.012677, + "end_time": "2023-09-29T05:56:34.871741", + "exception": false, + "start_time": "2023-09-29T05:56:34.859064", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DEEPSPEED_STRAT: deepspeed_stage_1\n", + "ENABLE_WANDB: True\n", + "GPU_DEVICES: auto\n", + "NOTEBOOK_DIR: /actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/notebook/experiment/rwkv-x-exp/multi-size-train\n", + "INFERENCE_DIR: /actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5\n", + "TRAINER_DIR: /actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5\n", + "PROJECT_DIR: /actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer\n" + ] + } + ], + "source": [ + "DEEPSPEED_STRAT=\"deepspeed_stage_1\"\n", + "GPU_DEVICES=\"auto\"\n", + "ENABLE_WANDB=True\n", + "\n", + "EMBED_SCALE=0.01\n", + "EMBED_SCALE_LABEL=str(EMBED_SCALE).replace(\".\", \"_\")\n", + "\n", + "LAYER_COUNT=12\n", + "EMBED_SIZE=2048\n", + "\n", + "WANDB_PREFIX=f\"[Multi-size] v5-L{LAYER_COUNT}-D{EMBED_SIZE}-E{EMBED_SCALE}\"\n", + "FILENAME_PREFIX=f\"v5-L{LAYER_COUNT}-D{EMBED_SIZE}-E{EMBED_SCALE_LABEL}\"\n", + "\n", + "print(\"DEEPSPEED_STRAT:\", DEEPSPEED_STRAT)\n", + "print(\"ENABLE_WANDB:\", ENABLE_WANDB)\n", + "print(\"GPU_DEVICES:\", GPU_DEVICES)\n", + "\n", + "if ENABLE_WANDB:\n", + " WANDB_MODE=\"online\"\n", + "else:\n", + " WANDB_MODE=\"disabled\"\n", + "\n", + "# Computing the notebook, and various paths\n", + "import os\n", + "NOTEBOOK_DIR=os.path.dirname(os.path.abspath(\"__file__\"))\n", + "PROJECT_DIR=os.path.abspath(os.path.join(NOTEBOOK_DIR, \"../../../../\"))\n", + "TRAINER_DIR=os.path.abspath(os.path.join(PROJECT_DIR, \"./RWKV-v5/\"))\n", + "INFERENCE_DIR=os.path.abspath(os.path.join(PROJECT_DIR, \"./RWKV-v5/\"))\n", + "\n", + "print(\"NOTEBOOK_DIR:\", NOTEBOOK_DIR)\n", + "print(\"INFERENCE_DIR:\", INFERENCE_DIR)\n", + "print(\"TRAINER_DIR:\", TRAINER_DIR)\n", + "print(\"PROJECT_DIR:\", PROJECT_DIR)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3e8278e4", + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-29T05:56:34.877052Z", + "iopub.status.busy": "2023-09-29T05:56:34.876592Z", + "iopub.status.idle": "2023-09-29T05:57:04.417944Z", + "shell.execute_reply": "2023-09-29T05:57:04.417221Z" + }, + "papermill": { + "duration": 29.546096, + "end_time": "2023-09-29T05:57:04.419963", + "exception": false, + "start_time": "2023-09-29T05:56:34.873867", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-09-29 05:56:37,806] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.0.1+cu118'\r\n", + "---- Initializing model ----\r\n", + "No of layers: 12\r\n", + "Embedding size: 2048\r\n", + "Output model path: ../model/v5-L12-D2048-E0_01-neox-v5base-init.pth\r\n", + "Vocab size: 50277\r\n", + "Emb scale: 0.01\r\n", + "Note: this process takes a significant time (and ram) for large models\r\n", + "---- ----- ----\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "50277 2048 -0.01 emb.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.0.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.0.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.0.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.0.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.0.att.output.weight\r\n", + "7168 2048 1.0 blocks.0.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.0.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.0.ffn.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.1.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.1.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.1.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.1.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.1.att.output.weight\r\n", + "7168 2048 1.0 blocks.1.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.1.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.1.ffn.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.2.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.2.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.2.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.2.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.2.att.output.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7168 2048 1.0 blocks.2.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.2.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.2.ffn.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.3.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.3.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.3.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.3.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.3.att.output.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7168 2048 1.0 blocks.3.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.3.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.3.ffn.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.4.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.4.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.4.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.4.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.4.att.output.weight\r\n", + "7168 2048 1.0 blocks.4.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.4.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.4.ffn.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.5.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.5.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.5.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.5.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.5.att.output.weight\r\n", + "7168 2048 1.0 blocks.5.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.5.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.5.ffn.value.weight\r\n", + "2048 2048 1.0 blocks.6.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.6.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.6.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.6.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.6.att.output.weight\r\n", + "7168 2048 1.0 blocks.6.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.6.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.6.ffn.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.7.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.7.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.7.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.7.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.7.att.output.weight\r\n", + "7168 2048 1.0 blocks.7.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.7.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.7.ffn.value.weight\r\n", + "2048 2048 1.0 blocks.8.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.8.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.8.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.8.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.8.att.output.weight\r\n", + "7168 2048 1.0 blocks.8.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.8.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.8.ffn.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.9.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.9.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.9.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.9.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.9.att.output.weight\r\n", + "7168 2048 1.0 blocks.9.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.9.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.9.ffn.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.10.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.10.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.10.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.10.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.10.att.output.weight\r\n", + "7168 2048 1.0 blocks.10.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.10.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.10.ffn.value.weight\r\n", + "2048 2048 1.0 blocks.11.att.gate.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.11.att.receptance.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.11.att.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 1.0 blocks.11.att.value.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.11.att.output.weight\r\n", + "7168 2048 1.0 blocks.11.ffn.key.weight\r\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048 2048 0 blocks.11.ffn.receptance.weight\r\n", + "2048 7168 0 blocks.11.ffn.value.weight\r\n", + "50277 2048 0.5 head.weight\r\n" + ] + } + ], + "source": [ + "# Init the model\n", + "!cd \"{TRAINER_DIR}\" && \\\n", + " python3 ./init_model.py \\\n", + " --n_layer {LAYER_COUNT} --n_embd {EMBED_SIZE} \\\n", + " --emb-scale \"{EMBED_SCALE}\" \\\n", + " --vocab_size neox --skip-if-exists \\\n", + " \"../model/{FILENAME_PREFIX}-neox-v5base-init.pth\"" + ] + }, + { + "cell_type": "markdown", + "id": "09f5efbe", + "metadata": { + "papermill": { + "duration": 0.004414, + "end_time": "2023-09-29T05:57:04.432709", + "exception": false, + "start_time": "2023-09-29T05:57:04.428295", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Enwiki Stage 1 : Foundation 4k model training" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0bc8ec73", + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-29T05:57:04.441904Z", + "iopub.status.busy": "2023-09-29T05:57:04.441548Z", + "iopub.status.idle": "2023-09-29T06:02:22.847012Z", + "shell.execute_reply": "2023-09-29T06:02:22.846216Z" + }, + "papermill": { + "duration": 318.412248, + "end_time": "2023-09-29T06:02:22.848898", + "exception": false, + "start_time": "2023-09-29T05:57:04.436650", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\r", + "Map (num_proc=16): 0%| | 0/1000000 [00:00\r\n", + " model = SimpleRWKV(MODEL_PATH, device=DEVICE)\r\n", + " File \"/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5/src/model.py\", line 1420, in __init__\r\n", + " self.model = RWKV(**model_config)\r\n", + " File \"/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5/src/model.py\", line 566, in __init__\r\n", + " raise ValueError(f\"load_model file '{load_model}' does not exist\")\r\n", + "ValueError: load_model file '../model/v5-L12-D2048-E0_01-enwiki-4k-p1.pth' does not exist\r\n" + ] + } + ], + "source": [ + "# # Lets do a quick dragon prompt validation\n", + "!cd \"{INFERENCE_DIR}\" && \\\n", + " python3 dragon_test.py \"../model/{FILENAME_PREFIX}-enwiki-4k-p1.pth\" \"cuda fp32\"" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "papermill": { + "default_parameters": {}, + "duration": 354.822024, + "end_time": "2023-09-29T06:02:28.105571", + "environment_variables": {}, + "exception": null, + "input_path": "/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/notebook/experiment/rwkv-x-exp/multi-size-train/v5-L12-D2048-part1.ipynb", + "output_path": "/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/output/experiment/rwkv-x-exp/multi-size-train/v5-L12-D2048-part1.ipynb", + "parameters": {}, + "start_time": "2023-09-29T05:56:33.283547", + "version": "2.4.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file