File size: 5,578 Bytes
77a12fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from . import text_cleaners
from typing import Dict, List, Optional
from .constants import ALL_POSSIBLE_HARAQAT
import sentencepiece as spm
class TextEncoder:
pad = "P"
def __init__(
self,
input_chars: List[str],
target_charts: List[str],
cleaner_fn: Optional[str] = None,
reverse_input: bool = False,
reverse_target: bool = False,
sp_model_path=None,
):
if cleaner_fn:
self.cleaner_fn = getattr(text_cleaners, cleaner_fn)
else:
self.cleaner_fn = None
self.input_symbols: List[str] = [TextEncoder.pad] + input_chars
self.target_symbols: List[str] = [TextEncoder.pad] + target_charts
if sp_model_path is None:
self.input_symbol_to_id: Dict[str, int] = {
s: i for i, s in enumerate(self.input_symbols)
}
self.input_id_to_symbol: Dict[int, str] = {
i: s for i, s in enumerate(self.input_symbols)
}
else:
sp_model = spm.SentencePieceProcessor()
sp_model.load(sp_model_path + "/sp.model")
self.input_symbol_to_id: Dict[str, int] = {
s: sp_model.PieceToId(s+'▁') for s in self.input_symbols
}
self.input_symbol_to_id[" "] = sp_model.PieceToId("|") # encode space
self.input_symbol_to_id[TextEncoder.pad] = 0 # encode padding
self.input_space_id = sp_model.PieceToId("|")
self.input_id_to_symbol: Dict[int, str] = {
i: s for s, i in self.input_symbol_to_id.items()
}
self.target_symbol_to_id: Dict[str, int] = {
s: i for i, s in enumerate(self.target_symbols)
}
self.target_id_to_symbol: Dict[int, str] = {
i: s for i, s in enumerate(self.target_symbols)
}
self.reverse_input = reverse_input
self.reverse_target = reverse_target
self.input_pad_id = self.input_symbol_to_id[self.pad]
self.target_pad_id = self.target_symbol_to_id[self.pad]
self.start_symbol_id = None
def input_to_sequence(self, text: str) -> List[int]:
if self.reverse_input:
text = "".join(list(reversed(text)))
sequence = [self.input_symbol_to_id[s] for s in text if s not in [self.pad]]
return sequence
def target_to_sequence(self, text: str) -> List[int]:
if self.reverse_target:
text = "".join(list(reversed(text)))
sequence = [self.target_symbol_to_id[s] for s in text if s not in [self.pad]]
return sequence
def sequence_to_input(self, sequence: List[int]):
return [
self.input_id_to_symbol[symbol]
for symbol in sequence
if symbol in self.input_id_to_symbol and symbol not in [self.input_pad_id]
]
def sequence_to_target(self, sequence: List[int]):
return [
self.target_id_to_symbol[symbol]
for symbol in sequence
if symbol in self.target_id_to_symbol and symbol not in [self.target_pad_id]
]
def clean(self, text):
if self.cleaner_fn:
return self.cleaner_fn(text)
return text
def combine_text_and_haraqat(self, input_ids: List[int], output_ids: List[int]):
"""
Combines the input text with its corresponding haraqat
Args:
inputs: a list of ids representing the input text
outputs: a list of ids representing the output text
Returns:
text: the text after merging the inputs text representation with the output
representation
"""
output = ""
for i, input_id in enumerate(input_ids):
if input_id == self.input_pad_id:
break
output += self.input_id_to_symbol[input_id]
# if input_id == self.input_space_id:
# continue
output += self.target_id_to_symbol[output_ids[i]]
return output
def __str__(self):
return type(self).__name__
class BasicArabicEncoder(TextEncoder):
def __init__(
self,
cleaner_fn="basic_cleaners",
reverse_input: bool = False,
reverse_target: bool = False,
sp_model_path=None,
):
input_chars: List[str] = list("بض.غىهظخة؟:طس،؛فندؤلوئآك-يذاصشحزءمأجإ ترقعث")
target_charts: List[str] = list(ALL_POSSIBLE_HARAQAT.keys())
super().__init__(
input_chars,
target_charts,
cleaner_fn=cleaner_fn,
reverse_input=reverse_input,
reverse_target=reverse_target,
sp_model_path=sp_model_path,
)
class ArabicEncoderWithStartSymbol(TextEncoder):
def __init__(
self,
cleaner_fn="basic_cleaners",
reverse_input: bool = False,
reverse_target: bool = False,
sp_model_path=None,
):
input_chars: List[str] = list("بض.غىهظخة؟:طس،؛فندؤلوئآك-يذاصشحزءمأجإ ترقعث")
# the only difference from the basic encoder is adding the start symbol
target_charts: List[str] = list(ALL_POSSIBLE_HARAQAT.keys()) + ["s"]
super().__init__(
input_chars,
target_charts,
cleaner_fn=cleaner_fn,
reverse_input=reverse_input,
reverse_target=reverse_target,
sp_model_path=sp_model_path,
)
self.start_symbol_id = self.target_symbol_to_id["s"]
|