freemt
Update before sent-align
4c04f50
"""Plot pandas.DataFrame with DBSCAN clustering."""
# pylint: disable=invalid-name, too-many-arguments, unused-import
import numpy as np # noqa
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import DBSCAN
from logzero import logger # noqa
# from radiobee.cmat2tset import cmat2tset
# turn interactive when in ipython session
_ = """
if "get_ipython" in globals():
plt.ion()
else:
plt.switch_backend('Agg')
# """
# fastlid.set_languages = ["en", "zh"]
# fmt: off
def plot_df(
df_: pd.DataFrame,
# cmat: np.ndarray,
eps: float = 10,
min_samples: int = 6,
xlabel: str = "",
ylabel: str = "",
xlim: int = 0,
ylim: int = 0,
backend: str = "TkAgg",
) -> plt:
# fmt: on
"""Plot df with DBSCAN clustering.
Args:
df_: pandas.DataFrame, with three columns columns=["x", "y", "cos"]
Returns:
matplotlib.pyplot: for possible use in gradio
plot_df(pd.DataFrame(cmat2tset(smat), columns=['x', 'y', 'cos']))
df_ = pd.DataFrame(cmat2tset(smat), columns=['x', 'y', 'cos'])
# sort 'x', axis 0 changes, index regenerated
df_s = df_.sort_values('x', axis=0, ignore_index=True)
# sorintg does not seem to impact clustering
DBSCAN(1.5, min_samples=3).fit(df_).labels_
DBSCAN(1.5, min_samples=3).fit(df_s).labels_
"""
# df_ = pd.DataFrame(cmat2tset(cmat))
if df_.shape[1] == 3:
df_.columns = ["x", "y", "cos"]
else:
logger.error(" shape mismatch: %s, expected (x, 3)", df_.shape)
# return None
raise Exception(" df_.shape[1] not equal to 3 ")
if not xlim:
xlim = len(df_)
if not ylim:
ylim = df_.y.max()
if not xlabel:
xlabel = str(xlim)
if not ylabel:
ylabel = str(ylim)
backend_saved = matplotlib.get_backend()
# switch if necessary
if backend_saved != backend:
plt.switch_backend(backend)
sns.set()
sns.set_style("darkgrid")
fig = plt.figure(figsize=(13, 8))
# gs = fig.add_gridspec(2, 2, wspace=0.4, hspace=0.58)
# ax2 = fig.add_subplot(gs[0, 0])
# ax0 = fig.add_subplot(gs[0, 1])
# ax1 = fig.add_subplot(gs[1, 0])
gs = fig.add_gridspec(1, 1, wspace=0.4, hspace=0.58)
ax0 = fig.add_subplot(gs[0, 0])
cmap = "viridis_r"
_ = DBSCAN(min_samples=min_samples, eps=eps).fit(df_).labels_ > -1
_x = ~_
# clustered
df_[_].plot.scatter("x", "y", c="cos", cmap=cmap, ax=ax0)
# outliers
df_[_x].plot.scatter("x", "y", c="r", marker="x", alpha=0.6, ax=ax0)
# ax1.set_xlabel("en")
# ax1.set_ylabel("zh")
ax0.set_xlabel(xlabel)
ax0.set_ylabel(ylabel)
# ax0.set_xlim(0, xlim)
# ax0.set_ylim(0, ylim)
ax0.set_title("max cos ('x': outliers)")
# ax1.set_title(f"potential aligned pairs ({round(sum(_) / xlim, 2):.0%})")
# restore if necessary
if backend_saved != backend:
plt.switch_backend(backend_saved)
return plt
_ = """
eps: float = 10
min_samples: int = 6
xlabel: str = ""
ylabel: str = ""
xlim: int = 0
ylim: int = 0
"""