File size: 1,170 Bytes
da90469 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from typing import List, Tuple
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
import pandas as pd
def plot_metrics(
metrics: List[Tuple[pd.Series, pd.Series, str]] | List[List[Tuple[pd.Series, pd.Series, str]]],
remove_na: bool = False,
subAxes: Axes = None,
title: str = None,
xlabel: str = None,
ylabel: str = None,
figsize=(8, 6)):
_, axes = plt.subplots(len(metrics), 1, figsize=(
figsize[0], figsize[1] * len(metrics))) if subAxes is None else (None, subAxes)
for index, metric in enumerate(metrics):
ax = (axes[index] if len(metrics) >
1 else axes) if subAxes is None else subAxes
if type(metric) is tuple:
(x, y, legend) = metric[0:3]
color = metric[3] if len(metric) > 3 else 'blue'
[x, y] = [x, y] if not remove_na else zip(
*[[x_1, y_1] for x_1, y_1 in zip(x, y) if pd.notna(y_1)])
ax.plot(x, y, color=color, label=legend)
ax.legend()
else:
plot_metrics(metric, remove_na, ax)
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel) |