Tayaba171's picture
Update app.py
3469024
raw
history blame
No virus
4.77 kB
import tensorflow as tf
from matplotlib import pyplot as plt
from skimage.transform import rescale, resize
import pickle as pkl
import numpy as np
import os
import cv2
from PIL import Image,ImageFont, ImageDraw
import CALTextModel
import gradio as gr
#### training setup parameters ####
lambda_val=1e-4
gamma_val=1
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
################################### Utility functions###################################
def load_dict_picklefile(dictFile):
fp=open(dictFile,'rb')
lexicon=pkl.load(fp)
fp.close()
return lexicon,lexicon[' ']
def preprocess_img(img):
if len(img.shape)>2:
img= cv2.cvtColor(img.astype('float32'), cv2.COLOR_BGR2GRAY)
height=img.shape[0]
width=img.shape[1]
if(width<300):
result = np.ones([img.shape[0], img.shape[1]*2])*255
result[0:img.shape[0],img.shape[1]:img.shape[1]*2]=img
img=result
img=cv2.resize(img, dsize=(800,100), interpolation = cv2.INTER_AREA)
img=(img-img.min())/(img.max()-img.min())
xx_pad = np.zeros((100, 800), dtype='float32')
xx_pad[:,:] =1
xx_pad = xx_pad[None, :, :]
img=img[None, :, :]
return img, xx_pad
worddicts,_ = load_dict_picklefile('vocabulary.pkl')
worddicts_r = [None] * len(worddicts)
i=1
for kk, vv in worddicts.items():
if(i<len(worddicts)):
worddicts_r[vv] = kk
else:
break
i=i+1
# Create an instance of the model
CALTEXT = CALTextModel.CALTEXT_Model(training=False)
CALTEXT.load_weights('final_caltextModel/cp-0037.ckpt')
test_loss = tf.keras.metrics.Mean(name='test_loss')
@tf.function(experimental_relax_shapes=True)
def execute_model(xx,xx_mask,CALTEXT):
anno = CALTEXT(xx,xx_mask, training=False)
hidden_state_0 = CALTEXT.get_hidden_state_0(anno)
return anno,hidden_state_0
def test_error( images, x_mask):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
batch_loss=0
img_ind=1
for img_ind in range(len(images)):
xx = images[img_ind][tf.newaxis, ... ]
xx_mask = x_mask[img_ind][tf.newaxis, ... ]
anno,hidden_state_0=execute_model(xx,xx_mask,CALTEXT)
sample, score,hypalpha=CALTextModel.get_sample(anno, hidden_state_0,10, 130, False, False, CALTEXT)
score = score / np.array([len(s) for s in sample])
ss = sample[score.argmin()]
img_ind=img_ind+1
ind=0
num=int(len(ss)/2)
#### output string
ind=0
outstr=u''
frames = []
font = ImageFont.truetype("Jameel Noori Nastaleeq.ttf",60)
while (ind<len(ss)-1):
k=(len(ss)-2)-ind
outstr=outstr+worddicts_r[int(ss[k])]
textimg = Image.new('RGB', (1400,100),(255,255,255))
drawtext = ImageDraw.Draw(textimg)
drawtext.text((20, 20), outstr ,(0,0,0),font=font)
fig,axes=plt.subplots(2,1)
axes[0].imshow(textimg)
axes[0].axis('off')
axes[1].axis('off')
axes[1].imshow(xx[0,:,:],cmap='gray')
visualization=resize(hypalpha[k], (100,800),anti_aliasing=True)
axes[1].imshow(255-(255 * visualization), alpha=0.2)
plt.axis('off')
plt.savefig('res.png')
frames.append(Image.fromarray(cv2.imread('res.png'), 'RGB'))
ind=ind+1
frame_one = frames[0]
frame_one.save("vis.gif", format="GIF", append_images=frames,save_all=True, duration=300, loop=0)
gif_image="vis.gif"
return outstr,gif_image
examples = [
['sample_test_images/59-11.png'],
['sample_test_images/59-21.png'],
['sample_test_images/59-32.png'],
['sample_test_images/59-37.png'],
['sample_test_images/91-47.png'],
['sample_test_images/91-49.png'],
]
def recognize_text(input_image):
x, x_mask=preprocess_img(input_image)
output_str, gifImage=test_error(x, x_mask)
return output_str,gifImage
title = "CALText Demo"
description = "<p style='text-align: center'>Gradio demo for an CALText model architecture <a href='https://github.com/nazar-khan/CALText'>[GitHub Code]</a> trained on the <a href='http://faculty.pucit.edu.pk/nazarkhan/work/urdu_ohtr/pucit_ohul_dataset.html'>PUCIT-OHUL</a> dataset. To use it, simply add your image, or click one of the examples to load them. </p>"
article = "<p style='text-align: center'></p>"
css = "#0 {object-fit: contain;} #1 {object-fit: contain;}"
inputs = gr.inputs.Image(label="Input Image")
demo = gr.Interface(fn=recognize_text,
inputs=inputs,
outputs=[gr.Textbox(label="Output"),
gr.Image(label="Attended Regions")],
examples=examples,
title=title,
description=description,
article=article,allow_flagging='never')
demo.launch()