import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from torchvision.transforms import functional as TF from PIL import Image from sinlib import Tokenizer from pathlib import Path MAX_LENGTH = 32 DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # Load tokenizer @st.cache_resource def load_tokenizer(): tokenizer = Tokenizer(max_length=1000).load_from_pretrained("gpt2.json") tokenizer.max_length = MAX_LENGTH return tokenizer tokenizer = load_tokenizer() class CRNN(nn.Module): def __init__(self, num_chars): super(CRNN, self).__init__() self.cnn = nn.Sequential( nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=(2, 1)), nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU(), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=(2, 1)), nn.Conv2d(512, 512, kernel_size=2, stride=1), nn.BatchNorm2d(512), nn.ReLU() ) # RNN layers self.rnn = nn.GRU(512 * 7, 256, bidirectional=True, batch_first=True, num_layers=2) self.linear = nn.Linear(512, num_chars) def forward(self, x): conv = self.cnn(x) batch, channel, height, width = conv.size() conv = conv.permute(0, 3, 1, 2) conv = conv.contiguous().view(batch, width, channel * height) output, _ = self.rnn(conv) output = self.linear(output) return output @st.cache_resource def load_model(selected_model_path): model = CRNN(num_chars=len(tokenizer)) model.load_state_dict(torch.load(f'{selected_model_path}', map_location=torch.device('cpu'))) model.eval() return model def preprocess_image(image): transform = transforms.Compose([ transforms.Grayscale(), transforms.ToTensor(), ]) image = TF.resize(image, (128, 2600), interpolation=Image.BILINEAR) image = transform(image) if image.shape[0] != 1: image = image.mean(dim=0, keepdim=True) image = image.unsqueeze(0) return image def inference(model, image): with torch.no_grad(): image = image.to(DEVICE) outputs = model(image) log_probs = F.log_softmax(outputs, dim=2) pred_chars = torch.argmax(log_probs, dim=2) return pred_chars.squeeze().cpu().numpy() st.title("CRNN Printed Text Recognition") st.warning("**Note**: This model was trained on images with these settings, \ with width ranging from 800 to 2600 pixels and height ranging from 128 to 600 pixels. \ For better results, use images within these limitations." ) fp = Path(".").glob("*.pth") selected_model_path = st.selectbox(label="Select Model...", options=fp) model = load_model(selected_model_path) uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image', use_column_width=True) w,h = image.size w_color = h_color = 'green' if not 800 <= w <= 2600: w_color = "red" if not 128 <= h <= 600: h_color = "red" with st.expander("Click See Image Details"): st.write(f"Width = :{w_color}[{w}];",f"Height = :{h_color}[{h}]") if st.button('Predict'): processed_image = preprocess_image(image) predicted_sequence = inference(model, processed_image) decoded_text = tokenizer.decode(predicted_sequence, skip_special_tokens=True) st.write("Predicted Text:") st.write(decoded_text) st.markdown("---") st.write("Note: This app uses a pre-trained CRNN model for printed Sinhala text recognition.")