{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import seaborn as sns \n", "from matplotlib import pyplot as plt\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "sft_log_file = '../logs/sft_train_log_20231211-2250.csv'\n", "dpo_log_file = '../logs/dpo_train_log_20231213-0214.csv'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0epochlearning_ratelossstep
000.001.400000e-082.59861
110.001.400000e-062.6353100
220.012.800000e-062.4905200
330.014.200000e-062.3610300
440.015.600000e-062.2837400
\n", "
" ], "text/plain": [ " Unnamed: 0 epoch learning_rate loss step\n", "0 0 0.00 1.400000e-08 2.5986 1\n", "1 1 0.00 1.400000e-06 2.6353 100\n", "2 2 0.01 2.800000e-06 2.4905 200\n", "3 3 0.01 4.200000e-06 2.3610 300\n", "4 4 0.01 5.600000e-06 2.2837 400" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sft_df = pd.read_csv(sft_log_file)\n", "dpo_df = pd.read_csv(dpo_log_file)\n", "sft_df.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# plt.title('learning_rate')\n", "sns.lineplot(\n", " x=\"step\", \n", " y=\"learning_rate\", \n", " data=sft_df,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.title('sft loss')\n", "sns.lineplot(\n", " x=\"step\", \n", " y=\"loss\", \n", " color='dodgerblue',\n", " data=sft_df,\n", " )\n", "plt.savefig('../img/sft_loss.png')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0epochlearning_ratelogits/chosenlogits/rejectedlogps/chosenlogps/rejectedlossrewards/accuraciesrewards/chosenrewards/marginsrewards/rejectedstep
000.001.000000e-08-3.525447-3.550683-256.702698-143.3082430.76890.437500-0.044875-0.0728440.0279691
110.012.000000e-07-3.509013-3.557282-270.281708-150.8504330.74380.4868420.002034-0.0201940.02222820
220.014.000000e-07-3.509622-3.544898-286.783966-162.9469150.70380.5296880.0242290.046643-0.02241440
330.026.000000e-07-3.521220-3.554179-267.424896-151.9845730.72180.5078120.0049730.008775-0.00380360
440.038.000000e-07-3.513215-3.551011-281.538208-157.7845460.69950.5484370.0571790.069537-0.01235880
\n", "
" ], "text/plain": [ " Unnamed: 0 epoch learning_rate logits/chosen logits/rejected \\\n", "0 0 0.00 1.000000e-08 -3.525447 -3.550683 \n", "1 1 0.01 2.000000e-07 -3.509013 -3.557282 \n", "2 2 0.01 4.000000e-07 -3.509622 -3.544898 \n", "3 3 0.02 6.000000e-07 -3.521220 -3.554179 \n", "4 4 0.03 8.000000e-07 -3.513215 -3.551011 \n", "\n", " logps/chosen logps/rejected loss rewards/accuracies rewards/chosen \\\n", "0 -256.702698 -143.308243 0.7689 0.437500 -0.044875 \n", "1 -270.281708 -150.850433 0.7438 0.486842 0.002034 \n", "2 -286.783966 -162.946915 0.7038 0.529688 0.024229 \n", "3 -267.424896 -151.984573 0.7218 0.507812 0.004973 \n", "4 -281.538208 -157.784546 0.6995 0.548437 0.057179 \n", "\n", " rewards/margins rewards/rejected step \n", "0 -0.072844 0.027969 1 \n", "1 -0.020194 0.022228 20 \n", "2 0.046643 -0.022414 40 \n", "3 0.008775 -0.003803 60 \n", "4 0.069537 -0.012358 80 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dpo_df.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.title('dpo loss')\n", "sns.lineplot(\n", " x=\"step\", \n", " y=\"loss\", \n", " color='orange',\n", " data=dpo_df[0: 6000 // 20], # 只使用了到6000步的checkpoit,后面的有过拟合迹象\n", " )\n", "plt.savefig('../img/dpo_loss.png')" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import sys, os\n", "root = os.path.realpath('.').replace('\\\\','/').split('/')[0: -1]\n", "root = '/'.join(root)\n", "sys.path.append(root)\n", "\n", "from model.infer import ChatBot\n", "from config import InferConfig\n", "\n", "bot = ChatBot(InferConfig())\n", "model = bot.model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model parameters size: 210.19 M = 0.21B\n", "GPU memory used: 0.40GB\n" ] } ], "source": [ "param_size = sum([p.numel() for p in model.parameters()]) / 1000 / 1000\n", "print('model parameters size: {:.2f} M = {:.2f}B'.format( param_size , param_size / 1000))\n", "\n", "print('GPU memory used: {:.2f}GB'.format(torch.cuda.memory_allocated() / (1024 ** 3)))" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }