OCR-CRNN / app.py
Ransaka's picture
Update app.py
93ea391 verified
raw
history blame contribute delete
No virus
4.34 kB
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from PIL import Image
from sinlib import Tokenizer
from pathlib import Path
MAX_LENGTH = 32
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load tokenizer
@st.cache_resource
def load_tokenizer():
tokenizer = Tokenizer(max_length=1000).load_from_pretrained("gpt2.json")
tokenizer.max_length = MAX_LENGTH
return tokenizer
tokenizer = load_tokenizer()
class CRNN(nn.Module):
def __init__(self, num_chars):
super(CRNN, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 1)),
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 1)),
nn.Conv2d(512, 512, kernel_size=2, stride=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
# RNN layers
self.rnn = nn.GRU(512 * 7, 256, bidirectional=True, batch_first=True, num_layers=2)
self.linear = nn.Linear(512, num_chars)
def forward(self, x):
conv = self.cnn(x)
batch, channel, height, width = conv.size()
conv = conv.permute(0, 3, 1, 2)
conv = conv.contiguous().view(batch, width, channel * height)
output, _ = self.rnn(conv)
output = self.linear(output)
return output
@st.cache_resource
def load_model(selected_model_path):
model = CRNN(num_chars=len(tokenizer))
model.load_state_dict(torch.load(f'{selected_model_path}', map_location=torch.device('cpu')))
model.eval()
return model
def preprocess_image(image):
transform = transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor(),
])
image = TF.resize(image, (128, 2600), interpolation=Image.BILINEAR)
image = transform(image)
if image.shape[0] != 1:
image = image.mean(dim=0, keepdim=True)
image = image.unsqueeze(0)
return image
def inference(model, image):
with torch.no_grad():
image = image.to(DEVICE)
outputs = model(image)
log_probs = F.log_softmax(outputs, dim=2)
pred_chars = torch.argmax(log_probs, dim=2)
return pred_chars.squeeze().cpu().numpy()
st.title("CRNN Printed Text Recognition")
st.warning("**Note**: This model was trained on images with these settings, \
with width ranging from 800 to 2600 pixels and height ranging from 128 to 600 pixels. \
For better results, use images within these limitations."
)
fp = Path(".").glob("*.pth")
selected_model_path = st.selectbox(label="Select Model...", options=fp)
model = load_model(selected_model_path)
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
w,h = image.size
w_color = h_color = 'green'
if not 800 <= w <= 2600:
w_color = "red"
if not 128 <= h <= 600:
h_color = "red"
with st.expander("Click See Image Details"):
st.write(f"Width = :{w_color}[{w}];",f"Height = :{h_color}[{h}]")
if st.button('Predict'):
processed_image = preprocess_image(image)
predicted_sequence = inference(model, processed_image)
decoded_text = tokenizer.decode(predicted_sequence, skip_special_tokens=True)
st.write("Predicted Text:")
st.write(decoded_text)
st.markdown("---")
st.write("Note: This app uses a pre-trained CRNN model for printed Sinhala text recognition.")