File size: 2,293 Bytes
a3ffd31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import json

def create_graph(lora_path, lora_name):
    try:
        import matplotlib.pyplot as plt
        from matplotlib.ticker import ScalarFormatter
        
        peft_model_path = f'{lora_path}/training_graph.json'
        image_model_path = f'{lora_path}/training_graph.png'
        # Check if the JSON file exists
        if os.path.exists(peft_model_path):
            # Load data from JSON file
            with open(peft_model_path, 'r') as file:
                data = json.load(file)
            # Extract x, y1, and y2 values
            x = [item['epoch'] for item in data]
            y1 = [item['learning_rate'] for item in data]
            y2 = [item['loss'] for item in data]

            # Create the line chart
            fig, ax1 = plt.subplots(figsize=(10, 6))
        

            # Plot y1 (learning rate) on the first y-axis
            ax1.plot(x, y1, 'b-', label='Learning Rate')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Learning Rate', color='b')
            ax1.tick_params('y', colors='b')

            # Create a second y-axis
            ax2 = ax1.twinx()

            # Plot y2 (loss) on the second y-axis
            ax2.plot(x, y2, 'r-', label='Loss')
            ax2.set_ylabel('Loss', color='r')
            ax2.tick_params('y', colors='r')

            # Set the y-axis formatter to display numbers in scientific notation
            ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
            ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0))

            # Add grid
            ax1.grid(True)

            # Combine the legends for both plots
            lines, labels = ax1.get_legend_handles_labels()
            lines2, labels2 = ax2.get_legend_handles_labels()
            ax2.legend(lines + lines2, labels + labels2, loc='best')

            # Set the title
            plt.title(f'{lora_name} LR and Loss vs Epoch')

            # Save the chart as an image
            plt.savefig(image_model_path)

            print(f"Graph saved in {image_model_path}")
        else:
            print(f"File 'training_graph.json' does not exist in the {lora_path}")
      
    except ImportError:
        print("matplotlib is not installed. Please install matplotlib to create PNG graphs")