|
|
""" |
|
|
Paper: "UTRNet: High-Resolution Urdu Text Recognition In Printed Documents" presented at ICDAR 2023 |
|
|
Authors: Abdur Rahman, Arjun Ghosh, Chetan Arora |
|
|
GitHub Repository: https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition |
|
|
Project Website: https://abdur75648.github.io/UTRNet/ |
|
|
Copyright (c) 2023-present: This work is licensed under the Creative Commons Attribution-NonCommercial |
|
|
4.0 International License (http://creativecommons.org/licenses/by-nc/4.0/) |
|
|
""" |
|
|
|
|
|
import pytz |
|
|
import torch |
|
|
import numpy as np |
|
|
from datetime import datetime |
|
|
import matplotlib.pyplot as plt |
|
|
from torch.autograd import Variable |
|
|
|
|
|
import os,random,shutil |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
class CTCLabelConverter(object): |
|
|
""" Convert between text-label and text-index """ |
|
|
|
|
|
def __init__(self, character): |
|
|
|
|
|
dict_character = list(character) |
|
|
|
|
|
self.dict = {} |
|
|
for i, char in enumerate(dict_character): |
|
|
|
|
|
self.dict[char] = i + 1 |
|
|
|
|
|
self.character = ['[CTCblank]'] + dict_character |
|
|
|
|
|
def encode(self, text, batch_max_length=25): |
|
|
"""convert text-label into text-index. |
|
|
input: |
|
|
text: text labels of each image. [batch_size] |
|
|
batch_max_length: max length of text label in the batch. 25 by default |
|
|
|
|
|
output: |
|
|
text: text index for CTCLoss. [batch_size, batch_max_length] |
|
|
length: length of each text. [batch_size] |
|
|
""" |
|
|
length = [len(s) for s in text] |
|
|
|
|
|
|
|
|
batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) |
|
|
for i, t in enumerate(text): |
|
|
text = list(t) |
|
|
text = [self.dict[char] for char in text] |
|
|
batch_text[i][:len(text)] = torch.LongTensor(text) |
|
|
return (batch_text, torch.IntTensor(length)) |
|
|
|
|
|
def decode(self, text_index, length): |
|
|
""" convert text-index into text-label. """ |
|
|
texts = [] |
|
|
for index, l in enumerate(length): |
|
|
t = text_index[index, :] |
|
|
|
|
|
char_list = [] |
|
|
for i in range(l): |
|
|
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): |
|
|
char_list.append(self.character[t[i]]) |
|
|
text = ''.join(char_list) |
|
|
|
|
|
texts.append(text) |
|
|
return texts |
|
|
|
|
|
|
|
|
class CTCLabelConverterForBaiduWarpctc(object): |
|
|
""" Convert between text-label and text-index for baidu warpctc """ |
|
|
|
|
|
def __init__(self, character): |
|
|
|
|
|
dict_character = list(character) |
|
|
|
|
|
self.dict = {} |
|
|
for i, char in enumerate(dict_character): |
|
|
|
|
|
self.dict[char] = i + 1 |
|
|
|
|
|
self.character = ['[CTCblank]'] + dict_character |
|
|
|
|
|
def encode(self, text, batch_max_length=25): |
|
|
"""convert text-label into text-index. |
|
|
input: |
|
|
text: text labels of each image. [batch_size] |
|
|
output: |
|
|
text: concatenated text index for CTCLoss. |
|
|
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] |
|
|
length: length of each text. [batch_size] |
|
|
""" |
|
|
length = [len(s) for s in text] |
|
|
text = ''.join(text) |
|
|
text = [self.dict[char] for char in text] |
|
|
|
|
|
return (torch.IntTensor(text), torch.IntTensor(length)) |
|
|
|
|
|
def decode(self, text_index, length): |
|
|
""" convert text-index into text-label. """ |
|
|
texts = [] |
|
|
index = 0 |
|
|
for l in length: |
|
|
t = text_index[index:index + l] |
|
|
|
|
|
char_list = [] |
|
|
for i in range(l): |
|
|
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): |
|
|
char_list.append(self.character[t[i]]) |
|
|
text = ''.join(char_list) |
|
|
|
|
|
texts.append(text) |
|
|
index += l |
|
|
return texts |
|
|
|
|
|
|
|
|
class AttnLabelConverter(object): |
|
|
""" Convert between text-label and text-index """ |
|
|
|
|
|
def __init__(self, character): |
|
|
|
|
|
|
|
|
list_token = ['[GO]', '[s]'] |
|
|
list_character = list(character) |
|
|
self.character = list_token + list_character |
|
|
|
|
|
self.dict = {} |
|
|
for i, char in enumerate(self.character): |
|
|
|
|
|
self.dict[char] = i |
|
|
|
|
|
def encode(self, text, batch_max_length=25): |
|
|
""" convert text-label into text-index. |
|
|
input: |
|
|
text: text labels of each image. [batch_size] |
|
|
batch_max_length: max length of text label in the batch. 25 by default |
|
|
|
|
|
output: |
|
|
text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. |
|
|
text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. |
|
|
length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] |
|
|
""" |
|
|
length = [len(s) + 1 for s in text] |
|
|
|
|
|
batch_max_length += 1 |
|
|
|
|
|
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) |
|
|
for i, t in enumerate(text): |
|
|
text = list(t) |
|
|
text.append('[s]') |
|
|
|
|
|
try: |
|
|
text = [self.dict[char] for char in text] |
|
|
except KeyError as e: |
|
|
continue |
|
|
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) |
|
|
return (batch_text, torch.IntTensor(length)) |
|
|
|
|
|
def decode(self, text_index, length): |
|
|
""" convert text-index into text-label. """ |
|
|
texts = [] |
|
|
for index, l in enumerate(length): |
|
|
text = ''.join([self.character[i] for i in text_index[index, :]]) |
|
|
texts.append(text) |
|
|
return texts |
|
|
|
|
|
|
|
|
def imshow(img, title,batch_size=1): |
|
|
std_correction = np.asarray([0.229, 0.224, 0.225]).reshape(3, 1, 1) |
|
|
mean_correction = np.asarray([0.485, 0.456, 0.406]).reshape(3, 1, 1) |
|
|
npimg = np.multiply(img.numpy(), std_correction) + mean_correction |
|
|
plt.figure(figsize = (batch_size * 4, 4)) |
|
|
plt.axis("off") |
|
|
plt.imshow(np.transpose(npimg, (1, 2, 0))) |
|
|
plt.title(title) |
|
|
plt.show() |
|
|
|
|
|
|
|
|
class Averager(object): |
|
|
"""Compute average for torch.Tensor, used for loss average.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.reset() |
|
|
|
|
|
def add(self, v): |
|
|
count = v.data.numel() |
|
|
v = v.data.sum() |
|
|
self.n_count += count |
|
|
self.sum += v |
|
|
|
|
|
def reset(self): |
|
|
self.n_count = 0 |
|
|
self.sum = 0 |
|
|
|
|
|
def val(self): |
|
|
res = 0 |
|
|
if self.n_count != 0: |
|
|
res = self.sum / float(self.n_count) |
|
|
return res |
|
|
|
|
|
class Logger(object): |
|
|
"""For logging while training""" |
|
|
def __init__(self, path): |
|
|
self.logFile = path |
|
|
datetime_now = str(datetime.now(pytz.timezone('Asia/Kolkata')).strftime("%Y-%m-%d_%H-%M-%S")) |
|
|
with open(self.logFile,"w",encoding="utf-8") as f: |
|
|
f.write("Logging at @ " + str(datetime_now) + "\n") |
|
|
|
|
|
def log(self,*input): |
|
|
message = "" |
|
|
for x in input: |
|
|
message+=str(x) + " " |
|
|
message = message.strip() |
|
|
print(message) |
|
|
with open(self.logFile,"a",encoding="utf-8") as f: |
|
|
f.write(str(message)+"\n") |
|
|
|
|
|
|
|
|
def allign_two_strings(x:str, y:str, pxy:int=1, pgap:int=1): |
|
|
""" |
|
|
Source: https://www.geeksforgeeks.org/sequence-alignment-problem/ |
|
|
""" |
|
|
i = 0 |
|
|
j = 0 |
|
|
m = len(x) |
|
|
n = len(y) |
|
|
dp = np.zeros([m+1,n+1], dtype=int) |
|
|
dp[0:(m+1),0] = [ i * pgap for i in range(m+1)] |
|
|
dp[0,0:(n+1)] = [ i * pgap for i in range(n+1)] |
|
|
|
|
|
i = 1 |
|
|
while i <= m: |
|
|
j = 1 |
|
|
while j <= n: |
|
|
if x[i - 1] == y[j - 1]: |
|
|
dp[i][j] = dp[i - 1][j - 1] |
|
|
else: |
|
|
dp[i][j] = min(dp[i - 1][j - 1] + pxy, |
|
|
dp[i - 1][j] + pgap, |
|
|
dp[i][j - 1] + pgap) |
|
|
j += 1 |
|
|
i += 1 |
|
|
|
|
|
l = n + m |
|
|
i = m |
|
|
j = n |
|
|
|
|
|
xpos = l |
|
|
ypos = l |
|
|
|
|
|
xans = np.zeros(l+1, dtype=int) |
|
|
yans = np.zeros(l+1, dtype=int) |
|
|
|
|
|
while not (i == 0 or j == 0): |
|
|
|
|
|
if x[i - 1] == y[j - 1]: |
|
|
xans[xpos] = ord(x[i - 1]) |
|
|
yans[ypos] = ord(y[j - 1]) |
|
|
xpos -= 1 |
|
|
ypos -= 1 |
|
|
i -= 1 |
|
|
j -= 1 |
|
|
elif (dp[i - 1][j - 1] + pxy) == dp[i][j]: |
|
|
|
|
|
xans[xpos] = ord(x[i - 1]) |
|
|
yans[ypos] = ord(y[j - 1]) |
|
|
xpos -= 1 |
|
|
ypos -= 1 |
|
|
i -= 1 |
|
|
j -= 1 |
|
|
|
|
|
elif (dp[i - 1][j] + pgap) == dp[i][j]: |
|
|
xans[xpos] = ord(x[i - 1]) |
|
|
yans[ypos] = ord('_') |
|
|
xpos -= 1 |
|
|
ypos -= 1 |
|
|
i -= 1 |
|
|
|
|
|
elif (dp[i][j - 1] + pgap) == dp[i][j]: |
|
|
xans[xpos] = ord('_') |
|
|
yans[ypos] = ord(y[j - 1]) |
|
|
xpos -= 1 |
|
|
ypos -= 1 |
|
|
j -= 1 |
|
|
|
|
|
|
|
|
while xpos > 0: |
|
|
if i > 0: |
|
|
i -= 1 |
|
|
xans[xpos] = ord(x[i]) |
|
|
xpos -= 1 |
|
|
else: |
|
|
xans[xpos] = ord('_') |
|
|
xpos -= 1 |
|
|
|
|
|
while ypos > 0: |
|
|
if j > 0: |
|
|
j -= 1 |
|
|
yans[ypos] = ord(y[j]) |
|
|
ypos -= 1 |
|
|
else: |
|
|
yans[ypos] = ord('_') |
|
|
ypos -= 1 |
|
|
|
|
|
id = 1 |
|
|
i = l |
|
|
while i >= 1: |
|
|
if (chr(yans[i]) == '_') and chr(xans[i]) == '_': |
|
|
id = i + 1 |
|
|
break |
|
|
|
|
|
i -= 1 |
|
|
|
|
|
i = id |
|
|
x_seq = "" |
|
|
while i <= l: |
|
|
x_seq += chr(xans[i]) |
|
|
i += 1 |
|
|
|
|
|
|
|
|
i = id |
|
|
y_seq = "" |
|
|
while i <= l: |
|
|
y_seq += chr(yans[i]) |
|
|
i += 1 |
|
|
|
|
|
return x_seq, y_seq |
|
|
|
|
|
|
|
|
def count_parameters(model,precision=2): |
|
|
return (round(sum(p.numel() for p in model.parameters() if p.requires_grad) / 10.**6, precision)) |
|
|
|
|
|
''' |
|
|
# Code for counting the number of FLOPs in the CNN backbone during inference |
|
|
Source - https://github.com/fdbtrs/ElasticFace/blob/main/utils/countFLOPS.py |
|
|
''' |
|
|
|
|
|
def count_model_flops(model,in_channels=1, input_res=[32, 400], multiply_adds=True): |
|
|
list_conv = [] |
|
|
|
|
|
def conv_hook(self, input, output): |
|
|
batch_size, input_channels, input_height, input_width = input[0].size() |
|
|
output_channels, output_height, output_width = output[0].size() |
|
|
|
|
|
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) |
|
|
bias_ops = 1 if self.bias is not None else 0 |
|
|
|
|
|
params = output_channels * (kernel_ops + bias_ops) |
|
|
flops = (kernel_ops * ( |
|
|
2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size |
|
|
list_conv.append(flops) |
|
|
list_linear = [] |
|
|
|
|
|
def linear_hook(self, input, output): |
|
|
batch_size = input[0].size(0) if input[0].dim() == 2 else 1 |
|
|
|
|
|
weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) |
|
|
if self.bias is not None: |
|
|
bias_ops = self.bias.nelement() if self.bias.nelement() else 0 |
|
|
flops = batch_size * (weight_ops + bias_ops) |
|
|
else: |
|
|
flops = batch_size * weight_ops |
|
|
list_linear.append(flops) |
|
|
|
|
|
list_bn = [] |
|
|
|
|
|
def bn_hook(self, input, output): |
|
|
list_bn.append(input[0].nelement() * 2) |
|
|
|
|
|
list_relu = [] |
|
|
|
|
|
def relu_hook(self, input, output): |
|
|
list_relu.append(input[0].nelement()) |
|
|
|
|
|
list_pooling = [] |
|
|
|
|
|
def pooling_hook(self, input, output): |
|
|
batch_size, input_channels, input_height, input_width = input[0].size() |
|
|
output_channels, output_height, output_width = output[0].size() |
|
|
|
|
|
kernel_ops = self.kernel_size[0] * self.kernel_size[1] if isinstance(self.kernel_size, tuple) else self.kernel_size * self.kernel_size |
|
|
bias_ops = 0 |
|
|
params = 0 |
|
|
flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size |
|
|
list_pooling.append(flops) |
|
|
|
|
|
def dropout_hook(self, input, output): |
|
|
|
|
|
batch_size, input_channels, input_height, input_width = input[0].size() |
|
|
list_conv.append(2*batch_size*input_channels*input_height*input_width) |
|
|
|
|
|
def sigmoid_hook(self,input,output): |
|
|
|
|
|
batch_size, input_channels, input_height, input_width = input[0].size() |
|
|
list_conv.append(3*batch_size*input_channels*input_height*input_width) |
|
|
|
|
|
def upsample_hook(self, input, output): |
|
|
batch_size, input_channels, input_height, input_width = input[0].size() |
|
|
output_channels, output_height, output_width = output[0].size() |
|
|
|
|
|
kernel_ops = self.scale_factor * self.scale_factor |
|
|
flops = (kernel_ops * ( |
|
|
2 if multiply_adds else 1)) * output_channels * output_height * output_width * batch_size |
|
|
list_conv.append(flops) |
|
|
|
|
|
handles = [] |
|
|
|
|
|
def foo(net): |
|
|
childrens = list(net.children()) |
|
|
if not childrens: |
|
|
if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d): |
|
|
handles.append(net.register_forward_hook(conv_hook)) |
|
|
elif isinstance(net, torch.nn.Linear): |
|
|
handles.append(net.register_forward_hook(linear_hook)) |
|
|
elif isinstance(net, torch.nn.BatchNorm2d) or isinstance(net, torch.nn.BatchNorm1d): |
|
|
handles.append(net.register_forward_hook(bn_hook)) |
|
|
elif isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.PReLU): |
|
|
handles.append(net.register_forward_hook(relu_hook)) |
|
|
elif isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): |
|
|
handles.append(net.register_forward_hook(pooling_hook)) |
|
|
elif isinstance(net, torch.nn.Dropout): |
|
|
handles.append(net.register_forward_hook(dropout_hook)) |
|
|
elif isinstance(net,torch.nn.Upsample): |
|
|
handles.append(net.register_forward_hook(upsample_hook)) |
|
|
elif isinstance(net,torch.nn.Sigmoid): |
|
|
handles.append(net.register_forward_hook(sigmoid_hook)) |
|
|
else: |
|
|
print("warning" + str(net)) |
|
|
return |
|
|
for c in childrens: |
|
|
foo(c) |
|
|
|
|
|
model.eval() |
|
|
foo(model) |
|
|
input = Variable(torch.rand(in_channels, input_res[1], input_res[0]).unsqueeze(0), requires_grad=True) |
|
|
out = model(input) |
|
|
total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling)) |
|
|
for h in handles: |
|
|
h.remove() |
|
|
model.train() |
|
|
|
|
|
def flops_to_string(flops, units='MFLOPS', precision=4): |
|
|
if units == 'GFLOPS': |
|
|
return str(round(flops / 10.**9, precision)) + ' ' + units |
|
|
elif units == 'MFLOPS': |
|
|
return str(round(flops / 10.**6, precision)) + ' ' + units |
|
|
elif units == 'KFLOPS': |
|
|
return str(round(flops / 10.**3, precision)) + ' ' + units |
|
|
else: |
|
|
return str(flops) + ' FLOPS' |
|
|
|
|
|
return flops_to_string(total_flops) |
|
|
|
|
|
|
|
|
def draw_feature_map(visual_feature,vis_dir,num_channel=10): |
|
|
"""draws feature maps for the given visual features |
|
|
Args: |
|
|
visual_feature (Tensor): Shape (C, H, W) |
|
|
vis_dir (String): Directory to save the feature maps |
|
|
""" |
|
|
if os.path.exists(vis_dir): |
|
|
shutil.rmtree(vis_dir) |
|
|
os.makedirs(vis_dir) |
|
|
|
|
|
for i in range(num_channel): |
|
|
random_channel = random.randint(0, visual_feature.shape[1]-1) |
|
|
visual_feature_for_visualization = visual_feature[0, random_channel, :, :].detach().cpu().numpy() |
|
|
|
|
|
visual_feature_for_visualization = visual_feature_for_visualization[:,::-1] |
|
|
|
|
|
visual_feature_for_visualization = (visual_feature_for_visualization - visual_feature_for_visualization.min()) / (visual_feature_for_visualization.max() - visual_feature_for_visualization.min()) |
|
|
|
|
|
plt.imshow(visual_feature_for_visualization, cmap='gray', interpolation='nearest') |
|
|
plt.axis("off") |
|
|
plt.savefig(os.path.join(vis_dir, "channel_{}.png".format(random_channel)), bbox_inches='tight', pad_inches=0) |