File size: 2,866 Bytes
f4fac26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# log 画图
from datetime import datetime
import numpy as np
import pandas as pd 

from matplotlib import pyplot as plt

import sys 
sys.path.extend(['.', '..'])

from config import PROJECT_ROOT

def str_to_timestamp(string: str) -> float:
    '''
    '''
    date_fmt = '%Y-%m-%d %H:%M:%S.%f'
    string = string.replace('[', '').replace(']', '')

    # 转化为时间戳
    return datetime.strptime(string, date_fmt).timestamp()

def plot_traing_loss(log_file: str, start_date: str, end_date: str, pic_save_to_file: str=None) -> None:
    '''
    将log日志中记录的画图,按需保存到文件,由于log日志打印内容较多,需要指定要打印loss的开始时间和结束时间
    examlpe:
    >>>  plot_traing_loss('./logs/trainer.log', '[2023-10-01 08:44:39.303]', '[2023-10-01 11:29:12.376]')
    >>> plot_traing_loss('./logs/trainer.log', '2023-10-01 08:44:39.303', '2023-10-01 11:29:12.376')
    '''
    start_timestamp = str_to_timestamp(start_date)
    end_timestamp = str_to_timestamp(end_date)
    
    loss_list = []
    with open(log_file, 'r', encoding='utf-8') as f:

        for line in f:
            if 'training loss: epoch:' in line:
                line = line.split(' ')
                date = ' '.join(line[0: 2])
                if str_to_timestamp(date) < start_timestamp:
                    continue
                
                if str_to_timestamp(date) > end_timestamp:
                    break

                if len(line) != 9: continue
 
                epoch = line[5][6: -1]  # 'epoch:0,'
                step = line[6][5: -1]   # 'step:0,'
                loss = float(line[7][5: -1])   # 'loss:0.11086619377136231\n'
                device = line[8][7: -1]
                loss_list.append([epoch, step, loss, device])
    
    df = pd.DataFrame(loss_list, columns=['epoch', 'step', 'loss', 'device'])
    
    # 多项式拟合
    x = list(range(0, len(df['loss'])))
    x_range = np.arange(0, len(df['loss']), step=0.005)
    fit3 = np.polyfit(x, df['loss'], 3)
    p1d = np.poly1d(fit3)
    y_fit = p1d(x_range)

    plt.figure(figsize=(8, 6),dpi=100)
    plt.plot(df['loss'],'g',label = 'loss')
    plt.plot(x_range, y_fit, 'r', label='fit loss')     
    plt.ylabel('loss')
    plt.xlabel('sampling step')
    plt.legend()        #个性化图例(颜色、形状等)
    
    if pic_save_to_file is not None:
        plt.savefig(pic_save_to_file) 
    
    plt.show()


if __name__ == '__main__':
    
    # plot_traing_loss(PROJECT_ROOT + '/logs/chat_trainer-20231011.log', '[2023-10-11 11:04:53.960]', '[2023-10-18 01:41:40.540]', pic_save_to_file=PROJECT_ROOT + '/img/train_loss.png')

    # plot_traing_loss(PROJECT_ROOT + '/logs/chat_trainer-20231018.log', '[2023-10-18 02:06:28.137]', '[2023-10-18 18:03:35.230]', pic_save_to_file=PROJECT_ROOT + '/img/finetune_loss.png')

    pass