File size: 3,457 Bytes
4c04f50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Convert paras to sents."""
# pylint: disable=unused-import, too-many-branches, ungrouped-imports

from typing import Callable, List, Optional, Tuple, Union

from itertools import zip_longest
import numpy as np
import pandas as pd
from logzero import logger

from radiobee.align_sents import align_sents
from radiobee.seg_text import seg_text
from radiobee.detect import detect

try:
    from radiobee.shuffle_sents import shuffle_sents
except Exception as exc:
    logger.error("shuffle_sents not available: %s, using align_sents", exc)
    shuffle_sents = lambda x1, x2, lang1="", lang2="": align_sents(x1, x2)  # noqa


def paras2sents(
    paras_: Union[pd.DataFrame, List[Tuple[str, str, Union[str, float]]], np.ndarray],
    align_func: Optional[Union[Callable, str]] = None,
    lang1: Optional[str] = None,
    lang2: Optional[str] = None,
) -> List[Tuple[str, str, Union[str, float]]]:
    """Convert paras to sents using align_func.

    Args:
        paras_: list of 3-tuples or numpy or pd.DataFrame
        lang1: fisrt lang code
        lang2: second lang code
        align_func: func used in the sent level
            if set to None, default to align_sents
    Returns:
        list of sents (possible with likelihood for shuffle_sents)
    """
    # wrap everything in pd.DataFrame
    # necessary to make pyright happy
    paras = pd.DataFrame(paras_).fillna("")

    # take the first three columns at maximum
    paras = paras.iloc[:, :3]

    if len(paras.columns) < 2:
        logger.error(
            "Need at least two columns, got %s",
            len(paras.columns)
        )
        raise Exception("wrong data")

    # append the third col (all "") if there are only two cols
    if len(paras.columns) < 3:
        paras.insert(2, "likelihood", [""] * len(paras))

    if lang1 is None:
        lang1 = detect(" ".join(paras.iloc[:, 0]))
    if lang2 is None:
        lang2 = detect(" ".join(paras.iloc[:, 1]))

    left, right = [], []
    row0, row1 = [], []
    for elm0, elm1, elm2 in paras.values:
        sents0 = seg_text(elm0, lang1)
        sents1 = seg_text(elm1, lang2)
        if isinstance(elm2, float) and elm2 > 0:
            if row0 or row1:
                left.append(row0)
                right.append(row1)
            row0, row1 = [], []  # collect and prepare

            if sents0:
                left.append(sents0)
            if sents1:
                right.append(sents1)
        else:
            if sents0:
                row0.extend(sents0)
            if sents1:
                row1.extend(sents1)
    # collect possible last batch
    if row0 or row1:
        left.append(row0)
        right.append(row1)

    # res = [*zip(left, right)]

    # align each batch using align_func

    # ready align_func
    if align_func is None:
        align_func = align_sents
    if isinstance(align_func, str) and align_func.startswith("shuffle") or not isinstance(align_func, str) and align_func.__name__ in ["shuffle_sents"]:
        align_func = lambda row0, row1: shuffle_sents(row0, row1, lang1=lang1, lang2=lang2)  # noqa
    else:
        align_func = align_sents

    res = []
    for row0, row1 in zip(left, right):
        try:
            _ = align_func(row0, row1)
        except Exception as exc:
            logger.error("errors: %s, resorting to zip_longest", exc)
            _ = [*zip_longest(row0, row1, fillvalue="")]

        # res.append(_)
        res.extend(_)

    return res