amst / utils.py
Gizachew's picture
Create utils.py
ae0e027 verified
# utils.py
import torch
from torchvision import transforms
from PIL import Image
# Character-to-Index Mapping (should be the same as used during training)
amharic_chars = list(' ሀሁሂሃሄህሆለሉሊላሌልሎሐሑሒሓሔሕሖመሙሚማሜምሞሰሱሲሳስሶረሩሪራሬርሮሠሡሢሣሤሥሦሸሹሺሻሼሽሾቀቁቂቃቄቅቆበቡቢባቤብቦተቱቲታቴትቶቸቹቺቻቼችቾኀኃነኑኒናኔንኖኘኙኚኛኜኝኞአኡኢኣኤእኦከኩኪካኬክኮኸኹኺኻኼኽኾወዉዊዋዌውዎዐዑዒዓዔዕዖዘዙዚዛዜዝዞዠዡዢዣዤዥዦየዩዪያዬይዮደዱዲዳዴድዶጀጁጂጃጄጅጆገጉጊጋጌግጎጠጡጢጣጤጥጦጨጩጪጫጬጭጮጰጱጲጳጴጵጶጸጹጺጻጼጽጾፀፁፂፃፄፅፆፈፉፊፋፌፍፎፐፑፒፓፔፕፖቨቩቪቫቬቭቮ0123456789፥፣()-ሏሟሷሯሿቧቆቈቋቷቿኗኟዟዧዷጇጧጯጿፏኳኋኧቯጐጕጓ።')
char_to_idx = {char: idx + 1 for idx, char in enumerate(amharic_chars)} # Start indexing from 1
char_to_idx['<UNK>'] = len(amharic_chars) + 1 # Unknown characters
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
idx_to_char[0] = '<blank>' # CTC blank token
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""
Preprocess the input image for the model.
"""
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
return transform(image)
def decode_predictions(preds: torch.Tensor) -> str:
"""
Decode the model's predictions using Best Path Decoding.
"""
preds = torch.argmax(preds, dim=2).transpose(0, 1) # [batch_size, H*W]
decoded_texts = []
for pred in preds:
pred = pred.cpu().numpy()
decoded = []
previous = 0 # Assuming blank index is 0
for p in pred:
if p != previous and p != 0:
decoded.append(idx_to_char.get(p, '<UNK>'))
previous = p
recognized_text = ''.join(decoded)
decoded_texts.append(recognized_text)
return decoded_texts