Spaces:
Sleeping
Sleeping
""" | |
Paper: "UTRNet: High-Resolution Urdu Text Recognition In Printed Documents" presented at ICDAR 2023 | |
Authors: Abdur Rahman, Arjun Ghosh, Chetan Arora | |
GitHub Repository: https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition | |
Project Website: https://abdur75648.github.io/UTRNet/ | |
Copyright (c) 2023-present: This work is licensed under the Creative Commons Attribution-NonCommercial | |
4.0 International License (http://creativecommons.org/licenses/by-nc/4.0/) | |
""" | |
import pytz | |
import torch | |
import numpy as np | |
from datetime import datetime | |
import matplotlib.pyplot as plt | |
from torch.autograd import Variable | |
import os,random,shutil | |
import matplotlib.pyplot as plt | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
class CTCLabelConverter(object): | |
""" Convert between text-label and text-index """ | |
def __init__(self, character): | |
# character (str): set of the possible characters. | |
dict_character = list(character) | |
self.dict = {} | |
for i, char in enumerate(dict_character): | |
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss | |
self.dict[char] = i + 1 | |
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) | |
def encode(self, text, batch_max_length=25): | |
"""convert text-label into text-index. | |
input: | |
text: text labels of each image. [batch_size] | |
batch_max_length: max length of text label in the batch. 25 by default | |
output: | |
text: text index for CTCLoss. [batch_size, batch_max_length] | |
length: length of each text. [batch_size] | |
""" | |
length = [len(s) for s in text] | |
# The index used for padding (=0) would not affect the CTC loss calculation. | |
batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) | |
for i, t in enumerate(text): | |
text = list(t) | |
text = [self.dict[char] for char in text] | |
batch_text[i][:len(text)] = torch.LongTensor(text) | |
return (batch_text, torch.IntTensor(length)) | |
def decode(self, text_index, length): | |
""" convert text-index into text-label. """ | |
texts = [] | |
for index, l in enumerate(length): | |
t = text_index[index, :] | |
char_list = [] | |
for i in range(l): | |
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. | |
char_list.append(self.character[t[i]]) | |
text = ''.join(char_list) | |
texts.append(text) | |
return texts | |
class CTCLabelConverterForBaiduWarpctc(object): | |
""" Convert between text-label and text-index for baidu warpctc """ | |
def __init__(self, character): | |
# character (str): set of the possible characters. | |
dict_character = list(character) | |
self.dict = {} | |
for i, char in enumerate(dict_character): | |
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss | |
self.dict[char] = i + 1 | |
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) | |
def encode(self, text, batch_max_length=25): | |
"""convert text-label into text-index. | |
input: | |
text: text labels of each image. [batch_size] | |
output: | |
text: concatenated text index for CTCLoss. | |
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] | |
length: length of each text. [batch_size] | |
""" | |
length = [len(s) for s in text] | |
text = ''.join(text) | |
text = [self.dict[char] for char in text] | |
return (torch.IntTensor(text), torch.IntTensor(length)) | |
def decode(self, text_index, length): | |
""" convert text-index into text-label. """ | |
texts = [] | |
index = 0 | |
for l in length: | |
t = text_index[index:index + l] | |
char_list = [] | |
for i in range(l): | |
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. | |
char_list.append(self.character[t[i]]) | |
text = ''.join(char_list) | |
texts.append(text) | |
index += l | |
return texts | |
class AttnLabelConverter(object): | |
""" Convert between text-label and text-index """ | |
def __init__(self, character): | |
# character (str): set of the possible characters. | |
# [GO] for the start token of the attention decoder. [s] for end-of-sentence token. | |
list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] | |
list_character = list(character) | |
self.character = list_token + list_character | |
self.dict = {} | |
for i, char in enumerate(self.character): | |
# print(i, char) | |
self.dict[char] = i | |
def encode(self, text, batch_max_length=25): | |
""" convert text-label into text-index. | |
input: | |
text: text labels of each image. [batch_size] | |
batch_max_length: max length of text label in the batch. 25 by default | |
output: | |
text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. | |
text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. | |
length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] | |
""" | |
length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. | |
# batch_max_length = max(length) # this is not allowed for multi-gpu setting | |
batch_max_length += 1 | |
# additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. | |
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) | |
for i, t in enumerate(text): | |
text = list(t) | |
text.append('[s]') | |
try: | |
text = [self.dict[char] for char in text] | |
except KeyError as e: | |
continue | |
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token | |
return (batch_text, torch.IntTensor(length)) | |
def decode(self, text_index, length): | |
""" convert text-index into text-label. """ | |
texts = [] | |
for index, l in enumerate(length): | |
text = ''.join([self.character[i] for i in text_index[index, :]]) | |
texts.append(text) | |
return texts | |
def imshow(img, title,batch_size=1): | |
std_correction = np.asarray([0.229, 0.224, 0.225]).reshape(3, 1, 1) | |
mean_correction = np.asarray([0.485, 0.456, 0.406]).reshape(3, 1, 1) | |
npimg = np.multiply(img.numpy(), std_correction) + mean_correction | |
plt.figure(figsize = (batch_size * 4, 4)) | |
plt.axis("off") | |
plt.imshow(np.transpose(npimg, (1, 2, 0))) | |
plt.title(title) | |
plt.show() | |
class Averager(object): | |
"""Compute average for torch.Tensor, used for loss average.""" | |
def __init__(self): | |
self.reset() | |
def add(self, v): | |
count = v.data.numel() | |
v = v.data.sum() | |
self.n_count += count | |
self.sum += v | |
def reset(self): | |
self.n_count = 0 | |
self.sum = 0 | |
def val(self): | |
res = 0 | |
if self.n_count != 0: | |
res = self.sum / float(self.n_count) | |
return res | |
class Logger(object): | |
"""For logging while training""" | |
def __init__(self, path): | |
self.logFile = path | |
datetime_now = str(datetime.now(pytz.timezone('Asia/Kolkata')).strftime("%Y-%m-%d_%H-%M-%S")) | |
with open(self.logFile,"w",encoding="utf-8") as f: | |
f.write("Logging at @ " + str(datetime_now) + "\n") | |
def log(self,*input): | |
message = "" | |
for x in input: | |
message+=str(x) + " " | |
message = message.strip() | |
print(message) | |
with open(self.logFile,"a",encoding="utf-8") as f: | |
f.write(str(message)+"\n") | |
def allign_two_strings(x:str, y:str, pxy:int=1, pgap:int=1): | |
""" | |
Source: https://www.geeksforgeeks.org/sequence-alignment-problem/ | |
""" | |
i = 0 | |
j = 0 | |
m = len(x) | |
n = len(y) | |
dp = np.zeros([m+1,n+1], dtype=int) | |
dp[0:(m+1),0] = [ i * pgap for i in range(m+1)] | |
dp[0,0:(n+1)] = [ i * pgap for i in range(n+1)] | |
i = 1 | |
while i <= m: | |
j = 1 | |
while j <= n: | |
if x[i - 1] == y[j - 1]: | |
dp[i][j] = dp[i - 1][j - 1] | |
else: | |
dp[i][j] = min(dp[i - 1][j - 1] + pxy, | |
dp[i - 1][j] + pgap, | |
dp[i][j - 1] + pgap) | |
j += 1 | |
i += 1 | |
l = n + m | |
i = m | |
j = n | |
xpos = l | |
ypos = l | |
xans = np.zeros(l+1, dtype=int) | |
yans = np.zeros(l+1, dtype=int) | |
while not (i == 0 or j == 0): | |
#print(f"i: {i}, j: {j}") | |
if x[i - 1] == y[j - 1]: | |
xans[xpos] = ord(x[i - 1]) | |
yans[ypos] = ord(y[j - 1]) | |
xpos -= 1 | |
ypos -= 1 | |
i -= 1 | |
j -= 1 | |
elif (dp[i - 1][j - 1] + pxy) == dp[i][j]: | |
xans[xpos] = ord(x[i - 1]) | |
yans[ypos] = ord(y[j - 1]) | |
xpos -= 1 | |
ypos -= 1 | |
i -= 1 | |
j -= 1 | |
elif (dp[i - 1][j] + pgap) == dp[i][j]: | |
xans[xpos] = ord(x[i - 1]) | |
yans[ypos] = ord('_') | |
xpos -= 1 | |
ypos -= 1 | |
i -= 1 | |
elif (dp[i][j - 1] + pgap) == dp[i][j]: | |
xans[xpos] = ord('_') | |
yans[ypos] = ord(y[j - 1]) | |
xpos -= 1 | |
ypos -= 1 | |
j -= 1 | |
while xpos > 0: | |
if i > 0: | |
i -= 1 | |
xans[xpos] = ord(x[i]) | |
xpos -= 1 | |
else: | |
xans[xpos] = ord('_') | |
xpos -= 1 | |
while ypos > 0: | |
if j > 0: | |
j -= 1 | |
yans[ypos] = ord(y[j]) | |
ypos -= 1 | |
else: | |
yans[ypos] = ord('_') | |
ypos -= 1 | |
id = 1 | |
i = l | |
while i >= 1: | |
if (chr(yans[i]) == '_') and chr(xans[i]) == '_': | |
id = i + 1 | |
break | |
i -= 1 | |
i = id | |
x_seq = "" | |
while i <= l: | |
x_seq += chr(xans[i]) | |
i += 1 | |
# Y | |
i = id | |
y_seq = "" | |
while i <= l: | |
y_seq += chr(yans[i]) | |
i += 1 | |
return x_seq, y_seq | |
# Function to count the number of trainable parameters in a model in "Millions" | |
def count_parameters(model,precision=2): | |
return (round(sum(p.numel() for p in model.parameters() if p.requires_grad) / 10.**6, precision)) | |
''' | |
# Code for counting the number of FLOPs in the CNN backbone during inference | |
Source - https://github.com/fdbtrs/ElasticFace/blob/main/utils/countFLOPS.py | |
''' | |
def count_model_flops(model,in_channels=1, input_res=[32, 400], multiply_adds=True): | |
list_conv = [] | |
def conv_hook(self, input, output): | |
batch_size, input_channels, input_height, input_width = input[0].size() | |
output_channels, output_height, output_width = output[0].size() | |
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) | |
bias_ops = 1 if self.bias is not None else 0 | |
params = output_channels * (kernel_ops + bias_ops) | |
flops = (kernel_ops * ( | |
2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size | |
list_conv.append(flops) | |
list_linear = [] | |
def linear_hook(self, input, output): | |
batch_size = input[0].size(0) if input[0].dim() == 2 else 1 | |
weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) | |
if self.bias is not None: | |
bias_ops = self.bias.nelement() if self.bias.nelement() else 0 | |
flops = batch_size * (weight_ops + bias_ops) | |
else: | |
flops = batch_size * weight_ops | |
list_linear.append(flops) | |
list_bn = [] | |
def bn_hook(self, input, output): | |
list_bn.append(input[0].nelement() * 2) | |
list_relu = [] | |
def relu_hook(self, input, output): | |
list_relu.append(input[0].nelement()) | |
list_pooling = [] | |
def pooling_hook(self, input, output): | |
batch_size, input_channels, input_height, input_width = input[0].size() | |
output_channels, output_height, output_width = output[0].size() | |
# If kernel_size is a tuple type, computer ops as product of elements or else if it is int type, compute ops as square of kernel_size | |
kernel_ops = self.kernel_size[0] * self.kernel_size[1] if isinstance(self.kernel_size, tuple) else self.kernel_size * self.kernel_size | |
bias_ops = 0 | |
params = 0 | |
flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size | |
list_pooling.append(flops) | |
def dropout_hook(self, input, output): | |
# calculate the number of operations for a dropout function by assuming that each operation involves one comparison and one multiplication | |
batch_size, input_channels, input_height, input_width = input[0].size() | |
list_conv.append(2*batch_size*input_channels*input_height*input_width) | |
def sigmoid_hook(self,input,output): | |
# calculate the number of operations for a sigmoid function by assuming that each operation involves two multiplications and one addition | |
batch_size, input_channels, input_height, input_width = input[0].size() | |
list_conv.append(3*batch_size*input_channels*input_height*input_width) | |
def upsample_hook(self, input, output): | |
batch_size, input_channels, input_height, input_width = input[0].size() | |
output_channels, output_height, output_width = output[0].size() | |
kernel_ops = self.scale_factor * self.scale_factor # * (self.in_channels / self.groups) | |
flops = (kernel_ops * ( | |
2 if multiply_adds else 1)) * output_channels * output_height * output_width * batch_size | |
list_conv.append(flops) | |
handles = [] | |
def foo(net): | |
childrens = list(net.children()) | |
if not childrens: | |
if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d): | |
handles.append(net.register_forward_hook(conv_hook)) | |
elif isinstance(net, torch.nn.Linear): | |
handles.append(net.register_forward_hook(linear_hook)) | |
elif isinstance(net, torch.nn.BatchNorm2d) or isinstance(net, torch.nn.BatchNorm1d): | |
handles.append(net.register_forward_hook(bn_hook)) | |
elif isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.PReLU): | |
handles.append(net.register_forward_hook(relu_hook)) | |
elif isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): | |
handles.append(net.register_forward_hook(pooling_hook)) | |
elif isinstance(net, torch.nn.Dropout): | |
handles.append(net.register_forward_hook(dropout_hook)) | |
elif isinstance(net,torch.nn.Upsample): | |
handles.append(net.register_forward_hook(upsample_hook)) | |
elif isinstance(net,torch.nn.Sigmoid): | |
handles.append(net.register_forward_hook(sigmoid_hook)) | |
else: | |
print("warning" + str(net)) | |
return | |
for c in childrens: | |
foo(c) | |
model.eval() | |
foo(model) | |
input = Variable(torch.rand(in_channels, input_res[1], input_res[0]).unsqueeze(0), requires_grad=True) | |
out = model(input) | |
total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling)) | |
for h in handles: | |
h.remove() | |
model.train() | |
def flops_to_string(flops, units='MFLOPS', precision=4): | |
if units == 'GFLOPS': | |
return str(round(flops / 10.**9, precision)) + ' ' + units | |
elif units == 'MFLOPS': | |
return str(round(flops / 10.**6, precision)) + ' ' + units | |
elif units == 'KFLOPS': | |
return str(round(flops / 10.**3, precision)) + ' ' + units | |
else: | |
return str(flops) + ' FLOPS' | |
return flops_to_string(total_flops) | |
def draw_feature_map(visual_feature,vis_dir,num_channel=10): | |
"""draws feature maps for the given visual features | |
Args: | |
visual_feature (Tensor): Shape (C, H, W) | |
vis_dir (String): Directory to save the feature maps | |
""" | |
if os.path.exists(vis_dir): | |
shutil.rmtree(vis_dir) | |
os.makedirs(vis_dir) | |
# Save visual_feature from num_channel random channels for visualization | |
for i in range(num_channel): | |
random_channel = random.randint(0, visual_feature.shape[1]-1) | |
visual_feature_for_visualization = visual_feature[0, random_channel, :, :].detach().cpu().numpy() | |
# Horizontal flip | |
visual_feature_for_visualization = visual_feature_for_visualization[:,::-1] | |
# Normalize | |
visual_feature_for_visualization = (visual_feature_for_visualization - visual_feature_for_visualization.min()) / (visual_feature_for_visualization.max() - visual_feature_for_visualization.min()) | |
# Draw heatmap | |
plt.imshow(visual_feature_for_visualization, cmap='gray', interpolation='nearest') | |
plt.axis("off") | |
plt.savefig(os.path.join(vis_dir, "channel_{}.png".format(random_channel)), bbox_inches='tight', pad_inches=0) |