File size: 1,170 Bytes
da90469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import List, Tuple
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
import pandas as pd

def plot_metrics(
        metrics: List[Tuple[pd.Series, pd.Series, str]] | List[List[Tuple[pd.Series, pd.Series, str]]],
        remove_na: bool = False,
        subAxes: Axes = None,
        title: str = None,
        xlabel: str = None,
        ylabel: str = None,
        figsize=(8, 6)):
    _, axes = plt.subplots(len(metrics), 1, figsize=(
        figsize[0], figsize[1] * len(metrics))) if subAxes is None else (None, subAxes)
    for index, metric in enumerate(metrics):
        ax = (axes[index] if len(metrics) >
              1 else axes) if subAxes is None else subAxes
        if type(metric) is tuple:
            (x, y, legend) = metric[0:3]
            color = metric[3] if len(metric) > 3 else 'blue'
            [x, y] = [x, y] if not remove_na else zip(
                *[[x_1, y_1] for x_1, y_1 in zip(x, y) if pd.notna(y_1)])
            ax.plot(x, y, color=color, label=legend)
            ax.legend()
        else:
            plot_metrics(metric, remove_na, ax)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)