"""Plot pandas.DataFrame with DBSCAN clustering.""" # pylint: disable=invalid-name, too-many-arguments, too-many-locals import numpy as np import pandas as pd import matplotlib import matplotlib.pyplot as plt import seaborn as sns from sklearn.cluster import DBSCAN from fastlid import fastlid import logzero from logzero import logger from radiobee.cmat2tset import cmat2tset # turn interactive when in ipython session _ = """ if "get_ipython" in globals(): plt.ion() else: plt.switch_backend("Agg") # """ logzero.loglevel(20) # 10: debug on fastlid.set_languages = ["en", "zh"] # fmt: off def plot_cmat( # df_: pd.DataFrame, cmat: np.ndarray, eps: float = 10, min_samples: int = 6, # ylim: int = None, xlabel: str = "zh", ylabel: str = "en", backend: str = "Agg", showfig: bool = False, ): # fmt: on """Plot df with DBSCAN clustering. Args: df_: pandas.DataFrame, with three columns columns=["x", "y", "cos"] Returns matplotlib.pyplot: for possible use in gradio plot_df(pd.DataFrame(cmat2tset(smat), columns=['x', 'y', 'cos'])) df_ = pd.DataFrame(cmat2tset(smat), columns=['x', 'y', 'cos']) # sort 'x', axis 0 changes, index regenerated df_s = df_.sort_values('x', axis=0, ignore_index=True) # sorintg does not seem to impact clustering DBSCAN(1.5, min_samples=3).fit(df_).labels_ DBSCAN(1.5, min_samples=3).fit(df_s).labels_ """ logger.debug( '"get_ipython" in globals(): %s', "get_ipython" in globals() ) len1, len2 = cmat.shape df_ = pd.DataFrame(cmat2tset(cmat)) df_.columns = ["x", "y", "cos"] backend_saved = matplotlib.get_backend() # switch backend if necessary if backend_saved != backend: plt.switch_backend(backend) # len1 = len(lst1) # noqa # len2 = len(lst2) # noqa # lang1, _ = fastlid(" ".join(lst1)) # lang2, _ = fastlid(" ".join(lst2)) # xlabel: str = lang1 # ylabel: str = lang2 sns.set() sns.set_style("darkgrid") # close all existing figures, necesssary for hf spaces plt.close("all") # if sys.platform not in ["win32", "linux"]: # plt.switch_backend('Agg') # to cater for Mac, thanks to WhiteFox fig = plt.figure() gs = fig.add_gridspec(2, 2, wspace=0.4, hspace=0.58) ax2 = fig.add_subplot(gs[0, 0]) ax0 = fig.add_subplot(gs[0, 1]) ax1 = fig.add_subplot(gs[1, 0]) cmap = "viridis_r" sns.heatmap(cmat, cmap=cmap, ax=ax2).invert_yaxis() ax2.set_xlabel(xlabel) ax2.set_ylabel(ylabel) ax2.set_title("cos similarity heatmap") fig.suptitle("alignment projection") _ = DBSCAN(min_samples=min_samples, eps=eps).fit(df_).labels_ > -1 # _x = DBSCAN(min_samples=min_samples, eps=eps).fit(df_).labels_ < 0 _x = ~_ df_.plot.scatter("x", "y", c="cos", cmap=cmap, ax=ax0) # clustered df_[_].plot.scatter("x", "y", c="cos", cmap=cmap, ax=ax1) # outliers df_[_x].plot.scatter("x", "y", c="r", marker="x", alpha=0.6, ax=ax0) ax0.set_xlabel(xlabel) ax0.set_ylabel(ylabel) ax0.set_xlim(0, len1) ax0.set_ylim(0, len2) ax0.set_title("max along columns ('x': outliers)") # ax1.set_xlabel("en") # ax1.set_ylabel("zh") ax1.set_xlabel(xlabel) ax1.set_ylabel(ylabel) ax1.set_xlim(0, len1) ax1.set_ylim(0, len2) ax1.set_title(f"potential aligned pairs ({round(sum(_) / len1, 2):.0%})") logger.debug(" matplotlib.get_backend(): %s", matplotlib.get_backend()) # if matplotlib.get_backend() not in ["Agg"]: if showfig: # plt.ioff() # or we'll just see the plot show and disappear # plt.show() plt.show(block=True) # restore if necessary if backend_saved != backend: plt.switch_backend(backend_saved) # return plt