Ransaka commited on
Commit
7f200cc
1 Parent(s): 4736a15

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
+ from torchvision.transforms import functional as TF
7
+ from PIL import Image
8
+ from sinlib import Tokenizer
9
+
10
+ MAX_LENGTH = 32
11
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+
13
+ # Load tokenizer
14
+ @st.cache_resource
15
+ def load_tokenizer():
16
+ tokenizer = Tokenizer(max_length=1000).load_from_pretrained("gpt2.json")
17
+ tokenizer.max_length = MAX_LENGTH
18
+ return tokenizer
19
+
20
+ tokenizer = load_tokenizer()
21
+
22
+ class CRNN(nn.Module):
23
+ def __init__(self, num_chars):
24
+ super(CRNN, self).__init__()
25
+
26
+ self.cnn = nn.Sequential(
27
+ nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
28
+ nn.ReLU(),
29
+ nn.MaxPool2d(kernel_size=2, stride=2),
30
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
31
+ nn.ReLU(),
32
+ nn.MaxPool2d(kernel_size=2, stride=2),
33
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
34
+ nn.BatchNorm2d(256),
35
+ nn.ReLU(),
36
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
37
+ nn.ReLU(),
38
+ nn.MaxPool2d(kernel_size=(2, 1)),
39
+ nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
40
+ nn.BatchNorm2d(512),
41
+ nn.ReLU(),
42
+ nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
43
+ nn.ReLU(),
44
+ nn.MaxPool2d(kernel_size=(2, 1)),
45
+ nn.Conv2d(512, 512, kernel_size=2, stride=1),
46
+ nn.BatchNorm2d(512),
47
+ nn.ReLU()
48
+ )
49
+
50
+ # RNN layers
51
+ self.rnn = nn.GRU(512 * 7, 256, bidirectional=True, batch_first=True, num_layers=2)
52
+ self.linear = nn.Linear(512, num_chars)
53
+
54
+ def forward(self, x):
55
+ conv = self.cnn(x)
56
+ batch, channel, height, width = conv.size()
57
+ conv = conv.permute(0, 3, 1, 2)
58
+ conv = conv.contiguous().view(batch, width, channel * height)
59
+ output, _ = self.rnn(conv)
60
+ output = self.linear(output)
61
+ return output
62
+
63
+ @st.cache_resource
64
+ def load_model():
65
+ model = CRNN(num_chars=len(tokenizer))
66
+ model.load_state_dict(torch.load('checkpoint-with-cer-0.18952566385269165.pth', map_location=torch.device('cpu')))
67
+ model.eval()
68
+ return model
69
+
70
+ model = load_model()
71
+
72
+ def preprocess_image(image):
73
+ transform = transforms.Compose([
74
+ transforms.Grayscale(),
75
+ transforms.ToTensor(),
76
+ ])
77
+
78
+ image = TF.resize(image, (128, 2600), interpolation=Image.BILINEAR)
79
+ image = transform(image)
80
+
81
+ if image.shape[0] != 1:
82
+ image = image.mean(dim=0, keepdim=True)
83
+
84
+ image = image.unsqueeze(0)
85
+ return image
86
+
87
+ def inference(model, image):
88
+ with torch.no_grad():
89
+ image = image.to(DEVICE)
90
+ outputs = model(image)
91
+ log_probs = F.log_softmax(outputs, dim=2)
92
+ pred_chars = torch.argmax(log_probs, dim=2)
93
+ return pred_chars.squeeze().cpu().numpy()
94
+
95
+ st.title("CRNN Printed Text Recognition")
96
+ st.warning("**Note**: This model was trained on images with these settings, \
97
+ with width ranging from 800 to 2600 pixels and height ranging from 128 to 600 pixels. \
98
+ For better results, use images within these limitations."
99
+ )
100
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
101
+
102
+ if uploaded_file is not None:
103
+ image = Image.open(uploaded_file)
104
+ st.image(image, caption='Uploaded Image', use_column_width=True)
105
+ w,h = image.size
106
+ w_color = h_color = 'green'
107
+ if not 800 <= w <= 2600:
108
+ w_color = "red"
109
+ if not 128 <= h <= 600:
110
+ h_color = "red"
111
+ with st.expander("Click See Image Details"):
112
+ st.write(f"Width = :{w_color}[{w}];",f"Height = :{h_color}[{h}]")
113
+
114
+ if st.button('Predict'):
115
+ processed_image = preprocess_image(image)
116
+ predicted_sequence = inference(model, processed_image)
117
+
118
+ decoded_text = tokenizer.decode(predicted_sequence, skip_special_tokens=True)
119
+ st.write("Predicted Text:")
120
+ st.write(decoded_text)
121
+
122
+ st.markdown("---")
123
+ st.write("Note: This app uses a pre-trained CRNN model for printed Sinhala text recognition.")