Spaces:
Build error
Build error
File size: 4,228 Bytes
3d38118 4c04f50 3d38118 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""Gen proper alignment for a given triple_set.
cmat = fetch_sent_corr(src, tgt)
src_len, tgt_len = np.array(cmat).shape
r_ali = gen_row_alignment(cmat, tgt_len, src_len) # note the order
src[r_ali[1]], tgt[r_ali[0]], r_ali[2]
or !!! (targer, source)
cmat = fetch_sent_corr(tgt, src) # note the order
src_len, tgt_len = np.array(cmat).shape
r_ali = gen_row_alignment(cmat, src_len, tgt_len)
src[r_ali[0]], tgt[r_ali[1]], r_ali[2]
---
src_txt = 'data/wu_ch2_en.txt'
tgt_txt = 'data/wu_ch2_zh.txt'
assert Path(src_txt).exists()
assert Path(tgt_txt).exists()
src_text, _ = load_paras(src_txt)
tgt_text, _ = load_paras(tgt_txt)
cos_matrix = gen_cos_matrix(src_text, tgt_text)
t_set, m_matrix = find_aligned_pairs(cos_matrix0, thr=0.4, matrix=True)
resu = gen_row_alignment(t_set, src_len, tgt_len)
resu = np.array(resu)
idx = -1
idx += 1; (resu[idx], src_text[int(resu[idx, 0])],
tgt_text[int(resu[idx, 1])]) if all(resu[idx]) else resu[idx]
idx += 1; i0, i1, i2 = resu[idx]; '***' if i0 == ''
else src_text[int(i0)], '***' if i1 == '' else tgt_text[int(i1)], ''
if i2 == '' else i2
"""
# pylint: disable=line-too-long, unused-variable
from typing import List, Union
# natural extrapolation with slope equal to 1
from itertools import zip_longest as zip_longest_middle
import numpy as np
from logzero import logger
# from tinybee.zip_longest_middle import zip_longest_middle
# from tinybee.zip_longest_middle import zip_longest_middle
# from tinybee.find_pairs import find_pairs
# logger = logging.getLogger(__name__)
# logger.addHandler(logging.NullHandler())
def gen_row_alignment( # pylint: disable=too-many-locals
t_set,
src_len,
tgt_len,
# ) -> List[Tuple[Union[str, int], Union[str, int], Union[str, float]]]:
) -> List[List[Union[str, float]]]:
"""Gen proper rows for given triple_set.
Arguments:
[t_set {np.array or list}] -- [nll matrix]
[src_len {int}] -- numb of source texts (para/sents)
[tgt_len {int}] -- numb of target texts (para/sents)
Returns:
[np.array] -- [proper rows]
"""
t_set = np.array(t_set, dtype="object")
# len0 = src_len
# len1 tgt text length, must be provided
len1 = tgt_len
# rearrange t_set as buff in increasing order
buff = [[-1, -1, ""]] #
idx_t = 0
# for elm in t_set:
# start with bigger value from the 3rd col
y00, yargmax, ymax = zip(*t_set)
ymax_ = np.array(ymax).copy()
reset_v = np.min(ymax_) - 1
for count in range(tgt_len):
argmax = np.argmax(ymax_)
# reset
ymax_[argmax] = reset_v
idx_t = argmax
elm = t_set[idx_t]
logger.debug("%s: %s, %s", count, idx_t, elm)
# find loc to insert
elm0, elm1, elm2 = elm
idx = -1
for idx, loc in enumerate(buff):
if loc[0] > elm0:
break
else:
idx += 1 # last
# make sure elm1 is within the range
# prev elm1 < elm1 < next elm1
if elm1 > buff[idx - 1][1]:
try: # overflow possible (idx + 1 in # last)
next_elm = buff[idx][1]
except IndexError:
next_elm = len1
if elm1 < next_elm:
# insert '' if necessary
# using zip_longest_middle
buff.insert(
idx, [elm0, elm1, elm2],
)
# logger.debug('---')
idx_t += 1
# if idx_t == 24: # 20:
# break
# remove [-1, -1]
# buff.pop(0)
# buff = np.array(buff, dtype='object')
# take care of the tail
buff += [[src_len, tgt_len, ""]]
resu = []
# merit = []
for idx, elm in enumerate(buff[1:]):
idx1 = idx + 1
elm0_, elm1_, elm2_ = buff[idx1 - 1] # idx starts from 0
elm0, elm1, elm2 = elm
del elm2_, elm2
tmp0 = zip_longest_middle(
list(range(elm0_ + 1, elm0)), list(range(elm1_ + 1, elm1)), fillvalue="",
)
# convet to list entries & attache merit
tmp = [list(t_elm) + [""] for t_elm in tmp0]
# update resu
resu += tmp + [buff[idx1]]
# remove the last entry
return resu[:-1]
|