|
|
|
|
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
|
|
|
|
amharic_chars = list(' ሀሁሂሃሄህሆለሉሊላሌልሎሐሑሒሓሔሕሖመሙሚማሜምሞሰሱሲሳስሶረሩሪራሬርሮሠሡሢሣሤሥሦሸሹሺሻሼሽሾቀቁቂቃቄቅቆበቡቢባቤብቦተቱቲታቴትቶቸቹቺቻቼችቾኀኃነኑኒናኔንኖኘኙኚኛኜኝኞአኡኢኣኤእኦከኩኪካኬክኮኸኹኺኻኼኽኾወዉዊዋዌውዎዐዑዒዓዔዕዖዘዙዚዛዜዝዞዠዡዢዣዤዥዦየዩዪያዬይዮደዱዲዳዴድዶጀጁጂጃጄጅጆገጉጊጋጌግጎጠጡጢጣጤጥጦጨጩጪጫጬጭጮጰጱጲጳጴጵጶጸጹጺጻጼጽጾፀፁፂፃፄፅፆፈፉፊፋፌፍፎፐፑፒፓፔፕፖቨቩቪቫቬቭቮ0123456789፥፣()-ሏሟሷሯሿቧቆቈቋቷቿኗኟዟዧዷጇጧጯጿፏኳኋኧቯጐጕጓ።') |
|
|
|
char_to_idx = {char: idx + 1 for idx, char in enumerate(amharic_chars)} |
|
char_to_idx['<UNK>'] = len(amharic_chars) + 1 |
|
idx_to_char = {idx: char for char, idx in char_to_idx.items()} |
|
idx_to_char[0] = '<blank>' |
|
|
|
def preprocess_image(image: Image.Image) -> torch.Tensor: |
|
""" |
|
Preprocess the input image for the model. |
|
""" |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]), |
|
]) |
|
return transform(image) |
|
|
|
def decode_predictions(preds: torch.Tensor) -> str: |
|
""" |
|
Decode the model's predictions using Best Path Decoding. |
|
""" |
|
preds = torch.argmax(preds, dim=2).transpose(0, 1) |
|
decoded_texts = [] |
|
|
|
for pred in preds: |
|
pred = pred.cpu().numpy() |
|
decoded = [] |
|
previous = 0 |
|
for p in pred: |
|
if p != previous and p != 0: |
|
decoded.append(idx_to_char.get(p, '<UNK>')) |
|
previous = p |
|
recognized_text = ''.join(decoded) |
|
decoded_texts.append(recognized_text) |
|
|
|
return decoded_texts |
|
|