{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns \n",
"from matplotlib import pyplot as plt\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"sft_log_file = '../logs/sft_train_log_20231211-2250.csv'\n",
"dpo_log_file = '../logs/dpo_train_log_20231213-0214.csv'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Unnamed: 0 | \n",
" epoch | \n",
" learning_rate | \n",
" loss | \n",
" step | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0 | \n",
" 0.00 | \n",
" 1.400000e-08 | \n",
" 2.5986 | \n",
" 1 | \n",
"
\n",
" \n",
" 1 | \n",
" 1 | \n",
" 0.00 | \n",
" 1.400000e-06 | \n",
" 2.6353 | \n",
" 100 | \n",
"
\n",
" \n",
" 2 | \n",
" 2 | \n",
" 0.01 | \n",
" 2.800000e-06 | \n",
" 2.4905 | \n",
" 200 | \n",
"
\n",
" \n",
" 3 | \n",
" 3 | \n",
" 0.01 | \n",
" 4.200000e-06 | \n",
" 2.3610 | \n",
" 300 | \n",
"
\n",
" \n",
" 4 | \n",
" 4 | \n",
" 0.01 | \n",
" 5.600000e-06 | \n",
" 2.2837 | \n",
" 400 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Unnamed: 0 epoch learning_rate loss step\n",
"0 0 0.00 1.400000e-08 2.5986 1\n",
"1 1 0.00 1.400000e-06 2.6353 100\n",
"2 2 0.01 2.800000e-06 2.4905 200\n",
"3 3 0.01 4.200000e-06 2.3610 300\n",
"4 4 0.01 5.600000e-06 2.2837 400"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sft_df = pd.read_csv(sft_log_file)\n",
"dpo_df = pd.read_csv(dpo_log_file)\n",
"sft_df.head(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plt.title('learning_rate')\n",
"sns.lineplot(\n",
" x=\"step\", \n",
" y=\"learning_rate\", \n",
" data=sft_df,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.title('sft loss')\n",
"sns.lineplot(\n",
" x=\"step\", \n",
" y=\"loss\", \n",
" color='dodgerblue',\n",
" data=sft_df,\n",
" )\n",
"plt.savefig('../img/sft_loss.png')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Unnamed: 0 | \n",
" epoch | \n",
" learning_rate | \n",
" logits/chosen | \n",
" logits/rejected | \n",
" logps/chosen | \n",
" logps/rejected | \n",
" loss | \n",
" rewards/accuracies | \n",
" rewards/chosen | \n",
" rewards/margins | \n",
" rewards/rejected | \n",
" step | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0 | \n",
" 0.00 | \n",
" 1.000000e-08 | \n",
" -3.525447 | \n",
" -3.550683 | \n",
" -256.702698 | \n",
" -143.308243 | \n",
" 0.7689 | \n",
" 0.437500 | \n",
" -0.044875 | \n",
" -0.072844 | \n",
" 0.027969 | \n",
" 1 | \n",
"
\n",
" \n",
" 1 | \n",
" 1 | \n",
" 0.01 | \n",
" 2.000000e-07 | \n",
" -3.509013 | \n",
" -3.557282 | \n",
" -270.281708 | \n",
" -150.850433 | \n",
" 0.7438 | \n",
" 0.486842 | \n",
" 0.002034 | \n",
" -0.020194 | \n",
" 0.022228 | \n",
" 20 | \n",
"
\n",
" \n",
" 2 | \n",
" 2 | \n",
" 0.01 | \n",
" 4.000000e-07 | \n",
" -3.509622 | \n",
" -3.544898 | \n",
" -286.783966 | \n",
" -162.946915 | \n",
" 0.7038 | \n",
" 0.529688 | \n",
" 0.024229 | \n",
" 0.046643 | \n",
" -0.022414 | \n",
" 40 | \n",
"
\n",
" \n",
" 3 | \n",
" 3 | \n",
" 0.02 | \n",
" 6.000000e-07 | \n",
" -3.521220 | \n",
" -3.554179 | \n",
" -267.424896 | \n",
" -151.984573 | \n",
" 0.7218 | \n",
" 0.507812 | \n",
" 0.004973 | \n",
" 0.008775 | \n",
" -0.003803 | \n",
" 60 | \n",
"
\n",
" \n",
" 4 | \n",
" 4 | \n",
" 0.03 | \n",
" 8.000000e-07 | \n",
" -3.513215 | \n",
" -3.551011 | \n",
" -281.538208 | \n",
" -157.784546 | \n",
" 0.6995 | \n",
" 0.548437 | \n",
" 0.057179 | \n",
" 0.069537 | \n",
" -0.012358 | \n",
" 80 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Unnamed: 0 epoch learning_rate logits/chosen logits/rejected \\\n",
"0 0 0.00 1.000000e-08 -3.525447 -3.550683 \n",
"1 1 0.01 2.000000e-07 -3.509013 -3.557282 \n",
"2 2 0.01 4.000000e-07 -3.509622 -3.544898 \n",
"3 3 0.02 6.000000e-07 -3.521220 -3.554179 \n",
"4 4 0.03 8.000000e-07 -3.513215 -3.551011 \n",
"\n",
" logps/chosen logps/rejected loss rewards/accuracies rewards/chosen \\\n",
"0 -256.702698 -143.308243 0.7689 0.437500 -0.044875 \n",
"1 -270.281708 -150.850433 0.7438 0.486842 0.002034 \n",
"2 -286.783966 -162.946915 0.7038 0.529688 0.024229 \n",
"3 -267.424896 -151.984573 0.7218 0.507812 0.004973 \n",
"4 -281.538208 -157.784546 0.6995 0.548437 0.057179 \n",
"\n",
" rewards/margins rewards/rejected step \n",
"0 -0.072844 0.027969 1 \n",
"1 -0.020194 0.022228 20 \n",
"2 0.046643 -0.022414 40 \n",
"3 0.008775 -0.003803 60 \n",
"4 0.069537 -0.012358 80 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dpo_df.head(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.title('dpo loss')\n",
"sns.lineplot(\n",
" x=\"step\", \n",
" y=\"loss\", \n",
" color='orange',\n",
" data=dpo_df[0: 6000 // 20], # 只使用了到6000步的checkpoit,后面的有过拟合迹象\n",
" )\n",
"plt.savefig('../img/dpo_loss.png')"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import sys, os\n",
"root = os.path.realpath('.').replace('\\\\','/').split('/')[0: -1]\n",
"root = '/'.join(root)\n",
"sys.path.append(root)\n",
"\n",
"from model.infer import ChatBot\n",
"from config import InferConfig\n",
"\n",
"bot = ChatBot(InferConfig())\n",
"model = bot.model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model parameters size: 210.19 M = 0.21B\n",
"GPU memory used: 0.40GB\n"
]
}
],
"source": [
"param_size = sum([p.numel() for p in model.parameters()]) / 1000 / 1000\n",
"print('model parameters size: {:.2f} M = {:.2f}B'.format( param_size , param_size / 1000))\n",
"\n",
"print('GPU memory used: {:.2f}GB'.format(torch.cuda.memory_allocated() / (1024 ** 3)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py310",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}