from typing import List, Tuple import matplotlib.pyplot as plt from matplotlib.axes import Axes import pandas as pd def plot_metrics( metrics: List[Tuple[pd.Series, pd.Series, str]] | List[List[Tuple[pd.Series, pd.Series, str]]], remove_na: bool = False, subAxes: Axes = None, title: str = None, xlabel: str = None, ylabel: str = None, figsize=(8, 6)): _, axes = plt.subplots(len(metrics), 1, figsize=( figsize[0], figsize[1] * len(metrics))) if subAxes is None else (None, subAxes) for index, metric in enumerate(metrics): ax = (axes[index] if len(metrics) > 1 else axes) if subAxes is None else subAxes if type(metric) is tuple: (x, y, legend) = metric[0:3] color = metric[3] if len(metric) > 3 else 'blue' [x, y] = [x, y] if not remove_na else zip( *[[x_1, y_1] for x_1, y_1 in zip(x, y) if pd.notna(y_1)]) ax.plot(x, y, color=color, label=legend) ax.legend() else: plot_metrics(metric, remove_na, ax) plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel)