MohammedHamdy32's picture
test
7b7f574
"""
Paper: "UTRNet: High-Resolution Urdu Text Recognition In Printed Documents" presented at ICDAR 2023
Authors: Abdur Rahman, Arjun Ghosh, Chetan Arora
GitHub Repository: https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
Project Website: https://abdur75648.github.io/UTRNet/
Copyright (c) 2023-present: This work is licensed under the Creative Commons Attribution-NonCommercial
4.0 International License (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import pytz
import torch
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
from torch.autograd import Variable
import os,random,shutil
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
class CTCLabelConverter(object):
""" Convert between text-label and text-index """
def __init__(self, character):
# character (str): set of the possible characters.
dict_character = list(character)
self.dict = {}
for i, char in enumerate(dict_character):
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
self.dict[char] = i + 1
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
batch_max_length: max length of text label in the batch. 25 by default
output:
text: text index for CTCLoss. [batch_size, batch_max_length]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
# The index used for padding (=0) would not affect the CTC loss calculation.
batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0)
for i, t in enumerate(text):
text = list(t)
text = [self.dict[char] for char in text]
batch_text[i][:len(text)] = torch.LongTensor(text)
return (batch_text, torch.IntTensor(length))
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
for index, l in enumerate(length):
t = text_index[index, :]
char_list = []
for i in range(l):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
return texts
class CTCLabelConverterForBaiduWarpctc(object):
""" Convert between text-label and text-index for baidu warpctc """
def __init__(self, character):
# character (str): set of the possible characters.
dict_character = list(character)
self.dict = {}
for i, char in enumerate(dict_character):
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
self.dict[char] = i + 1
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
text = ''.join(text)
text = [self.dict[char] for char in text]
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
index = 0
for l in length:
t = text_index[index:index + l]
char_list = []
for i in range(l):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
index += l
return texts
class AttnLabelConverter(object):
""" Convert between text-label and text-index """
def __init__(self, character):
# character (str): set of the possible characters.
# [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]']
list_character = list(character)
self.character = list_token + list_character
self.dict = {}
for i, char in enumerate(self.character):
# print(i, char)
self.dict[char] = i
def encode(self, text, batch_max_length=25):
""" convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
batch_max_length: max length of text label in the batch. 25 by default
output:
text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token.
text[:, 0] is [GO] token and text is padded with [GO] token after [s] token.
length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size]
"""
length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
# batch_max_length = max(length) # this is not allowed for multi-gpu setting
batch_max_length += 1
# additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
for i, t in enumerate(text):
text = list(t)
text.append('[s]')
try:
text = [self.dict[char] for char in text]
except KeyError as e:
continue
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
return (batch_text, torch.IntTensor(length))
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
for index, l in enumerate(length):
text = ''.join([self.character[i] for i in text_index[index, :]])
texts.append(text)
return texts
def imshow(img, title,batch_size=1):
std_correction = np.asarray([0.229, 0.224, 0.225]).reshape(3, 1, 1)
mean_correction = np.asarray([0.485, 0.456, 0.406]).reshape(3, 1, 1)
npimg = np.multiply(img.numpy(), std_correction) + mean_correction
plt.figure(figsize = (batch_size * 4, 4))
plt.axis("off")
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.title(title)
plt.show()
class Averager(object):
"""Compute average for torch.Tensor, used for loss average."""
def __init__(self):
self.reset()
def add(self, v):
count = v.data.numel()
v = v.data.sum()
self.n_count += count
self.sum += v
def reset(self):
self.n_count = 0
self.sum = 0
def val(self):
res = 0
if self.n_count != 0:
res = self.sum / float(self.n_count)
return res
class Logger(object):
"""For logging while training"""
def __init__(self, path):
self.logFile = path
datetime_now = str(datetime.now(pytz.timezone('Asia/Kolkata')).strftime("%Y-%m-%d_%H-%M-%S"))
with open(self.logFile,"w",encoding="utf-8") as f:
f.write("Logging at @ " + str(datetime_now) + "\n")
def log(self,*input):
message = ""
for x in input:
message+=str(x) + " "
message = message.strip()
print(message)
with open(self.logFile,"a",encoding="utf-8") as f:
f.write(str(message)+"\n")
def allign_two_strings(x:str, y:str, pxy:int=1, pgap:int=1):
"""
Source: https://www.geeksforgeeks.org/sequence-alignment-problem/
"""
i = 0
j = 0
m = len(x)
n = len(y)
dp = np.zeros([m+1,n+1], dtype=int)
dp[0:(m+1),0] = [ i * pgap for i in range(m+1)]
dp[0,0:(n+1)] = [ i * pgap for i in range(n+1)]
i = 1
while i <= m:
j = 1
while j <= n:
if x[i - 1] == y[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = min(dp[i - 1][j - 1] + pxy,
dp[i - 1][j] + pgap,
dp[i][j - 1] + pgap)
j += 1
i += 1
l = n + m
i = m
j = n
xpos = l
ypos = l
xans = np.zeros(l+1, dtype=int)
yans = np.zeros(l+1, dtype=int)
while not (i == 0 or j == 0):
#print(f"i: {i}, j: {j}")
if x[i - 1] == y[j - 1]:
xans[xpos] = ord(x[i - 1])
yans[ypos] = ord(y[j - 1])
xpos -= 1
ypos -= 1
i -= 1
j -= 1
elif (dp[i - 1][j - 1] + pxy) == dp[i][j]:
xans[xpos] = ord(x[i - 1])
yans[ypos] = ord(y[j - 1])
xpos -= 1
ypos -= 1
i -= 1
j -= 1
elif (dp[i - 1][j] + pgap) == dp[i][j]:
xans[xpos] = ord(x[i - 1])
yans[ypos] = ord('_')
xpos -= 1
ypos -= 1
i -= 1
elif (dp[i][j - 1] + pgap) == dp[i][j]:
xans[xpos] = ord('_')
yans[ypos] = ord(y[j - 1])
xpos -= 1
ypos -= 1
j -= 1
while xpos > 0:
if i > 0:
i -= 1
xans[xpos] = ord(x[i])
xpos -= 1
else:
xans[xpos] = ord('_')
xpos -= 1
while ypos > 0:
if j > 0:
j -= 1
yans[ypos] = ord(y[j])
ypos -= 1
else:
yans[ypos] = ord('_')
ypos -= 1
id = 1
i = l
while i >= 1:
if (chr(yans[i]) == '_') and chr(xans[i]) == '_':
id = i + 1
break
i -= 1
i = id
x_seq = ""
while i <= l:
x_seq += chr(xans[i])
i += 1
# Y
i = id
y_seq = ""
while i <= l:
y_seq += chr(yans[i])
i += 1
return x_seq, y_seq
# Function to count the number of trainable parameters in a model in "Millions"
def count_parameters(model,precision=2):
return (round(sum(p.numel() for p in model.parameters() if p.requires_grad) / 10.**6, precision))
'''
# Code for counting the number of FLOPs in the CNN backbone during inference
Source - https://github.com/fdbtrs/ElasticFace/blob/main/utils/countFLOPS.py
'''
def count_model_flops(model,in_channels=1, input_res=[32, 400], multiply_adds=True):
list_conv = []
def conv_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
bias_ops = 1 if self.bias is not None else 0
params = output_channels * (kernel_ops + bias_ops)
flops = (kernel_ops * (
2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size
list_conv.append(flops)
list_linear = []
def linear_hook(self, input, output):
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
if self.bias is not None:
bias_ops = self.bias.nelement() if self.bias.nelement() else 0
flops = batch_size * (weight_ops + bias_ops)
else:
flops = batch_size * weight_ops
list_linear.append(flops)
list_bn = []
def bn_hook(self, input, output):
list_bn.append(input[0].nelement() * 2)
list_relu = []
def relu_hook(self, input, output):
list_relu.append(input[0].nelement())
list_pooling = []
def pooling_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
# If kernel_size is a tuple type, computer ops as product of elements or else if it is int type, compute ops as square of kernel_size
kernel_ops = self.kernel_size[0] * self.kernel_size[1] if isinstance(self.kernel_size, tuple) else self.kernel_size * self.kernel_size
bias_ops = 0
params = 0
flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size
list_pooling.append(flops)
def dropout_hook(self, input, output):
# calculate the number of operations for a dropout function by assuming that each operation involves one comparison and one multiplication
batch_size, input_channels, input_height, input_width = input[0].size()
list_conv.append(2*batch_size*input_channels*input_height*input_width)
def sigmoid_hook(self,input,output):
# calculate the number of operations for a sigmoid function by assuming that each operation involves two multiplications and one addition
batch_size, input_channels, input_height, input_width = input[0].size()
list_conv.append(3*batch_size*input_channels*input_height*input_width)
def upsample_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.scale_factor * self.scale_factor # * (self.in_channels / self.groups)
flops = (kernel_ops * (
2 if multiply_adds else 1)) * output_channels * output_height * output_width * batch_size
list_conv.append(flops)
handles = []
def foo(net):
childrens = list(net.children())
if not childrens:
if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d):
handles.append(net.register_forward_hook(conv_hook))
elif isinstance(net, torch.nn.Linear):
handles.append(net.register_forward_hook(linear_hook))
elif isinstance(net, torch.nn.BatchNorm2d) or isinstance(net, torch.nn.BatchNorm1d):
handles.append(net.register_forward_hook(bn_hook))
elif isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.PReLU):
handles.append(net.register_forward_hook(relu_hook))
elif isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
handles.append(net.register_forward_hook(pooling_hook))
elif isinstance(net, torch.nn.Dropout):
handles.append(net.register_forward_hook(dropout_hook))
elif isinstance(net,torch.nn.Upsample):
handles.append(net.register_forward_hook(upsample_hook))
elif isinstance(net,torch.nn.Sigmoid):
handles.append(net.register_forward_hook(sigmoid_hook))
else:
print("warning" + str(net))
return
for c in childrens:
foo(c)
model.eval()
foo(model)
input = Variable(torch.rand(in_channels, input_res[1], input_res[0]).unsqueeze(0), requires_grad=True)
out = model(input)
total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling))
for h in handles:
h.remove()
model.train()
def flops_to_string(flops, units='MFLOPS', precision=4):
if units == 'GFLOPS':
return str(round(flops / 10.**9, precision)) + ' ' + units
elif units == 'MFLOPS':
return str(round(flops / 10.**6, precision)) + ' ' + units
elif units == 'KFLOPS':
return str(round(flops / 10.**3, precision)) + ' ' + units
else:
return str(flops) + ' FLOPS'
return flops_to_string(total_flops)
def draw_feature_map(visual_feature,vis_dir,num_channel=10):
"""draws feature maps for the given visual features
Args:
visual_feature (Tensor): Shape (C, H, W)
vis_dir (String): Directory to save the feature maps
"""
if os.path.exists(vis_dir):
shutil.rmtree(vis_dir)
os.makedirs(vis_dir)
# Save visual_feature from num_channel random channels for visualization
for i in range(num_channel):
random_channel = random.randint(0, visual_feature.shape[1]-1)
visual_feature_for_visualization = visual_feature[0, random_channel, :, :].detach().cpu().numpy()
# Horizontal flip
visual_feature_for_visualization = visual_feature_for_visualization[:,::-1]
# Normalize
visual_feature_for_visualization = (visual_feature_for_visualization - visual_feature_for_visualization.min()) / (visual_feature_for_visualization.max() - visual_feature_for_visualization.min())
# Draw heatmap
plt.imshow(visual_feature_for_visualization, cmap='gray', interpolation='nearest')
plt.axis("off")
plt.savefig(os.path.join(vis_dir, "channel_{}.png".format(random_channel)), bbox_inches='tight', pad_inches=0)