|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from torchvision.transforms import functional as TF |
|
from PIL import Image |
|
from sinlib import Tokenizer |
|
from pathlib import Path |
|
|
|
MAX_LENGTH = 32 |
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
@st.cache_resource |
|
def load_tokenizer(): |
|
tokenizer = Tokenizer(max_length=1000).load_from_pretrained("gpt2.json") |
|
tokenizer.max_length = MAX_LENGTH |
|
return tokenizer |
|
|
|
tokenizer = load_tokenizer() |
|
|
|
class CRNN(nn.Module): |
|
def __init__(self, num_chars): |
|
super(CRNN, self).__init__() |
|
|
|
self.cnn = nn.Sequential( |
|
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(), |
|
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=(2, 1)), |
|
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(), |
|
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=(2, 1)), |
|
nn.Conv2d(512, 512, kernel_size=2, stride=1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU() |
|
) |
|
|
|
|
|
self.rnn = nn.GRU(512 * 7, 256, bidirectional=True, batch_first=True, num_layers=2) |
|
self.linear = nn.Linear(512, num_chars) |
|
|
|
def forward(self, x): |
|
conv = self.cnn(x) |
|
batch, channel, height, width = conv.size() |
|
conv = conv.permute(0, 3, 1, 2) |
|
conv = conv.contiguous().view(batch, width, channel * height) |
|
output, _ = self.rnn(conv) |
|
output = self.linear(output) |
|
return output |
|
|
|
@st.cache_resource |
|
def load_model(selected_model_path): |
|
model = CRNN(num_chars=len(tokenizer)) |
|
model.load_state_dict(torch.load(f'{selected_model_path}', map_location=torch.device('cpu'))) |
|
model.eval() |
|
return model |
|
|
|
def preprocess_image(image): |
|
transform = transforms.Compose([ |
|
transforms.Grayscale(), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
image = TF.resize(image, (128, 2600), interpolation=Image.BILINEAR) |
|
image = transform(image) |
|
|
|
if image.shape[0] != 1: |
|
image = image.mean(dim=0, keepdim=True) |
|
|
|
image = image.unsqueeze(0) |
|
return image |
|
|
|
def inference(model, image): |
|
with torch.no_grad(): |
|
image = image.to(DEVICE) |
|
outputs = model(image) |
|
log_probs = F.log_softmax(outputs, dim=2) |
|
pred_chars = torch.argmax(log_probs, dim=2) |
|
return pred_chars.squeeze().cpu().numpy() |
|
|
|
st.title("CRNN Printed Text Recognition") |
|
st.warning("**Note**: This model was trained on images with these settings, \ |
|
with width ranging from 800 to 2600 pixels and height ranging from 128 to 600 pixels. \ |
|
For better results, use images within these limitations." |
|
) |
|
fp = Path(".").glob("*.pth") |
|
selected_model_path = st.selectbox(label="Select Model...", options=fp) |
|
model = load_model(selected_model_path) |
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
w,h = image.size |
|
w_color = h_color = 'green' |
|
if not 800 <= w <= 2600: |
|
w_color = "red" |
|
if not 128 <= h <= 600: |
|
h_color = "red" |
|
with st.expander("Click See Image Details"): |
|
st.write(f"Width = :{w_color}[{w}];",f"Height = :{h_color}[{h}]") |
|
|
|
if st.button('Predict'): |
|
processed_image = preprocess_image(image) |
|
predicted_sequence = inference(model, processed_image) |
|
|
|
decoded_text = tokenizer.decode(predicted_sequence, skip_special_tokens=True) |
|
st.write("Predicted Text:") |
|
st.write(decoded_text) |
|
|
|
st.markdown("---") |
|
st.write("Note: This app uses a pre-trained CRNN model for printed Sinhala text recognition.") |