DocTRv2 / vietocr /translate.py
hantech's picture
Upload 28 files
0667c13 verified
raw
history blame contribute delete
No virus
1.95 kB
import torch
import numpy as np
import math
from PIL import Image
from torch.nn.functional import softmax
def translate(img, model, max_seq_length=128, sos_token=1, eos_token=2):
"data: BxCXHxW"
model.eval()
with torch.no_grad():
src = model.cnn(img)
memory = model.transformer.forward_encoder(src)
translated_sentence = [[sos_token]*len(img)]
max_length = 0
while max_length <= max_seq_length and not all(np.any(np.asarray(translated_sentence).T==eos_token, axis=1)):
tgt_inp = torch.LongTensor(translated_sentence)
output, memory = model.transformer.forward_decoder(tgt_inp, memory)
output = softmax(output, dim=-1)
_, indices = torch.topk(output, 5)
indices = indices[:, -1, 0]
indices = indices.tolist()
translated_sentence.append(indices)
max_length += 1
translated_sentence = np.asarray(translated_sentence).T
return translated_sentence
def resize(w, h, expected_height, image_min_width, image_max_width):
new_w = int(expected_height * float(w) / float(h))
round_to = 10
new_w = math.ceil(new_w/round_to)*round_to
new_w = max(new_w, image_min_width)
new_w = min(new_w, image_max_width)
return new_w, expected_height
def process_image(image, image_height, image_min_width, image_max_width):
img = image.convert('RGB')
w, h = img.size
new_w, image_height = resize(w, h, image_height, image_min_width, image_max_width)
img = img.resize((new_w, image_height), Image.Resampling.LANCZOS)
img = np.asarray(img).transpose(2,0, 1)
img = img/255
return img
def process_input(image, image_height, image_min_width, image_max_width):
img = process_image(image, image_height, image_min_width, image_max_width)
img = img[np.newaxis, ...]
img = torch.FloatTensor(img)
return img