Spaces:
Build error
Build error
| """Convert paras to sents.""" | |
| # pylint: disable=unused-import, too-many-branches, ungrouped-imports | |
| from typing import Callable, List, Optional, Tuple, Union | |
| from itertools import zip_longest | |
| import numpy as np | |
| import pandas as pd | |
| from logzero import logger | |
| from radiobee.align_sents import align_sents | |
| from radiobee.seg_text import seg_text | |
| from radiobee.detect import detect | |
| try: | |
| from radiobee.shuffle_sents import shuffle_sents | |
| except Exception as exc: | |
| logger.error("shuffle_sents not available: %s, using align_sents", exc) | |
| shuffle_sents = lambda x1, x2, lang1="", lang2="": align_sents(x1, x2) # noqa | |
| def paras2sents( | |
| paras_: Union[pd.DataFrame, List[Tuple[str, str, Union[str, float]]], np.ndarray], | |
| align_func: Optional[Union[Callable, str]] = None, | |
| lang1: Optional[str] = None, | |
| lang2: Optional[str] = None, | |
| ) -> List[Tuple[str, str, Union[str, float]]]: | |
| """Convert paras to sents using align_func. | |
| Args: | |
| paras_: list of 3-tuples or numpy or pd.DataFrame | |
| lang1: fisrt lang code | |
| lang2: second lang code | |
| align_func: func used in the sent level | |
| if set to None, default to align_sents | |
| Returns: | |
| list of sents (possible with likelihood for shuffle_sents) | |
| """ | |
| # wrap everything in pd.DataFrame | |
| # necessary to make pyright happy | |
| paras = pd.DataFrame(paras_).fillna("") | |
| # take the first three columns at maximum | |
| paras = paras.iloc[:, :3] | |
| if len(paras.columns) < 2: | |
| logger.error( | |
| "Need at least two columns, got %s", | |
| len(paras.columns) | |
| ) | |
| raise Exception("wrong data") | |
| # append the third col (all "") if there are only two cols | |
| if len(paras.columns) < 3: | |
| paras.insert(2, "likelihood", [""] * len(paras)) | |
| if lang1 is None: | |
| lang1 = detect(" ".join(paras.iloc[:, 0])) | |
| if lang2 is None: | |
| lang2 = detect(" ".join(paras.iloc[:, 1])) | |
| left, right = [], [] | |
| row0, row1 = [], [] | |
| for elm0, elm1, elm2 in paras.values: | |
| sents0 = seg_text(elm0, lang1) | |
| sents1 = seg_text(elm1, lang2) | |
| if isinstance(elm2, float) and elm2 > 0: | |
| if row0 or row1: | |
| left.append(row0) | |
| right.append(row1) | |
| row0, row1 = [], [] # collect and prepare | |
| if sents0: | |
| left.append(sents0) | |
| if sents1: | |
| right.append(sents1) | |
| else: | |
| if sents0: | |
| row0.extend(sents0) | |
| if sents1: | |
| row1.extend(sents1) | |
| # collect possible last batch | |
| if row0 or row1: | |
| left.append(row0) | |
| right.append(row1) | |
| # res = [*zip(left, right)] | |
| # align each batch using align_func | |
| # ready align_func | |
| if align_func is None: | |
| align_func = align_sents | |
| if isinstance(align_func, str) and align_func.startswith("shuffle") or not isinstance(align_func, str) and align_func.__name__ in ["shuffle_sents"]: | |
| align_func = lambda row0, row1: shuffle_sents(row0, row1, lang1=lang1, lang2=lang2) # noqa | |
| else: | |
| align_func = align_sents | |
| res = [] | |
| for row0, row1 in zip(left, right): | |
| try: | |
| _ = align_func(row0, row1) | |
| except Exception as exc: | |
| logger.error("errors: %s, resorting to zip_longest", exc) | |
| _ = [*zip_longest(row0, row1, fillvalue="")] | |
| # res.append(_) | |
| res.extend(_) | |
| return res | |