Spaces:
Runtime error
Runtime error
File size: 2,866 Bytes
f4fac26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# log 画图
from datetime import datetime
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import sys
sys.path.extend(['.', '..'])
from config import PROJECT_ROOT
def str_to_timestamp(string: str) -> float:
'''
'''
date_fmt = '%Y-%m-%d %H:%M:%S.%f'
string = string.replace('[', '').replace(']', '')
# 转化为时间戳
return datetime.strptime(string, date_fmt).timestamp()
def plot_traing_loss(log_file: str, start_date: str, end_date: str, pic_save_to_file: str=None) -> None:
'''
将log日志中记录的画图,按需保存到文件,由于log日志打印内容较多,需要指定要打印loss的开始时间和结束时间
examlpe:
>>> plot_traing_loss('./logs/trainer.log', '[2023-10-01 08:44:39.303]', '[2023-10-01 11:29:12.376]')
>>> plot_traing_loss('./logs/trainer.log', '2023-10-01 08:44:39.303', '2023-10-01 11:29:12.376')
'''
start_timestamp = str_to_timestamp(start_date)
end_timestamp = str_to_timestamp(end_date)
loss_list = []
with open(log_file, 'r', encoding='utf-8') as f:
for line in f:
if 'training loss: epoch:' in line:
line = line.split(' ')
date = ' '.join(line[0: 2])
if str_to_timestamp(date) < start_timestamp:
continue
if str_to_timestamp(date) > end_timestamp:
break
if len(line) != 9: continue
epoch = line[5][6: -1] # 'epoch:0,'
step = line[6][5: -1] # 'step:0,'
loss = float(line[7][5: -1]) # 'loss:0.11086619377136231\n'
device = line[8][7: -1]
loss_list.append([epoch, step, loss, device])
df = pd.DataFrame(loss_list, columns=['epoch', 'step', 'loss', 'device'])
# 多项式拟合
x = list(range(0, len(df['loss'])))
x_range = np.arange(0, len(df['loss']), step=0.005)
fit3 = np.polyfit(x, df['loss'], 3)
p1d = np.poly1d(fit3)
y_fit = p1d(x_range)
plt.figure(figsize=(8, 6),dpi=100)
plt.plot(df['loss'],'g',label = 'loss')
plt.plot(x_range, y_fit, 'r', label='fit loss')
plt.ylabel('loss')
plt.xlabel('sampling step')
plt.legend() #个性化图例(颜色、形状等)
if pic_save_to_file is not None:
plt.savefig(pic_save_to_file)
plt.show()
if __name__ == '__main__':
# plot_traing_loss(PROJECT_ROOT + '/logs/chat_trainer-20231011.log', '[2023-10-11 11:04:53.960]', '[2023-10-18 01:41:40.540]', pic_save_to_file=PROJECT_ROOT + '/img/train_loss.png')
# plot_traing_loss(PROJECT_ROOT + '/logs/chat_trainer-20231018.log', '[2023-10-18 02:06:28.137]', '[2023-10-18 18:03:35.230]', pic_save_to_file=PROJECT_ROOT + '/img/finetune_loss.png')
pass
|