Tayaba171's picture
Update app.py
532f3a6
raw
history blame
5.01 kB
import tensorflow as tf
from matplotlib import pyplot as plt
from skimage.transform import rescale, resize
import pickle as pkl
import numpy as np
import os
import cv2
from PIL import Image,ImageFont, ImageDraw
import CALTextModel
#### training setup parameters ####
lambda_val=1e-4
gamma_val=1
################################### Utility functions###################################
def load_dict_picklefile(dictFile):
fp=open(dictFile,'rb')
lexicon=pkl.load(fp)
fp.close()
return lexicon,lexicon[' ']
def preprocess_img(img):
if len(img.shape)>2:
img= cv2.cvtColor(img.astype('float32'), cv2.COLOR_BGR2GRAY)
height=img.shape[0]
width=img.shape[1]
if(width<300):
result = np.ones([img.shape[0], img.shape[1]*2])*255
result[0:img.shape[0],img.shape[1]:img.shape[1]*2]=img
img=result
img=cv2.resize(img, dsize=(800,100), interpolation = cv2.INTER_AREA)
img=(img-img.min())/(img.max()-img.min())
xx_pad = np.zeros((100, 800), dtype='float32')
xx_pad[:,:] =1
xx_pad = xx_pad[None, :, :]
img=img[None, :, :]
return img, xx_pad
worddicts,_ = load_dict_picklefile('vocabulary.pkl')
worddicts_r = [None] * len(worddicts)
i=1
for kk, vv in worddicts.items():
if(i<len(worddicts)):
worddicts_r[vv] = kk
else:
break
i=i+1
# Create an instance of the model
CALTEXT = CALTextModel.CALTEXT_Model(training=False)
CALTEXT.load_weights('final_caltextModel/cp-0037.ckpt')
test_loss = tf.keras.metrics.Mean(name='test_loss')
@tf.function(experimental_relax_shapes=True)
def execute_model(xx,xx_mask,CALTEXT):
anno = CALTEXT(xx,xx_mask, training=False)
hidden_state_0 = CALTEXT.get_hidden_state_0(anno)
return anno,hidden_state_0
def test_error( images, x_mask):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
batch_loss=0
img_ind=1
for img_ind in range(len(images)):
xx = images[img_ind][tf.newaxis, ... ]
xx_mask = x_mask[img_ind][tf.newaxis, ... ]
anno,hidden_state_0=execute_model(xx,xx_mask,CALTEXT)
sample, score,hypalpha=CALTextModel.get_sample(anno, hidden_state_0,10, 130, False, False, CALTEXT)
score = score / np.array([len(s) for s in sample])
ss = sample[score.argmin()]
img_ind=img_ind+1
ind=0
num=int(len(ss)/2)
#### output string
ind=0
outstr=u''
frames = []
#font = ImageFont.truetype("Jameel Noori Nastaleeq.ttf",60)
while (ind<len(ss)-1):
k=(len(ss)-2)-ind
outstr=outstr+worddicts_r[int(ss[k])]
'''textimg = Image.new('RGB', (1400,100),(255,255,255))
drawtext = ImageDraw.Draw(textimg)
drawtext.text((20, 20), outstr ,(0,0,0),font=font)
fig,axes=plt.subplots(2,1)
axes[0].imshow(textimg)
axes[0].axis('off')
axes[1].axis('off')
axes[1].imshow(xx[0,:,:],cmap='gray')
visualization=resize(hypalpha[k], (100,800),anti_aliasing=True)
axes[1].imshow(255-(255 * visualization), alpha=0.2)
plt.axis('off')
plt.savefig('/content/gdrive/My Drive/CALText_Demo/res.png')
frames.append(Image.fromarray(cv2.imread('/content/gdrive/My Drive/CALText_Demo/res.png'), 'RGB'))'''
ind=ind+1
'''frame_one = frames[0]
frame_one.save("/content/gdrive/My Drive/CALText_Demo/'vis.gif", format="GIF", append_images=frames,save_all=True, duration=300, loop=0)
gif_image="/content/gdrive/My Drive/CALText_Demo/'vis.gif"'''
return outstr,gif_image
'''examples = [
['/content/gdrive/My Drive/CALText_Demo/sample_test_images/59-11.png'],
['/content/gdrive/My Drive/CALText_Demo/sample_test_images/59-21.png'],
['/content/gdrive/My Drive/CALText_Demo/sample_test_images/59-32.png'],
['/content/gdrive/My Drive/CALText_Demo/sample_test_images/59-37.png'],
['/content/gdrive/My Drive/CALText_Demo/sample_test_images/91-47.png'],
['/content/gdrive/My Drive/CALText_Demo/sample_test_images/91-49.png'],
]'''
import gradio as gr
def recognize_text(input_image):
x, x_mask=preprocess_img(input_image)
output_str,gif_image=test_error(x, x_mask)
return output_str,gif_image
title = "CALText Demo"
description = "<p style='text-align: center'>Gradio demo for an CALText model architecture <a href='https://github.com/nazar-khan/CALText'>[GitHub Code]</a> trained on the <a href='http://faculty.pucit.edu.pk/nazarkhan/work/urdu_ohtr/pucit_ohul_dataset.html'>PUCIT-OHUL</a> dataset. To use it, simply add your image, or click one of the examples to load them. </p>"
article = "<p style='text-align: center'></p>"
css = "#0 {object-fit: contain;} #1 {object-fit: contain;}"
inputs = gr.inputs.Image(label="Input Image")
demo = gr.Interface(fn=recognize_text,inputs=inputs,outputs=[gr.Textbox(label="Output"),gr.Image(label="Demonstration of attention")],title=title,
description=description,
article=article,allow_flagging='never')
demo.launch()