|
from typing import List, Tuple |
|
import matplotlib.pyplot as plt |
|
from matplotlib.axes import Axes |
|
import pandas as pd |
|
|
|
def plot_metrics( |
|
metrics: List[Tuple[pd.Series, pd.Series, str]] | List[List[Tuple[pd.Series, pd.Series, str]]], |
|
remove_na: bool = False, |
|
subAxes: Axes = None, |
|
title: str = None, |
|
xlabel: str = None, |
|
ylabel: str = None, |
|
figsize=(8, 6)): |
|
_, axes = plt.subplots(len(metrics), 1, figsize=( |
|
figsize[0], figsize[1] * len(metrics))) if subAxes is None else (None, subAxes) |
|
for index, metric in enumerate(metrics): |
|
ax = (axes[index] if len(metrics) > |
|
1 else axes) if subAxes is None else subAxes |
|
if type(metric) is tuple: |
|
(x, y, legend) = metric[0:3] |
|
color = metric[3] if len(metric) > 3 else 'blue' |
|
[x, y] = [x, y] if not remove_na else zip( |
|
*[[x_1, y_1] for x_1, y_1 in zip(x, y) if pd.notna(y_1)]) |
|
ax.plot(x, y, color=color, label=legend) |
|
ax.legend() |
|
else: |
|
plot_metrics(metric, remove_na, ax) |
|
plt.title(title) |
|
plt.xlabel(xlabel) |
|
plt.ylabel(ylabel) |