File size: 2,267 Bytes
ae0e027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# utils.py

import torch
from torchvision import transforms
from PIL import Image

# Character-to-Index Mapping (should be the same as used during training)
amharic_chars = list(' ሀሁሂሃሄህሆለሉሊላሌልሎሐሑሒሓሔሕሖመሙሚማሜምሞሰሱሲሳስሶረሩሪራሬርሮሠሡሢሣሤሥሦሸሹሺሻሼሽሾቀቁቂቃቄቅቆበቡቢባቤብቦተቱቲታቴትቶቸቹቺቻቼችቾኀኃነኑኒናኔንኖኘኙኚኛኜኝኞአኡኢኣኤእኦከኩኪካኬክኮኸኹኺኻኼኽኾወዉዊዋዌውዎዐዑዒዓዔዕዖዘዙዚዛዜዝዞዠዡዢዣዤዥዦየዩዪያዬይዮደዱዲዳዴድዶጀጁጂጃጄጅጆገጉጊጋጌግጎጠጡጢጣጤጥጦጨጩጪጫጬጭጮጰጱጲጳጴጵጶጸጹጺጻጼጽጾፀፁፂፃፄፅፆፈፉፊፋፌፍፎፐፑፒፓፔፕፖቨቩቪቫቬቭቮ0123456789፥፣()-ሏሟሷሯሿቧቆቈቋቷቿኗኟዟዧዷጇጧጯጿፏኳኋኧቯጐጕጓ።')

char_to_idx = {char: idx + 1 for idx, char in enumerate(amharic_chars)}  # Start indexing from 1
char_to_idx['<UNK>'] = len(amharic_chars) + 1  # Unknown characters
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
idx_to_char[0] = '<blank>'  # CTC blank token

def preprocess_image(image: Image.Image) -> torch.Tensor:
    """
    Preprocess the input image for the model.
    """
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    return transform(image)

def decode_predictions(preds: torch.Tensor) -> str:
    """
    Decode the model's predictions using Best Path Decoding.
    """
    preds = torch.argmax(preds, dim=2).transpose(0, 1)  # [batch_size, H*W]
    decoded_texts = []
    
    for pred in preds:
        pred = pred.cpu().numpy()
        decoded = []
        previous = 0  # Assuming blank index is 0
        for p in pred:
            if p != previous and p != 0:
                decoded.append(idx_to_char.get(p, '<UNK>'))
            previous = p
        recognized_text = ''.join(decoded)
        decoded_texts.append(recognized_text)
    
    return decoded_texts