Tayaba171 commited on
Commit
a172c2e
1 Parent(s): 2d5345c

Update CALTextModel.py

Browse files
Files changed (1) hide show
  1. CALTextModel.py +59 -0
CALTextModel.py CHANGED
@@ -606,3 +606,62 @@ def get_sample(ctx0, h_0, k , maxlen, stochastic, training, model):
606
 
607
  return sample, sample_score,sample_att
608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
 
607
  return sample, sample_score,sample_att
608
 
609
+
610
+ #######################Predict
611
+ @tf.function(experimental_relax_shapes=True)
612
+ def execute_model(xx,xx_mask,CALTEXT):
613
+
614
+ anno = CALTEXT(xx,xx_mask, training=False)
615
+ hidden_state_0 = CALTEXT.get_hidden_state_0(anno)
616
+ return anno,hidden_state_0
617
+
618
+
619
+ def test_error( images, x_mask):
620
+ # training=False is only needed if there are layers with different
621
+ # behavior during training versus inference (e.g. Dropout).
622
+ batch_loss=0
623
+ img_ind=1
624
+ for img_ind in range(len(images)):
625
+ xx = images[img_ind][tf.newaxis, ... ]
626
+ xx_mask = x_mask[img_ind][tf.newaxis, ... ]
627
+ anno,hidden_state_0=execute_model(xx,xx_mask,CALTEXT)
628
+
629
+ sample, score,hypalpha=CALTextModel.get_sample(anno, hidden_state_0,10, 130, False, False, CALTEXT)
630
+
631
+
632
+ score = score / np.array([len(s) for s in sample])
633
+ ss = sample[score.argmin()]
634
+ img_ind=img_ind+1
635
+
636
+ ind=0
637
+ num=int(len(ss)/2)
638
+
639
+ #### output string
640
+ ind=0
641
+ outstr=u''
642
+ frames = []
643
+ font = ImageFont.truetype("Jameel Noori Nastaleeq.ttf",60)
644
+ while (ind<len(ss)-1):
645
+ k=(len(ss)-2)-ind
646
+ outstr=outstr+worddicts_r[int(ss[k])]
647
+ textimg = Image.new('RGB', (1400,100),(255,255,255))
648
+ drawtext = ImageDraw.Draw(textimg)
649
+ drawtext.text((20, 20), outstr ,(0,0,0),font=font)
650
+ fig,axes=plt.subplots(2,1)
651
+ axes[0].imshow(textimg)
652
+ axes[0].axis('off')
653
+ axes[1].axis('off')
654
+ axes[1].imshow(xx[0,:,:],cmap='gray')
655
+ visualization=resize(hypalpha[k], (100,800),anti_aliasing=True)
656
+ axes[1].imshow(255-(255 * visualization), alpha=0.2)
657
+ plt.axis('off')
658
+
659
+ plt.savefig('res.png')
660
+ frames.append(Image.fromarray(cv2.imread('res.png'), 'RGB'))
661
+ ind=ind+1
662
+ frame_one = frames[0]
663
+ frame_one.save("vis.gif", format="GIF", append_images=frames,save_all=True, duration=300, loop=0)
664
+ gif_image="vis.gif"
665
+ return outstr,gif_image
666
+
667
+