radiobee-aligner / radiobee /gradiobee.py
freemt
Updte pyc
52a9494
raw
history blame
No virus
15.8 kB
"""Gradiobee."""
# pylint: disable=invalid-name, too-many-arguments, too-many-branches, too-many-locals, too-many-statements, unused-variable, too-many-return-statements, unused-import
from pathlib import Path
import platform
import inspect
from itertools import zip_longest
# import tempfile
from sklearn.cluster import DBSCAN
from fastlid import fastlid
from logzero import logger
from icecream import ic
import numpy as np # noqa
import pandas as pd
import matplotlib # noqa
import matplotlib.pyplot as plt
import seaborn as sns
# from radiobee.process_upload import process_upload
from radiobee.files2df import files2df
from radiobee.file2text import file2text
from radiobee.lists2cmat import lists2cmat
from radiobee.gen_pset import gen_pset
from radiobee.gen_aset import gen_aset
from radiobee.align_texts import align_texts
from radiobee.cmat2tset import cmat2tset
from radiobee.trim_df import trim_df
from radiobee.error_msg import error_msg
from radiobee.text2lists import text2lists
from radiobee.align_sents import align_sents
from radiobee.shuffle_sents import shuffle_sents # type: ignore
from radiobee.paras2sents import paras2sents # type: ignore
uname = platform.uname()
HFSPACES = False
if "amzn2" in uname.release: # on hf spaces
HFSPACES = True
sns.set()
sns.set_style("darkgrid")
pd.options.display.float_format = "{:,.2f}".format
debug = False
debug = True
def gradiobee( # noqa
file1,
file2,
tf_type,
idf_type,
dl_type,
norm,
eps,
min_samples,
# debug=False,
sent_ali_algo,
):
"""Process inputs and return outputs."""
logger.debug(" *debug* ")
# possible further switchse
# para_sent: para/sent
# sent_ali: default/radio/gale-church
plot_dia = True # noqa
# outputs: check return
# if outputs is modified, also need to modify error_msg's outputs
# convert "None" to None for those Radio types
for _ in [idf_type, dl_type, norm]:
if _ in "None":
_ = None
# logger.info("file1: *%s*, file2: *%s*", file1, file2)
if file2 is not None:
logger.info("file1.name: *%s*, file2.name: *%s*", file1.name, file2.name)
else:
logger.info("file1.name: *%s*, file2: *%s*", file1.name, file2)
# bypass if file1 or file2 is str input
# if not (isinstance(file1, str) or isinstance(file2, str)):
text1 = file2text(file1)
if file2 is None:
logger.debug("file2 is None")
text2 = ""
else:
logger.debug("file2.name: %s", file2.name)
text2 = file2text(file2)
# if not text1.strip() or not text2.strip():
if not text1.strip():
msg = (
"file 1 is apparently empty... Upload a none empty file and try again."
# f"text1[:10]: [{text1[:10]}], "
# f"text2[:10]: [{text2[:10]}]"
)
return error_msg(msg)
# single file
# when text2 is empty
# process file1/text1: split text1 to text1 text2 to zh-en
len_max = 2000
if not text2.strip(): # empty file2
_ = [elm.strip() for elm in text1.splitlines() if elm.strip()]
if not _: # essentially empty file1
return error_msg("Nothing worthy of processing in file 1")
logger.info(
"single file: len %s, max %s",
len(_), 2 * len_max
)
# exit if there are too many lines
if len(_) > 2 * len_max:
return error_msg(f" Too many lines ({len(_)}) > {2 * len_max}, alignment op halted, sorry.", "info")
_ = zip_longest(_, [""])
_ = pd.DataFrame(_, columns=["text1", "text2"])
df_trimmed = trim_df(_)
# text1 = loadtext("data/test-dual.txt")
list1, list2 = text2lists(text1)
lang1 = text2lists.lang1
lang2 = text2lists.lang2
offset = text2lists.offset # noqa
_ = """
ax = sns.heatmap(lists2cmat(list1, list2), cmap="gist_earth_r") # ax=plt.gca()
ax.invert_yaxis()
ax.set(
xlabel=lang1,
ylabel=lang2,
title=f"cos similary heatmap \n(offset={offset})",
)
plt_loc = "img/plt.png"
plt.savefig(plt_loc)
# """
# output_plot = plt_loc # for gr.outputs.Image
#
_ = zip_longest(list1, list2, fillvalue="")
df_aligned = pd.DataFrame(
_,
columns=["text1", "tex2"]
)
file_dl = Path(f"{Path(file1.name).stem[:-8]}-{lang1}-{lang2}.csv")
file_dl_xlsx = Path(
f"{Path(file1.name).stem[:-8]}-{lang1}-{lang2}.xlsx"
)
# return df_trimmed, output_plot, file_dl, file_dl_xlsx, df_aligned
# end if single file
# not single file
else: # file1 file 2: proceed
fastlid.set_languages = None
lang1, _ = fastlid(text1)
lang2, _ = fastlid(text2)
df1 = files2df(file1, file2)
list1 = [elm for elm in df1.text1 if elm]
list2 = [elm for elm in df1.text2 if elm]
# len1 = len(list1) # noqa
# len2 = len(list2) # noqa
# exit if there are too many lines
len12 = len(list1) + len(list2)
logger.info(
"fast track: len1 %s, len2 %s, tot %s, max %s",
len(list1), len(list2), len(list1) + len(list2), 3 * len_max
)
if len12 > 3 * len_max:
return error_msg(f" Too many lines ({len(list1)} + {len(list2)} > {3 * len_max}), alignment op halted, sorry.", "info")
file_dl = Path(f"{Path(file1.name).stem[:-8]}-{Path(file2.name).stem[:-8]}.csv")
file_dl_xlsx = Path(
f"{Path(file1.name).stem[:-8]}-{Path(file2.name).stem[:-8]}.xlsx"
)
df_trimmed = trim_df(df1)
# --- end else single
lang_en_zh = ["en", "zh"]
logger.debug("lang1: %s, lang2: %s", lang1, lang2)
if debug:
ic(f" lang1: {lang1}, lang2: {lang2}")
ic("fast track? ", lang1 in lang_en_zh and lang2 in lang_en_zh)
# fast track
if lang1 in lang_en_zh and lang2 in lang_en_zh:
try:
cmat = lists2cmat(
list1,
list2,
tf_type=tf_type,
idf_type=idf_type,
dl_type=dl_type,
norm=norm,
)
except Exception as exc:
logger.error(exc)
return error_msg(exc)
# slow track
else:
logger.info(
"slow track: len1 %s, len2 %s, tot: %s, max %s",
len(list1), len(list2), len(list1) + len(list2),
3 * len_max
)
if len(list1) + len(list2) > 3 * len_max:
msg = (
f" len1 {len(list1)} + len2 {len(list2)} > {3 * len_max}. "
"This will take too long to complete "
"and will hog this experimental server and hinder "
"other users from trying the service. "
"Aborted...sorry"
)
return error_msg(msg, "info ")
try:
from radiobee.model_s import model_s # pylint: disable=import-outside-toplevel
vec1 = model_s.encode(list1)
vec2 = model_s.encode(list2)
# cmat = vec1.dot(vec2.T)
cmat = vec2.dot(vec1.T)
except Exception as exc:
logger.error(exc)
_ = inspect.currentframe().f_lineno # type: ignore
return error_msg(
f"{exc}, {Path(__file__).name} ln{_}, period"
)
tset = pd.DataFrame(cmat2tset(cmat))
tset.columns = ["x", "y", "cos"]
_ = """
df_trimmed = pd.concat(
[
df1.iloc[:4, :],
pd.DataFrame(
[
[
"...",
"...",
]
],
columns=df1.columns,
),
df1.iloc[-4:, :],
],
ignore_index=1,
)
# """
# process list1, list2 to obtained df_aligned
# quick fix ValueError: not enough values to unpack (expected at least 1, got 0)
# fixed in gen_pet, but we leave the loop here
for min_s in range(min_samples):
logger.info(" min_samples, using %s", min_samples - min_s)
try:
pset = gen_pset(
cmat,
eps=eps,
min_samples=min_samples - min_s,
delta=7,
)
break
except ValueError:
logger.info(" decrease min_samples by %s", min_s + 1)
continue
except Exception as e:
logger.error(e)
continue
else:
# break should happen above when min_samples = 2
raise Exception("bummer, this shouldn't happen, probably another bug")
min_samples = gen_pset.min_samples
# will result in error message:
# UserWarning: Starting a Matplotlib GUI outside of
# the main thread will likely fail."
_ = """
plot_cmat(
cmat,
eps=eps,
min_samples=min_samples,
xlabel=lang1,
ylabel=lang2,
)
# """
# move plot_cmat's code to the main thread here
# to make it work
xlabel = lang1
ylabel = lang2
len1, len2 = cmat.shape
ylim, xlim = len1, len2
# does not seem to show up
ic(f" len1 (ylim): {len1}, len2 (xlim): {len2}")
logger.debug(" len1 (ylim): %s, len2 (xlim): %s", len1, len2)
if debug:
print(f" len1 (ylim): {len1}, len2 (xlim): {len2}")
df_ = pd.DataFrame(cmat2tset(cmat))
df_.columns = ["x", "y", "cos"]
sns.set()
sns.set_style("darkgrid")
# close all existing figures, necesssary for hf spaces
plt.close("all")
# if sys.platform not in ["win32", "linux"]:
# going for noninterative
# to cater for Mac, thanks to WhiteFox
plt.switch_backend("Agg")
# figsize=(13, 8), (339, 212) mm on '1280x800+0+0'
fig = plt.figure(figsize=(13, 8))
# gs = fig.add_gridspec(2, 2, wspace=0.4, hspace=0.58)
gs = fig.add_gridspec(1, 2, wspace=0.4, hspace=0.58)
ax_heatmap = fig.add_subplot(gs[0, 0]) # ax2
ax0 = fig.add_subplot(gs[0, 1])
# ax1 = fig.add_subplot(gs[1, 0])
cmap = "viridis_r"
sns.heatmap(cmat, cmap=cmap, ax=ax_heatmap).invert_yaxis()
ax_heatmap.set_xlabel(xlabel)
ax_heatmap.set_ylabel(ylabel)
ax_heatmap.set_title("cos similarity heatmap")
fig.suptitle(f"alignment projection\n(eps={eps}, min_samples={min_samples})")
_ = DBSCAN(min_samples=min_samples, eps=eps).fit(df_).labels_ > -1
# _x = DBSCAN(min_samples=min_samples, eps=eps).fit(df_).labels_ < 0
_x = ~_
# max cos along columns
df_.plot.scatter("x", "y", c="cos", cmap=cmap, ax=ax0)
# outliers
df_[_x].plot.scatter("x", "y", c="r", marker="x", alpha=0.6, ax=ax0)
ax0.set_xlabel(xlabel)
ax0.set_ylabel(ylabel)
ax0.set_xlim(xmin=0, xmax=xlim)
ax0.set_ylim(ymin=0, ymax=ylim)
ax0.set_title(
"max along columns (x: outliers)\n"
"potential aligned pairs (green line), "
f"{round(sum(_) / xlim, 2):.0%}"
)
plt_loc = "img/plt.png"
ic(f" plotting to {plt_loc}")
plt.savefig(plt_loc)
# clustered
# df_[_].plot.scatter("x", "y", c="cos", cmap=cmap, ax=ax1)
# ax1.set_xlabel(xlabel)
# ax1.set_ylabel(ylabel)
# ax1.set_xlim(0, len1)
# ax1.set_title(f"potential aligned pairs ({round(sum(_) / len1, 2):.0%})")
# end of plot_cmat
src_len, tgt_len = cmat.shape
aset = gen_aset(pset, src_len, tgt_len)
final_list = align_texts(aset, list2, list1) # note the order
# df_aligned
df_aligned = pd.DataFrame(final_list, columns=["text1", "text2", "likelihood"])
# swap text1 text2
df_aligned = df_aligned[["text2", "text1", "likelihood"]]
df_aligned.columns = ["text1", "text2", "likelihood"]
ic("paras aligned: ", df_aligned.head(10))
# round the last column to 2
# df_aligned.likelihood = df_aligned.likelihood.round(2)
# df_aligned = df_aligned.round({"likelihood": 2})
# df_aligned.likelihood = df_aligned.likelihood.apply(lambda x: np.nan if x in [""] else x)
if len(df_aligned) > 200:
df_html = None
else: # show a one-bathc table in html
# style
styled = df_aligned.style.set_properties(
**{
"font-size": "10pt",
"border-color": "black",
"border": "1px black solid !important"
}
# border-color="black",
).set_table_styles([{
"selector": "", # noqs
"props": [("border", "2px black solid !important")]}] # noqs
).format(
precision=2
)
# .bar(subset="likelihood", color="#5fba7d")
# .background_gradient("Greys")
# df_html = df_aligned.to_html()
df_html = styled.to_html()
# ===
if plot_dia:
output_plot = "img/plt.png"
else:
output_plot = None
_ = df_aligned.to_csv(index=False)
file_dl.write_text(_, encoding="utf8")
# file_dl.write_text(_, encoding="gb2312") # no go
df_aligned.to_excel(file_dl_xlsx)
# return df_trimmed, plt
# return df_trimmed, plt, file_dl, file_dl_xlsx, df_aligned
# output_plot: gr.outputs.Image(type="auto", label="...")
# return df_trimmed, output_plot, file_dl, file_dl_xlsx, df_aligned
# return df_trimmed, output_plot, file_dl, file_dl_xlsx, styled, df_html # gradio cant handle style
ic("sent-ali-algo: ", sent_ali_algo)
# ### sent-ali-algo is None: para align
if sent_ali_algo in ["None"]:
ic("returning para-ali outputs")
return df_trimmed, output_plot, file_dl, file_dl_xlsx, None, None, df_aligned, df_html
# ### proceed with sent align
if sent_ali_algo in ["fast"]:
ic(sent_ali_algo)
align_func = align_sents
ic(df_aligned.shape, df_aligned.columns)
aligned_sents = paras2sents(df_aligned, align_func)
# ic(pd.DataFrame(aligned_sents).shape, aligned_sents)
ic(pd.DataFrame(aligned_sents).shape)
df_aligned_sents = pd.DataFrame(aligned_sents, columns=["text1", "text2"])
else: # ["slow"]
ic(sent_ali_algo)
align_func = shuffle_sents
aligned_sents = paras2sents(df_aligned, align_func, lang1, lang2)
# add extra entry if necessary
aligned_sents = [list(sent) + [""] if len(sent) == 2 else list(sent) for sent in aligned_sents]
df_aligned_sents = pd.DataFrame(aligned_sents, columns=["text1", "text2", "likelihood"])
# prepare sents downloads
file_dl_sents = Path(f"{file_dl.stem}-sents{file_dl.suffix}")
file_dl_xlsx_sents = Path(f"{file_dl_xlsx.stem}-sents{file_dl_xlsx.suffix}")
_ = df_aligned_sents.to_csv(index=False)
file_dl_sents.write_text(_, encoding="utf8")
df_aligned_sents.to_excel(file_dl_xlsx_sents)
# prepare html output
if len(df_aligned_sents) > 200:
df_html = None
else: # show a one-bathc table in html
# style
styled = df_aligned_sents.style.set_properties(
**{
"font-size": "10pt",
"border-color": "black",
"border": "1px black solid !important"
}
# border-color="black",
).set_table_styles([{
"selector": "", # noqs
"props": [("border", "2px black solid !important")]}] # noqs
).format(
precision=2
)
df_html = styled.to_html()
# aligned sents outputs
ic("aligned sents outputs")
# return df_trimmed, output_plot, file_dl, file_dl_xlsx, None, None, df_aligned, df_html
return df_trimmed, output_plot, file_dl, file_dl_xlsx, file_dl_sents, file_dl_xlsx_sents, df_aligned_sents, df_html