Spaces:
Sleeping
Sleeping
from . import text_cleaners | |
from typing import Dict, List, Optional | |
from .constants import ALL_POSSIBLE_HARAQAT | |
import sentencepiece as spm | |
class TextEncoder: | |
pad = "P" | |
def __init__( | |
self, | |
input_chars: List[str], | |
target_charts: List[str], | |
cleaner_fn: Optional[str] = None, | |
reverse_input: bool = False, | |
reverse_target: bool = False, | |
sp_model_path=None, | |
): | |
if cleaner_fn: | |
self.cleaner_fn = getattr(text_cleaners, cleaner_fn) | |
else: | |
self.cleaner_fn = None | |
self.input_symbols: List[str] = [TextEncoder.pad] + input_chars | |
self.target_symbols: List[str] = [TextEncoder.pad] + target_charts | |
if sp_model_path is None: | |
self.input_symbol_to_id: Dict[str, int] = { | |
s: i for i, s in enumerate(self.input_symbols) | |
} | |
self.input_id_to_symbol: Dict[int, str] = { | |
i: s for i, s in enumerate(self.input_symbols) | |
} | |
else: | |
sp_model = spm.SentencePieceProcessor() | |
sp_model.load(sp_model_path + "/sp.model") | |
self.input_symbol_to_id: Dict[str, int] = { | |
s: sp_model.PieceToId(s+'▁') for s in self.input_symbols | |
} | |
self.input_symbol_to_id[" "] = sp_model.PieceToId("|") # encode space | |
self.input_symbol_to_id[TextEncoder.pad] = 0 # encode padding | |
self.input_space_id = sp_model.PieceToId("|") | |
self.input_id_to_symbol: Dict[int, str] = { | |
i: s for s, i in self.input_symbol_to_id.items() | |
} | |
self.target_symbol_to_id: Dict[str, int] = { | |
s: i for i, s in enumerate(self.target_symbols) | |
} | |
self.target_id_to_symbol: Dict[int, str] = { | |
i: s for i, s in enumerate(self.target_symbols) | |
} | |
self.reverse_input = reverse_input | |
self.reverse_target = reverse_target | |
self.input_pad_id = self.input_symbol_to_id[self.pad] | |
self.target_pad_id = self.target_symbol_to_id[self.pad] | |
self.start_symbol_id = None | |
def input_to_sequence(self, text: str) -> List[int]: | |
if self.reverse_input: | |
text = "".join(list(reversed(text))) | |
sequence = [self.input_symbol_to_id[s] for s in text if s not in [self.pad]] | |
return sequence | |
def target_to_sequence(self, text: str) -> List[int]: | |
if self.reverse_target: | |
text = "".join(list(reversed(text))) | |
sequence = [self.target_symbol_to_id[s] for s in text if s not in [self.pad]] | |
return sequence | |
def sequence_to_input(self, sequence: List[int]): | |
return [ | |
self.input_id_to_symbol[symbol] | |
for symbol in sequence | |
if symbol in self.input_id_to_symbol and symbol not in [self.input_pad_id] | |
] | |
def sequence_to_target(self, sequence: List[int]): | |
return [ | |
self.target_id_to_symbol[symbol] | |
for symbol in sequence | |
if symbol in self.target_id_to_symbol and symbol not in [self.target_pad_id] | |
] | |
def clean(self, text): | |
if self.cleaner_fn: | |
return self.cleaner_fn(text) | |
return text | |
def combine_text_and_haraqat(self, input_ids: List[int], output_ids: List[int]): | |
""" | |
Combines the input text with its corresponding haraqat | |
Args: | |
inputs: a list of ids representing the input text | |
outputs: a list of ids representing the output text | |
Returns: | |
text: the text after merging the inputs text representation with the output | |
representation | |
""" | |
output = "" | |
for i, input_id in enumerate(input_ids): | |
if input_id == self.input_pad_id: | |
break | |
output += self.input_id_to_symbol[input_id] | |
# if input_id == self.input_space_id: | |
# continue | |
output += self.target_id_to_symbol[output_ids[i]] | |
return output | |
def __str__(self): | |
return type(self).__name__ | |
class BasicArabicEncoder(TextEncoder): | |
def __init__( | |
self, | |
cleaner_fn="basic_cleaners", | |
reverse_input: bool = False, | |
reverse_target: bool = False, | |
sp_model_path=None, | |
): | |
input_chars: List[str] = list("بض.غىهظخة؟:طس،؛فندؤلوئآك-يذاصشحزءمأجإ ترقعث") | |
target_charts: List[str] = list(ALL_POSSIBLE_HARAQAT.keys()) | |
super().__init__( | |
input_chars, | |
target_charts, | |
cleaner_fn=cleaner_fn, | |
reverse_input=reverse_input, | |
reverse_target=reverse_target, | |
sp_model_path=sp_model_path, | |
) | |
class ArabicEncoderWithStartSymbol(TextEncoder): | |
def __init__( | |
self, | |
cleaner_fn="basic_cleaners", | |
reverse_input: bool = False, | |
reverse_target: bool = False, | |
sp_model_path=None, | |
): | |
input_chars: List[str] = list("بض.غىهظخة؟:طس،؛فندؤلوئآك-يذاصشحزءمأجإ ترقعث") | |
# the only difference from the basic encoder is adding the start symbol | |
target_charts: List[str] = list(ALL_POSSIBLE_HARAQAT.keys()) + ["s"] | |
super().__init__( | |
input_chars, | |
target_charts, | |
cleaner_fn=cleaner_fn, | |
reverse_input=reverse_input, | |
reverse_target=reverse_target, | |
sp_model_path=sp_model_path, | |
) | |
self.start_symbol_id = self.target_symbol_to_id["s"] | |