Spaces:
Sleeping
Sleeping
File size: 1,272 Bytes
7694c84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch
from .symbols import (RNN_BIG_CHARACTERS_MAPPING,
DIACRITICS_LIST, ARABIC_LETTERS_LIST, RNN_REV_CLASSES_MAPPING, RNN_SMALL_CHARACTERS_MAPPING)
def remove_diacritics(data, DIACRITICS_LIST):
return data.translate(str.maketrans('', '', ''.join(DIACRITICS_LIST)))
CHARACTERS_MAPPING = RNN_BIG_CHARACTERS_MAPPING
# CHARACTERS_MAPPING = RNN_SMALL_CHARACTERS_MAPPING
REV_CLASSES_MAPPING = RNN_REV_CLASSES_MAPPING
def encode(input_text:str):
x = [CHARACTERS_MAPPING['<SOS>']]
for idx, char in enumerate(input_text):
if char in DIACRITICS_LIST:
continue
if char not in CHARACTERS_MAPPING:
x.append(CHARACTERS_MAPPING['<UNK>'])
else:
x.append(CHARACTERS_MAPPING[char])
x.append(CHARACTERS_MAPPING['<EOS>'])
return x
def decode(probs, input_text:str):
probs = probs[0][1:]
output = ''
for char, prediction in zip(remove_diacritics(input_text, DIACRITICS_LIST), probs):
output += char
if char not in ARABIC_LETTERS_LIST:
continue
prediction = torch.argmax(prediction).item()
if '<' in REV_CLASSES_MAPPING[prediction]:
continue
output += REV_CLASSES_MAPPING[prediction]
return output |