Spaces:
Runtime error
Runtime error
Update CALTextModel.py
Browse files- CALTextModel.py +59 -0
CALTextModel.py
CHANGED
@@ -606,3 +606,62 @@ def get_sample(ctx0, h_0, k , maxlen, stochastic, training, model):
|
|
606 |
|
607 |
return sample, sample_score,sample_att
|
608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
|
607 |
return sample, sample_score,sample_att
|
608 |
|
609 |
+
|
610 |
+
#######################Predict
|
611 |
+
@tf.function(experimental_relax_shapes=True)
|
612 |
+
def execute_model(xx,xx_mask,CALTEXT):
|
613 |
+
|
614 |
+
anno = CALTEXT(xx,xx_mask, training=False)
|
615 |
+
hidden_state_0 = CALTEXT.get_hidden_state_0(anno)
|
616 |
+
return anno,hidden_state_0
|
617 |
+
|
618 |
+
|
619 |
+
def test_error( images, x_mask):
|
620 |
+
# training=False is only needed if there are layers with different
|
621 |
+
# behavior during training versus inference (e.g. Dropout).
|
622 |
+
batch_loss=0
|
623 |
+
img_ind=1
|
624 |
+
for img_ind in range(len(images)):
|
625 |
+
xx = images[img_ind][tf.newaxis, ... ]
|
626 |
+
xx_mask = x_mask[img_ind][tf.newaxis, ... ]
|
627 |
+
anno,hidden_state_0=execute_model(xx,xx_mask,CALTEXT)
|
628 |
+
|
629 |
+
sample, score,hypalpha=CALTextModel.get_sample(anno, hidden_state_0,10, 130, False, False, CALTEXT)
|
630 |
+
|
631 |
+
|
632 |
+
score = score / np.array([len(s) for s in sample])
|
633 |
+
ss = sample[score.argmin()]
|
634 |
+
img_ind=img_ind+1
|
635 |
+
|
636 |
+
ind=0
|
637 |
+
num=int(len(ss)/2)
|
638 |
+
|
639 |
+
#### output string
|
640 |
+
ind=0
|
641 |
+
outstr=u''
|
642 |
+
frames = []
|
643 |
+
font = ImageFont.truetype("Jameel Noori Nastaleeq.ttf",60)
|
644 |
+
while (ind<len(ss)-1):
|
645 |
+
k=(len(ss)-2)-ind
|
646 |
+
outstr=outstr+worddicts_r[int(ss[k])]
|
647 |
+
textimg = Image.new('RGB', (1400,100),(255,255,255))
|
648 |
+
drawtext = ImageDraw.Draw(textimg)
|
649 |
+
drawtext.text((20, 20), outstr ,(0,0,0),font=font)
|
650 |
+
fig,axes=plt.subplots(2,1)
|
651 |
+
axes[0].imshow(textimg)
|
652 |
+
axes[0].axis('off')
|
653 |
+
axes[1].axis('off')
|
654 |
+
axes[1].imshow(xx[0,:,:],cmap='gray')
|
655 |
+
visualization=resize(hypalpha[k], (100,800),anti_aliasing=True)
|
656 |
+
axes[1].imshow(255-(255 * visualization), alpha=0.2)
|
657 |
+
plt.axis('off')
|
658 |
+
|
659 |
+
plt.savefig('res.png')
|
660 |
+
frames.append(Image.fromarray(cv2.imread('res.png'), 'RGB'))
|
661 |
+
ind=ind+1
|
662 |
+
frame_one = frames[0]
|
663 |
+
frame_one.save("vis.gif", format="GIF", append_images=frames,save_all=True, duration=300, loop=0)
|
664 |
+
gif_image="vis.gif"
|
665 |
+
return outstr,gif_image
|
666 |
+
|
667 |
+
|